mirror of
https://github.com/khoaliber/LetterFeed.git
synced 2026-03-02 13:18:27 +00:00
feat: authentication
This commit is contained in:
108
backend/app/core/auth.py
Normal file
108
backend/app/core/auth.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings as env_settings
|
||||
from app.core.database import get_db
|
||||
from app.core.hashing import get_password_hash
|
||||
from app.models.settings import Settings as SettingsModel
|
||||
from app.schemas.auth import TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login", auto_error=False)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_env_password_hash():
|
||||
"""Get and cache the password hash from environment variables."""
|
||||
if env_settings.auth_password:
|
||||
return get_password_hash(env_settings.auth_password)
|
||||
return None
|
||||
|
||||
|
||||
def _get_auth_credentials(db: Session) -> dict:
|
||||
"""Get auth credentials, prioritizing environment variables."""
|
||||
# Env vars take precedence
|
||||
if env_settings.auth_username and env_settings.auth_password:
|
||||
return {
|
||||
"username": env_settings.auth_username,
|
||||
"password_hash": _get_env_password_hash(),
|
||||
}
|
||||
|
||||
# Then check DB
|
||||
db_settings = db.query(SettingsModel).first()
|
||||
if db_settings and db_settings.auth_username and db_settings.auth_password_hash:
|
||||
return {
|
||||
"username": db_settings.auth_username,
|
||||
"password_hash": db_settings.auth_password_hash,
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
"""Create a new access token."""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, env_settings.secret_key, algorithm=env_settings.algorithm
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def protected_route(
|
||||
token: str | None = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Dependency to protect routes with JWTs."""
|
||||
auth_creds = _get_auth_credentials(db)
|
||||
|
||||
# If no auth credentials are set up, access is allowed.
|
||||
if not auth_creds.get("username") or not auth_creds.get("password_hash"):
|
||||
return
|
||||
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, env_settings.secret_key, algorithms=[env_settings.algorithm]
|
||||
)
|
||||
username: str | None = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# Check if the username from the token matches the configured username
|
||||
correct_username = secrets.compare_digest(
|
||||
token_data.username, auth_creds["username"]
|
||||
)
|
||||
if not correct_username:
|
||||
raise credentials_exception
|
||||
|
||||
return token_data.username
|
||||
|
||||
|
||||
def is_auth_enabled(db: Session = Depends(get_db)):
|
||||
"""Dependency to check if auth is enabled."""
|
||||
auth_creds = _get_auth_credentials(db)
|
||||
return bool(auth_creds.get("username"))
|
||||
@@ -27,6 +27,13 @@ class Settings(BaseSettings):
|
||||
mark_as_read: bool = False
|
||||
email_check_interval: int = 15
|
||||
auto_add_new_senders: bool = False
|
||||
auth_username: str | None = None
|
||||
auth_password: str | None = None
|
||||
secret_key: str = Field(
|
||||
..., validation_alias=AliasChoices("SECRET_KEY", "LETTERFEED_SECRET_KEY")
|
||||
)
|
||||
algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
13
backend/app/core/hashing.py
Normal file
13
backend/app/core/hashing.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from passlib.context import CryptContext
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against a hashed one."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
@@ -86,16 +86,25 @@ def update_newsletter(
|
||||
db_newsletter.move_to_folder = newsletter_update.move_to_folder
|
||||
db_newsletter.extract_content = newsletter_update.extract_content
|
||||
|
||||
# Simple approach: delete existing senders and add new ones
|
||||
# More efficient sender update
|
||||
existing_emails = {sender.email for sender in db_newsletter.senders}
|
||||
new_emails = set(newsletter_update.sender_emails)
|
||||
|
||||
# Remove senders that are no longer in the list
|
||||
for sender in db_newsletter.senders:
|
||||
db.delete(sender)
|
||||
db.commit()
|
||||
if sender.email not in new_emails:
|
||||
db.delete(sender)
|
||||
|
||||
for email in newsletter_update.sender_emails:
|
||||
db_sender = Sender(id=generate(), email=email, newsletter_id=db_newsletter.id)
|
||||
db.add(db_sender)
|
||||
# Add new senders
|
||||
for email in new_emails:
|
||||
if email not in existing_emails:
|
||||
db_sender = Sender(
|
||||
id=generate(), email=email, newsletter_id=db_newsletter.id
|
||||
)
|
||||
db.add(db_sender)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_newsletter)
|
||||
|
||||
logger.info(f"Successfully updated newsletter with id={db_newsletter.id}")
|
||||
return get_newsletter(db, newsletter_id)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings as env_settings
|
||||
from app.core.hashing import get_password_hash
|
||||
from app.core.logging import get_logger
|
||||
from app.models.settings import Settings as SettingsModel
|
||||
from app.schemas.settings import Settings as SettingsSchema
|
||||
@@ -9,11 +10,10 @@ from app.schemas.settings import SettingsCreate
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_settings(db: Session, with_password: bool = False) -> SettingsSchema:
|
||||
"""Retrieve application settings, prioritizing environment variables over database."""
|
||||
logger.debug("Querying for settings")
|
||||
def create_initial_settings(db: Session):
|
||||
"""Create initial settings in the database if they don't exist."""
|
||||
logger.debug("Checking for initial settings.")
|
||||
db_settings = db.query(SettingsModel).first()
|
||||
|
||||
if not db_settings:
|
||||
logger.info(
|
||||
"No settings found in the database, creating new default settings from environment variables."
|
||||
@@ -25,12 +25,29 @@ def get_settings(db: Session, with_password: bool = False) -> SettingsSchema:
|
||||
k: v for k, v in env_settings.model_dump().items() if k in model_fields
|
||||
}
|
||||
|
||||
if env_settings.auth_password:
|
||||
env_data_for_db["auth_password_hash"] = get_password_hash(
|
||||
env_settings.auth_password
|
||||
)
|
||||
if "auth_password" in env_data_for_db:
|
||||
del env_data_for_db["auth_password"]
|
||||
|
||||
db_settings = SettingsModel(**env_data_for_db)
|
||||
db.add(db_settings)
|
||||
db.commit()
|
||||
db.refresh(db_settings)
|
||||
logger.info("Default settings created from environment variables.")
|
||||
|
||||
|
||||
def get_settings(db: Session, with_password: bool = False) -> SettingsSchema:
|
||||
"""Retrieve application settings, prioritizing environment variables over database."""
|
||||
logger.debug("Querying for settings")
|
||||
db_settings = db.query(SettingsModel).first()
|
||||
|
||||
if not db_settings:
|
||||
# This should not happen if create_initial_settings is called at startup.
|
||||
raise RuntimeError("Settings not initialized.")
|
||||
|
||||
# Build dictionary from DB model attributes, handling possible None values
|
||||
db_data = {
|
||||
"id": db_settings.id,
|
||||
@@ -41,6 +58,7 @@ def get_settings(db: Session, with_password: bool = False) -> SettingsSchema:
|
||||
"mark_as_read": db_settings.mark_as_read,
|
||||
"email_check_interval": db_settings.email_check_interval,
|
||||
"auto_add_new_senders": db_settings.auto_add_new_senders,
|
||||
"auth_username": db_settings.auth_username,
|
||||
}
|
||||
|
||||
# Get all environment settings that were explicitly set.
|
||||
@@ -80,14 +98,22 @@ def create_or_update_settings(db: Session, settings: SettingsCreate):
|
||||
db_settings = SettingsModel()
|
||||
db.add(db_settings)
|
||||
|
||||
update_data = settings.model_dump()
|
||||
update_data = settings.model_dump(exclude_unset=True)
|
||||
|
||||
# Do not update fields that are set by environment variables
|
||||
locked_fields = list(env_settings.model_dump(exclude_unset=True).keys())
|
||||
logger.debug(f"Fields locked by environment variables: {locked_fields}")
|
||||
|
||||
for key, value in update_data.items():
|
||||
if key not in locked_fields:
|
||||
if key in locked_fields:
|
||||
continue
|
||||
|
||||
if key == "auth_password":
|
||||
if value:
|
||||
db_settings.auth_password_hash = get_password_hash(value)
|
||||
else:
|
||||
db_settings.auth_password_hash = None
|
||||
elif hasattr(db_settings, key):
|
||||
setattr(db_settings, key, value)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.core.database import Base, engine
|
||||
from app.core.auth import protected_route
|
||||
from app.core.database import Base, SessionLocal, engine
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.core.scheduler import scheduler, start_scheduler_with_interval
|
||||
from app.routers import feeds, health, imap, newsletters
|
||||
from app.crud.settings import create_initial_settings
|
||||
from app.routers import auth, feeds, health, imap, newsletters
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -19,6 +21,10 @@ async def lifespan(app: FastAPI):
|
||||
logger.info(f"DATABASE_URL used: {settings.database_url}")
|
||||
logger.info("Starting up Letterfeed backend...")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
with SessionLocal() as db:
|
||||
create_initial_settings(db)
|
||||
|
||||
start_scheduler_with_interval()
|
||||
yield
|
||||
if scheduler.running:
|
||||
@@ -43,6 +49,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
app.include_router(health.router)
|
||||
app.include_router(imap.router)
|
||||
app.include_router(newsletters.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(imap.router, dependencies=[Depends(protected_route)])
|
||||
app.include_router(newsletters.router, dependencies=[Depends(protected_route)])
|
||||
app.include_router(feeds.router)
|
||||
|
||||
@@ -17,3 +17,5 @@ class Settings(Base):
|
||||
mark_as_read = Column(Boolean, default=False)
|
||||
email_check_interval = Column(Integer, default=15) # Interval in minutes
|
||||
auto_add_new_senders = Column(Boolean, default=False)
|
||||
auth_username = Column(String, nullable=True)
|
||||
auth_password_hash = Column(String, nullable=True)
|
||||
|
||||
55
backend/app/routers/auth.py
Normal file
55
backend/app/routers/auth.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import secrets
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import (
|
||||
_get_auth_credentials,
|
||||
create_access_token,
|
||||
is_auth_enabled,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.hashing import verify_password
|
||||
from app.schemas.auth import Token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/auth/status")
|
||||
def auth_status(auth_enabled: bool = Depends(is_auth_enabled)):
|
||||
"""Check if authentication is enabled."""
|
||||
return {"auth_enabled": auth_enabled}
|
||||
|
||||
|
||||
@router.post("/auth/login", response_model=Token)
|
||||
def login_for_access_token(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Verify username and password and return an access token."""
|
||||
auth_creds = _get_auth_credentials(db)
|
||||
if not auth_creds.get("username") or not auth_creds.get("password_hash"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication is not configured on the server",
|
||||
)
|
||||
|
||||
correct_username = secrets.compare_digest(
|
||||
form_data.username, auth_creds["username"]
|
||||
)
|
||||
correct_password = verify_password(form_data.password, auth_creds["password_hash"])
|
||||
|
||||
if not (correct_username and correct_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": form_data.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
14
backend/app/schemas/auth.py
Normal file
14
backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""Schema for the access token."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""Schema for the data encoded in the JWT."""
|
||||
|
||||
username: str | None = None
|
||||
@@ -13,12 +13,14 @@ class SettingsBase(BaseModel):
|
||||
mark_as_read: bool = False
|
||||
email_check_interval: int = 15
|
||||
auto_add_new_senders: bool = False
|
||||
auth_username: str | None = None
|
||||
|
||||
|
||||
class SettingsCreate(SettingsBase):
|
||||
"""Schema for creating or updating settings, including the IMAP password."""
|
||||
|
||||
imap_password: str
|
||||
auth_password: str | None = None
|
||||
|
||||
|
||||
class Settings(SettingsBase):
|
||||
@@ -26,6 +28,7 @@ class Settings(SettingsBase):
|
||||
|
||||
id: int
|
||||
imap_password: str | None = Field(None, exclude=True)
|
||||
auth_password: str | None = Field(None, exclude=True)
|
||||
locked_fields: List[str] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
os.environ["LETTERFEED_DATABASE_URL"] = "sqlite:///./test.db"
|
||||
os.environ["LETTERFEED_SECRET_KEY"] = "testsecret"
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
170
backend/app/tests/test_auth.py
Normal file
170
backend/app/tests/test_auth.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.crud.settings import create_or_update_settings
|
||||
from app.schemas.settings import SettingsCreate
|
||||
|
||||
|
||||
def test_auth_status_disabled(client: TestClient):
|
||||
"""Test auth status when auth is disabled."""
|
||||
response = client.get("/auth/status")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth_enabled": False}
|
||||
|
||||
|
||||
def test_auth_status_enabled(client: TestClient, db_session: Session):
|
||||
"""Test auth status when auth is enabled."""
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="test.com",
|
||||
imap_username="test",
|
||||
imap_password="password",
|
||||
auth_username="admin",
|
||||
auth_password="password",
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
response = client.get("/auth/status")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth_enabled": True}
|
||||
|
||||
|
||||
def test_login_endpoint(client: TestClient, db_session: Session):
|
||||
"""Test the /auth/login endpoint directly."""
|
||||
# Setup auth credentials in the database
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="test.com",
|
||||
imap_username="test",
|
||||
imap_password="password",
|
||||
auth_username="admin",
|
||||
auth_password="password",
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
# Test with correct credentials
|
||||
login_data = {"username": "admin", "password": "password"}
|
||||
response = client.post("/auth/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
assert "access_token" in json_response
|
||||
assert json_response["token_type"] == "bearer"
|
||||
|
||||
# Test with incorrect password
|
||||
login_data["password"] = "wrongpassword"
|
||||
response = client.post("/auth/login", data=login_data)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test with incorrect username
|
||||
login_data["username"] = "wronguser"
|
||||
login_data["password"] = "password"
|
||||
response = client.post("/auth/login", data=login_data)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test with no credentials
|
||||
response = client.post("/auth/login")
|
||||
assert response.status_code == 422 # FastAPI validation error for missing form data
|
||||
|
||||
|
||||
def test_protected_route_no_auth(client: TestClient, db_session: Session):
|
||||
"""Test accessing a protected route without auth enabled."""
|
||||
# Health is not protected, newsletters is.
|
||||
response = client.get("/newsletters")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_protected_route_with_auth_fail(client: TestClient, db_session: Session):
|
||||
"""Test accessing a protected route with auth enabled but wrong credentials."""
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="test.com",
|
||||
imap_username="test",
|
||||
imap_password="password",
|
||||
auth_username="admin",
|
||||
auth_password="password",
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
response = client.get("/newsletters")
|
||||
assert response.status_code == 401
|
||||
|
||||
response = client.get(
|
||||
"/newsletters", headers={"Authorization": "Bearer wrongtoken"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_protected_route_with_auth_success(client: TestClient, db_session: Session):
|
||||
"""Test accessing a protected route with auth enabled and correct credentials."""
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="test.com",
|
||||
imap_username="test",
|
||||
imap_password="password",
|
||||
auth_username="admin",
|
||||
auth_password="password",
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
# First, log in to get a token
|
||||
login_data = {"username": "admin", "password": "password"}
|
||||
response = client.post("/auth/login", data=login_data)
|
||||
token = response.json()["access_token"]
|
||||
|
||||
# Then, use the token to access the protected route
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = client.get("/newsletters", headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_unprotected_route_with_auth(client: TestClient, db_session: Session):
|
||||
"""Test that feed endpoint is not protected."""
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="test.com",
|
||||
imap_username="test",
|
||||
imap_password="password",
|
||||
auth_username="admin",
|
||||
auth_password="password",
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
# Log in to get a token
|
||||
login_data = {"username": "admin", "password": "password"}
|
||||
login_response = client.post("/auth/login", data=login_data)
|
||||
token = login_response.json()["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# Create a newsletter to get a feed from
|
||||
newsletter_data = {"name": "Test Newsletter", "sender_emails": ["test@test.com"]}
|
||||
create_response = client.post("/newsletters", json=newsletter_data, headers=headers)
|
||||
newsletter_id = create_response.json()["id"]
|
||||
|
||||
response = client.get(f"/feeds/{newsletter_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_auth_with_env_vars(client: TestClient):
|
||||
"""Test authentication using environment variables."""
|
||||
with patch("app.core.auth.env_settings") as mock_env_settings:
|
||||
mock_env_settings.auth_username = "env_admin"
|
||||
mock_env_settings.auth_password = "env_password"
|
||||
mock_env_settings.secret_key = "test-secret"
|
||||
mock_env_settings.algorithm = "HS256"
|
||||
mock_env_settings.access_token_expire_minutes = 30
|
||||
|
||||
# Log in to get a token
|
||||
login_data = {"username": "env_admin", "password": "env_password"}
|
||||
login_response = client.post("/auth/login", data=login_data)
|
||||
assert login_response.status_code == 200
|
||||
token = login_response.json()["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
response = client.get("/newsletters", headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get(
|
||||
"/newsletters", headers={"Authorization": "Bearer wrongtoken"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
response = client.get("/auth/status")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth_enabled": True}
|
||||
@@ -167,8 +167,12 @@ def test_process_emails_auto_add_sender(mock_imap, db_session: Session):
|
||||
|
||||
@patch("app.services.email_processor.imaplib.IMAP4_SSL")
|
||||
def test_process_emails_no_settings(mock_imap, db_session: Session):
|
||||
"""Test processing emails with no settings in the database."""
|
||||
# No settings in the DB
|
||||
"""Test processing emails with no settings configured."""
|
||||
# This test ensures that email processing is skipped if settings are not configured.
|
||||
# In the new flow, initial settings are created at startup, so we call it here.
|
||||
from app.crud.settings import create_initial_settings
|
||||
|
||||
create_initial_settings(db_session)
|
||||
process_emails(db_session)
|
||||
mock_imap.assert_not_called()
|
||||
|
||||
|
||||
@@ -20,10 +20,19 @@ def test_create_or_update_settings(db_session: Session):
|
||||
search_folder="INBOX",
|
||||
move_to_folder="Archive",
|
||||
mark_as_read=True,
|
||||
auth_username="user",
|
||||
auth_password="password",
|
||||
)
|
||||
settings = create_or_update_settings(db_session, settings_data)
|
||||
assert settings.imap_server == "imap.test.com"
|
||||
assert settings.mark_as_read
|
||||
assert settings.auth_username == "user"
|
||||
|
||||
# check password hash
|
||||
from app.models.settings import Settings as SettingsModel
|
||||
|
||||
db_settings = db_session.query(SettingsModel).first()
|
||||
assert db_settings.auth_password_hash is not None
|
||||
|
||||
updated_settings_data = SettingsCreate(
|
||||
imap_server="imap.updated.com",
|
||||
@@ -32,11 +41,13 @@ def test_create_or_update_settings(db_session: Session):
|
||||
search_folder="Inbox",
|
||||
move_to_folder=None,
|
||||
mark_as_read=False,
|
||||
auth_username="new_user",
|
||||
)
|
||||
updated_settings = create_or_update_settings(db_session, updated_settings_data)
|
||||
assert updated_settings.imap_server == "imap.updated.com"
|
||||
assert not updated_settings.mark_as_read
|
||||
assert updated_settings.move_to_folder is None
|
||||
assert updated_settings.auth_username == "new_user"
|
||||
|
||||
|
||||
def test_get_settings(db_session: Session):
|
||||
@@ -62,6 +73,8 @@ def test_get_settings_with_env_override(db_session: Session):
|
||||
imap_username="db_user",
|
||||
imap_password="db_pass",
|
||||
email_check_interval=15,
|
||||
auth_username="db_user",
|
||||
auth_password="db_password",
|
||||
)
|
||||
create_or_update_settings(db_session, db_settings_data)
|
||||
|
||||
@@ -72,8 +85,11 @@ def test_get_settings_with_env_override(db_session: Session):
|
||||
"imap_username": "env_user",
|
||||
"imap_password": "env_pass",
|
||||
"email_check_interval": 30,
|
||||
"auth_username": "env_auth_user",
|
||||
"auth_password": "env_auth_password",
|
||||
}
|
||||
mock_env_settings.imap_password = "env_pass"
|
||||
mock_env_settings.auth_password = "env_auth_password"
|
||||
|
||||
# 3. Call get_settings and assert the override
|
||||
settings = get_settings(db_session, with_password=True)
|
||||
@@ -81,8 +97,10 @@ def test_get_settings_with_env_override(db_session: Session):
|
||||
assert settings.imap_username == "env_user"
|
||||
assert settings.imap_password == "env_pass"
|
||||
assert settings.email_check_interval == 30
|
||||
assert settings.auth_username == "env_auth_user"
|
||||
assert "imap_server" in settings.locked_fields
|
||||
assert "imap_username" in settings.locked_fields
|
||||
assert "auth_username" in settings.locked_fields
|
||||
|
||||
# 4. Call create_or_update_settings and assert that locked fields are not updated
|
||||
update_data = SettingsCreate(
|
||||
@@ -90,11 +108,14 @@ def test_get_settings_with_env_override(db_session: Session):
|
||||
imap_username="new_user",
|
||||
imap_password="new_pass",
|
||||
email_check_interval=45,
|
||||
auth_username="new_auth_user",
|
||||
auth_password="new_auth_password",
|
||||
)
|
||||
updated_settings = create_or_update_settings(db_session, update_data)
|
||||
assert updated_settings.imap_server == "env.imap.com" # Should not change
|
||||
assert updated_settings.imap_username == "env_user" # Should not change
|
||||
assert updated_settings.email_check_interval == 30 # Should not change
|
||||
assert updated_settings.auth_username == "env_auth_user" # Should not change
|
||||
|
||||
|
||||
def test_create_newsletter(db_session: Session):
|
||||
|
||||
@@ -2,6 +2,10 @@ import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.crud.settings import create_or_update_settings
|
||||
from app.schemas.settings import SettingsCreate
|
||||
|
||||
|
||||
def test_health_check(client: TestClient):
|
||||
@@ -11,12 +15,8 @@ def test_health_check(client: TestClient):
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@patch("app.core.imap.imaplib.IMAP4_SSL")
|
||||
def test_update_imap_settings(mock_imap, client: TestClient):
|
||||
def test_update_imap_settings(client: TestClient):
|
||||
"""Test updating IMAP settings."""
|
||||
mock_imap.return_value.login.return_value = (None, None)
|
||||
mock_imap.return_value.logout.return_value = (None, None)
|
||||
|
||||
settings_data = {
|
||||
"imap_server": "imap.example.com",
|
||||
"imap_username": "test@example.com",
|
||||
@@ -34,12 +34,8 @@ def test_update_imap_settings(mock_imap, client: TestClient):
|
||||
assert response.json()["mark_as_read"]
|
||||
|
||||
|
||||
@patch("app.core.imap.imaplib.IMAP4_SSL")
|
||||
def test_get_imap_settings(mock_imap, client: TestClient):
|
||||
def test_get_imap_settings(client: TestClient):
|
||||
"""Test getting IMAP settings."""
|
||||
mock_imap.return_value.login.return_value = (None, None)
|
||||
mock_imap.return_value.logout.return_value = (None, None)
|
||||
|
||||
settings_data = {
|
||||
"imap_server": "imap.example.com",
|
||||
"imap_username": "test@example.com",
|
||||
@@ -57,20 +53,20 @@ def test_get_imap_settings(mock_imap, client: TestClient):
|
||||
|
||||
|
||||
@patch("app.core.imap.imaplib.IMAP4_SSL")
|
||||
def test_test_imap_connection(mock_imap, client: TestClient):
|
||||
def test_test_imap_connection(mock_imap, client: TestClient, db_session: Session):
|
||||
"""Test the IMAP connection."""
|
||||
mock_imap.return_value.login.return_value = (None, None)
|
||||
mock_imap.return_value.logout.return_value = (None, None)
|
||||
|
||||
settings_data = {
|
||||
"imap_server": "imap.example.com",
|
||||
"imap_username": "test@example.com",
|
||||
"imap_password": "password",
|
||||
"search_folder": "INBOX",
|
||||
"move_to_folder": "Processed",
|
||||
"mark_as_read": True,
|
||||
}
|
||||
client.post("/imap/settings", json=settings_data)
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="imap.example.com",
|
||||
imap_username="test@example.com",
|
||||
imap_password="password",
|
||||
search_folder="INBOX",
|
||||
move_to_folder="Processed",
|
||||
mark_as_read=True,
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
response = client.post("/imap/test")
|
||||
assert response.status_code == 200
|
||||
@@ -78,7 +74,7 @@ def test_test_imap_connection(mock_imap, client: TestClient):
|
||||
|
||||
|
||||
@patch("app.core.imap.imaplib.IMAP4_SSL")
|
||||
def test_get_imap_folders(mock_imap, client: TestClient):
|
||||
def test_get_imap_folders(mock_imap, client: TestClient, db_session: Session):
|
||||
"""Test getting IMAP folders."""
|
||||
mock_imap.return_value.login.return_value = (None, None)
|
||||
mock_imap.return_value.logout.return_value = (None, None)
|
||||
@@ -87,15 +83,15 @@ def test_get_imap_folders(mock_imap, client: TestClient):
|
||||
[b'(NOCONNECT NOSELECT) "/" "INBOX"', b'(NOCONNECT NOSELECT) "/" "Processed"'],
|
||||
)
|
||||
|
||||
settings_data = {
|
||||
"imap_server": "imap.example.com",
|
||||
"imap_username": "test@example.com",
|
||||
"imap_password": "password",
|
||||
"search_folder": "INBOX",
|
||||
"move_to_folder": "Processed",
|
||||
"mark_as_read": True,
|
||||
}
|
||||
client.post("/imap/settings", json=settings_data)
|
||||
settings_data = SettingsCreate(
|
||||
imap_server="imap.example.com",
|
||||
imap_username="test@example.com",
|
||||
imap_password="password",
|
||||
search_folder="INBOX",
|
||||
move_to_folder="Processed",
|
||||
mark_as_read=True,
|
||||
)
|
||||
create_or_update_settings(db_session, settings_data)
|
||||
|
||||
response = client.get("/imap/folders")
|
||||
assert response.status_code == 200
|
||||
@@ -141,7 +137,7 @@ def test_get_single_newsletter(client: TestClient):
|
||||
"""Test getting a single newsletter."""
|
||||
unique_email = f"newsletter_{uuid.uuid4()}@example.com"
|
||||
newsletter_data = {"name": "Third Newsletter", "sender_emails": [unique_email]}
|
||||
create_response = client.post("/newsletters/", json=newsletter_data)
|
||||
create_response = client.post("/newsletters", json=newsletter_data)
|
||||
newsletter_id = create_response.json()["id"]
|
||||
|
||||
response = client.get(f"/newsletters/{newsletter_id}")
|
||||
@@ -151,7 +147,7 @@ def test_get_single_newsletter(client: TestClient):
|
||||
|
||||
def test_get_nonexistent_newsletter(client: TestClient):
|
||||
"""Test getting a nonexistent newsletter."""
|
||||
response = client.get("/newsletters/999")
|
||||
response = client.get("/newsletters/nonexistent")
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Newsletter not found"}
|
||||
|
||||
@@ -160,7 +156,7 @@ def test_get_newsletter_feed(client: TestClient):
|
||||
"""Test generating a newsletter feed."""
|
||||
unique_email = f"feed_test_{uuid.uuid4()}@example.com"
|
||||
newsletter_data = {"name": "Feed Test Newsletter", "sender_emails": [unique_email]}
|
||||
create_response = client.post("/newsletters/", json=newsletter_data)
|
||||
create_response = client.post("/newsletters", json=newsletter_data)
|
||||
newsletter_id = create_response.json()["id"]
|
||||
|
||||
# Add some entries to the newsletter
|
||||
@@ -195,6 +191,6 @@ def test_get_newsletter_feed(client: TestClient):
|
||||
|
||||
def test_get_newsletter_feed_nonexistent_newsletter(client: TestClient):
|
||||
"""Test generating a feed for a nonexistent newsletter."""
|
||||
response = client.get("/feeds/999")
|
||||
response = client.get("/feeds/nonexistent")
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Newsletter not found"}
|
||||
|
||||
Reference in New Issue
Block a user