mirror of
https://github.com/khoaliber/LetterFeed.git
synced 2026-03-02 21:19:13 +00:00
feat: authentication
This commit is contained in:
108
backend/app/core/auth.py
Normal file
108
backend/app/core/auth.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings as env_settings
|
||||
from app.core.database import get_db
|
||||
from app.core.hashing import get_password_hash
|
||||
from app.models.settings import Settings as SettingsModel
|
||||
from app.schemas.auth import TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login", auto_error=False)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_env_password_hash():
|
||||
"""Get and cache the password hash from environment variables."""
|
||||
if env_settings.auth_password:
|
||||
return get_password_hash(env_settings.auth_password)
|
||||
return None
|
||||
|
||||
|
||||
def _get_auth_credentials(db: Session) -> dict:
|
||||
"""Get auth credentials, prioritizing environment variables."""
|
||||
# Env vars take precedence
|
||||
if env_settings.auth_username and env_settings.auth_password:
|
||||
return {
|
||||
"username": env_settings.auth_username,
|
||||
"password_hash": _get_env_password_hash(),
|
||||
}
|
||||
|
||||
# Then check DB
|
||||
db_settings = db.query(SettingsModel).first()
|
||||
if db_settings and db_settings.auth_username and db_settings.auth_password_hash:
|
||||
return {
|
||||
"username": db_settings.auth_username,
|
||||
"password_hash": db_settings.auth_password_hash,
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
"""Create a new access token."""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, env_settings.secret_key, algorithm=env_settings.algorithm
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def protected_route(
|
||||
token: str | None = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Dependency to protect routes with JWTs."""
|
||||
auth_creds = _get_auth_credentials(db)
|
||||
|
||||
# If no auth credentials are set up, access is allowed.
|
||||
if not auth_creds.get("username") or not auth_creds.get("password_hash"):
|
||||
return
|
||||
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, env_settings.secret_key, algorithms=[env_settings.algorithm]
|
||||
)
|
||||
username: str | None = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# Check if the username from the token matches the configured username
|
||||
correct_username = secrets.compare_digest(
|
||||
token_data.username, auth_creds["username"]
|
||||
)
|
||||
if not correct_username:
|
||||
raise credentials_exception
|
||||
|
||||
return token_data.username
|
||||
|
||||
|
||||
def is_auth_enabled(db: Session = Depends(get_db)):
|
||||
"""Dependency to check if auth is enabled."""
|
||||
auth_creds = _get_auth_credentials(db)
|
||||
return bool(auth_creds.get("username"))
|
||||
@@ -27,6 +27,13 @@ class Settings(BaseSettings):
|
||||
mark_as_read: bool = False
|
||||
email_check_interval: int = 15
|
||||
auto_add_new_senders: bool = False
|
||||
auth_username: str | None = None
|
||||
auth_password: str | None = None
|
||||
secret_key: str = Field(
|
||||
..., validation_alias=AliasChoices("SECRET_KEY", "LETTERFEED_SECRET_KEY")
|
||||
)
|
||||
algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
13
backend/app/core/hashing.py
Normal file
13
backend/app/core/hashing.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from passlib.context import CryptContext
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against a hashed one."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
Reference in New Issue
Block a user