diff --git a/src/interface/web/app/components/loginPrompt/loginPrompt.tsx b/src/interface/web/app/components/loginPrompt/loginPrompt.tsx index 633d7b41..756c2108 100644 --- a/src/interface/web/app/components/loginPrompt/loginPrompt.tsx +++ b/src/interface/web/app/components/loginPrompt/loginPrompt.tsx @@ -52,10 +52,6 @@ export default function LoginPrompt(props: LoginPromptProps) { const [useEmailSignIn, setUseEmailSignIn] = useState(false); - const [email, setEmail] = useState(""); - const [checkEmail, setCheckEmail] = useState(false); - const [recheckEmail, setRecheckEmail] = useState(false); - useEffect(() => { const google = (window as any).google; @@ -118,49 +114,13 @@ export default function LoginPrompt(props: LoginPromptProps) { }); }; - function handleMagicLinkSignIn() { - fetch("/auth/magic", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ email: email }), - }) - .then((res) => { - if (res.ok) { - setCheckEmail(true); - if (checkEmail) { - setRecheckEmail(true); - } - return res.json(); - } else { - throw new Error("Failed to send magic link"); - } - }) - .then((data) => { - console.log(data); - }) - .catch((err) => { - console.error(err); - }); - } - if (props.isMobileWidth) { return (
{useEmailSignIn ? ( - + ) : (
{useEmailSignIn ? ( - + ) : ( void; - checkEmail: boolean; - setCheckEmail: (checkEmail: boolean) => void; setUseEmailSignIn: (useEmailSignIn: boolean) => void; - recheckEmail: boolean; - setRecheckEmail: (recheckEmail: boolean) => void; - handleMagicLinkSignIn: () => void; }) { const [otp, setOTP] = useState(""); const [otpError, setOTPError] = useState(""); const [numFailures, setNumFailures] = useState(0); + const [email, setEmail] = useState(""); + const [checkEmail, setCheckEmail] = useState(false); + const [recheckEmail, setRecheckEmail] = useState(false); + const [sendEmailError, setSendEmailError] = useState(""); function checkOTPAndRedirect() { const verifyUrl = `/auth/magic?code=${encodeURIComponent(otp)}&email=${encodeURIComponent(email)}`; @@ -275,6 +217,39 @@ function EmailSignInContext({ }); } + function handleMagicLinkSignIn() { + fetch("/auth/magic", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ email: email }), + }) + .then((res) => { + if (res.ok) { + setCheckEmail(true); + if (checkEmail) { + setRecheckEmail(true); + } + return res.json(); + } else if (res.status === 429 || res.status === 404) { + res.json().then((data) => { + setSendEmailError(data.detail); + throw new Error(data.detail); + }); + } else { + setSendEmailError("Failed to send email. Contact developers for assistance."); + throw new Error("Failed to send magic link via email."); + } + }) + .then((data) => { + console.log(data); + }) + .catch((err) => { + console.error(err); + }); + } + return (
- + {sendEmailError &&
{sendEmailError}
} +
)} {checkEmail && (
@@ -359,9 +335,7 @@ function EmailSignInContext({ variant="ghost" className="p-0 text-orange-500" disabled={recheckEmail} - onClick={() => { - handleMagicLinkSignIn(); - }} + onClick={handleMagicLinkSignIn} > Resend email diff --git a/src/khoj/configure.py b/src/khoj/configure.py index e6612631..40d61a88 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -37,6 +37,7 @@ from khoj.database.adapters import ( aget_or_create_user_by_phone_number, aget_user_by_phone_number, ais_user_subscribed, + delete_ratelimit_records, delete_user_requests, get_all_users, get_or_create_search_models, @@ -428,8 +429,10 @@ def upload_telemetry(): @schedule.repeat(schedule.every(31).minutes) @clean_connections def delete_old_user_requests(): - num_deleted = delete_user_requests() - logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests") + num_user_ratelimit_requests = delete_user_requests() + num_ratelimit_requests = delete_ratelimit_records() + if state.verbose > 2: + logger.debug(f"🗑️ Deleted {num_user_ratelimit_requests + num_ratelimit_requests} stale rate limit requests") @schedule.repeat(schedule.every(17).minutes) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 7ffbf67a..086be4b0 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -27,6 +27,7 @@ from django.contrib.sessions.backends.db import SessionStore from django.db.models import Prefetch, Q from django.db.models.manager import BaseManager from django.db.utils import IntegrityError +from django.utils import timezone as django_timezone from django_apscheduler import util from django_apscheduler.models import DjangoJob, DjangoJobExecution from fastapi import HTTPException @@ -49,6 +50,7 @@ from khoj.database.models import ( NotionConfig, ProcessLock, PublicConversation, + RateLimitRecord, ReflectiveQuestion, SearchModelConfig, ServerChatSettings, @@ -233,20 +235,21 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser: return user -async def aget_or_create_user_by_email(input_email: str) -> tuple[KhojUser, bool]: - email, is_valid_email = normalize_email(input_email) +async def aget_or_create_user_by_email(input_email: str, check_deliverability=False) -> tuple[KhojUser, bool]: + # Validate deliverability to email address of new user + email, is_valid_email = normalize_email(input_email, check_deliverability=check_deliverability) is_existing_user = await KhojUser.objects.filter(email=email).aexists() - # Validate email address of new users if not is_existing_user and not is_valid_email: logger.error(f"Account creation failed. Invalid email address: {email}") return None, False + # Get/create user based on email address user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create( defaults={"username": email, "email": email} ) # Generate a secure 6-digit numeric code - user.email_verification_code = f"{secrets.randbelow(1000000):06}" + user.email_verification_code = f"{secrets.randbelow(int(1e6)):06}" user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5) await user.asave() @@ -516,8 +519,18 @@ def get_user_notion_config(user: KhojUser): return config -def delete_user_requests(window: timedelta = timedelta(days=1)): - return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() +def delete_user_requests(max_age: timedelta = timedelta(days=1)): + """Deletes UserRequests entries older than the specified max_age.""" + cutoff = django_timezone.now() - max_age + deleted_count, _ = UserRequests.objects.filter(created_at__lte=cutoff).delete() + return deleted_count + + +def delete_ratelimit_records(max_age: timedelta = timedelta(days=1)): + """Deletes RateLimitRecord entries older than the specified max_age.""" + cutoff = django_timezone.now() - max_age + deleted_count, _ = RateLimitRecord.objects.filter(created_at__lt=cutoff).delete() + return deleted_count @arequire_valid_user diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index e590bffd..7297ce11 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -25,6 +25,7 @@ from khoj.database.models import ( KhojUser, NotionConfig, ProcessLock, + RateLimitRecord, ReflectiveQuestion, SearchModelConfig, ServerChatSettings, @@ -179,6 +180,7 @@ admin.site.register(NotionConfig, unfold_admin.ModelAdmin) admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin) admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin) admin.site.register(UserRequests, unfold_admin.ModelAdmin) +admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin) @admin.register(Agent) diff --git a/src/khoj/database/migrations/0088_ratelimitrecord.py b/src/khoj/database/migrations/0088_ratelimitrecord.py new file mode 100644 index 00000000..220e98f5 --- /dev/null +++ b/src/khoj/database/migrations/0088_ratelimitrecord.py @@ -0,0 +1,28 @@ +# Generated by Django 5.0.13 on 2025-04-07 07:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0087_alter_aimodelapi_api_key"), + ] + + operations = [ + migrations.CreateModel( + name="RateLimitRecord", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("identifier", models.CharField(db_index=True, max_length=255)), + ("slug", models.CharField(db_index=True, max_length=255)), + ], + options={ + "ordering": ["-created_at"], + "indexes": [ + models.Index(fields=["identifier", "slug", "created_at"], name="database_ra_identif_031adf_idx") + ], + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 6da88c53..429d010e 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -730,10 +730,28 @@ class EntryDates(DbBaseModel): class UserRequests(DbBaseModel): + """Stores user requests to the server for rate limiting.""" + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) slug = models.CharField(max_length=200) +class RateLimitRecord(DbBaseModel): + """Stores individual request timestamps for rate limiting.""" + + identifier = models.CharField(max_length=255, db_index=True) # IP address or email + slug = models.CharField(max_length=255, db_index=True) # Differentiates limit types + + class Meta: + indexes = [ + models.Index(fields=["identifier", "slug", "created_at"]), + ] + ordering = ["-created_at"] + + def __str__(self): + return f"{self.slug} - {self.identifier} at {self.created_at}" + + class DataStore(DbBaseModel): key = models.CharField(max_length=200, unique=True) value = models.JSONField(default=dict) diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 838cfa48..b2a1e623 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -3,11 +3,10 @@ import datetime import logging import os from typing import Optional -from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +from urllib.parse import urlencode, urlparse, urlunparse import requests -from fastapi import APIRouter, Depends -from pydantic import BaseModel, EmailStr +from fastapi import APIRouter, Depends, HTTPException from starlette.authentication import requires from starlette.config import Config from starlette.requests import Request @@ -25,21 +24,20 @@ from khoj.database.adapters import ( ) from khoj.routers.email import send_magic_link_email, send_welcome_email from khoj.routers.helpers import ( + EmailAttemptRateLimiter, EmailVerificationApiRateLimiter, + MagicLinkForm, get_next_url, update_telemetry_state, ) from khoj.utils import state +from khoj.utils.helpers import in_debug_mode logger = logging.getLogger(__name__) auth_router = APIRouter() -class MagicLinkForm(BaseModel): - email: EmailStr - - if not state.anonymous_mode: missing_requirements = [] from authlib.integrations.starlette_client import OAuth, OAuthError @@ -78,24 +76,36 @@ async def login(request: Request): @auth_router.post("/magic") -async def login_magic_link(request: Request, form: MagicLinkForm): +async def login_magic_link( + request: Request, + form: MagicLinkForm, + email_limiter=Depends(EmailAttemptRateLimiter(requests=20, window=60 * 60 * 24, slug="magic_link_login_by_email")), +): if request.user.is_authenticated: # Clear the session if user is already authenticated request.session.pop("user", None) - user, is_new = await aget_or_create_user_by_email(form.email) + # Get/create user if valid email address + check_deliverability = state.billing_enabled and not in_debug_mode() + user, is_new = await aget_or_create_user_by_email(form.email, check_deliverability=check_deliverability) + if not user: + raise HTTPException(status_code=404, detail="Invalid email address. Please fix before trying again.") - if user: - unique_id = user.email_verification_code - await send_magic_link_email(user.email, unique_id, request.base_url) - if is_new: - update_telemetry_state( - request=request, - telemetry_type="api", - api="create_user__email", - metadata={"server_id": str(user.uuid)}, - ) - logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}") + # Rate limit email login by user + user_limiter = EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_login_by_user") + await user_limiter(email=user.email) + + # Send email with magic link + unique_id = user.email_verification_code + await send_magic_link_email(user.email, unique_id, request.base_url) + if is_new: + update_telemetry_state( + request=request, + telemetry_type="api", + api="create_user__email", + metadata={"server_id": str(user.uuid)}, + ) + logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}") return Response(status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d15c0021..88747e46 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -33,8 +33,9 @@ import requests from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async +from django.utils import timezone as django_timezone from fastapi import Depends, Header, HTTPException, Request, UploadFile -from pydantic import BaseModel, Field +from pydantic import BaseModel, EmailStr, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -46,10 +47,10 @@ from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, FileObjectAdapters, + aget_user_by_email, ais_user_subscribed, create_khoj_token, get_khoj_tokens, - get_user_by_email, get_user_name, get_user_notion_config, get_user_subscription_state, @@ -64,6 +65,7 @@ from khoj.database.models import ( KhojUser, NotionConfig, ProcessLock, + RateLimitRecord, Subscription, TextToImageModelConfig, UserRequests, @@ -112,6 +114,7 @@ from khoj.utils.helpers import ( LRU, ConversationCommand, get_file_type, + in_debug_mode, is_none_or_empty, is_valid_url, log_telemetry, @@ -1613,47 +1616,71 @@ class FeedbackData(BaseModel): sentiment: str -class EmailVerificationApiRateLimiter: +class MagicLinkForm(BaseModel): + email: EmailStr + + +class EmailAttemptRateLimiter: + """Rate limiter for email attempts BEFORE get/create user with valid email address.""" + def __init__(self, requests: int, window: int, slug: str): self.requests = requests - self.window = window + self.window = window # Window in seconds self.slug = slug - def __call__(self, request: Request): - # Rate limiting disabled if billing is disabled - if state.billing_enabled is False: + async def __call__(self, form: MagicLinkForm): + # Disable login rate limiting in debug mode + if in_debug_mode(): return - # Extract the email query parameter - email = request.query_params.get("email") + # Calculate the time window cutoff + cutoff = django_timezone.now() - timedelta(seconds=self.window) - if email: - logger.info(f"Email query parameter: {email}") + # Count recent attempts for this email and slug + count = await RateLimitRecord.objects.filter( + identifier=form.email, slug=self.slug, created_at__gte=cutoff + ).acount() - user: KhojUser = get_user_by_email(email) - - if not user: + if count >= self.requests: + logger.warning(f"Email attempt rate limit exceeded for {form.email} (slug: {self.slug})") raise HTTPException( - status_code=404, - detail="User not found.", + status_code=429, detail="Too many requests for your email address. Please wait before trying again." ) + # Record the current attempt + await RateLimitRecord.objects.acreate(identifier=form.email, slug=self.slug) + + +class EmailVerificationApiRateLimiter: + """Rate limiter for actions AFTER user with valid email address is known to exist""" + + def __init__(self, requests: int, window: int, slug: str): + self.requests = requests + self.window = window # Window in seconds + self.slug = slug + + async def __call__(self, email: str = None): + # Disable login rate limiting in debug mode + if in_debug_mode(): + return + + user: KhojUser = await aget_user_by_email(email) + if not user: + raise HTTPException(status_code=404, detail="User not found.") + # Remove requests outside of the time window - cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window) - count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() + cutoff = django_timezone.now() - timedelta(seconds=self.window) + count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount() # Check if the user has exceeded the rate limit if count_requests >= self.requests: - logger.info( + logger.warning( f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}." ) - raise HTTPException( - status_code=429, - detail="Ran out of login attempts", - ) + raise HTTPException(status_code=429, detail="Ran out of login attempts. Please wait before trying again.") # Add the current request to the db - UserRequests.objects.create(user=user, slug=self.slug) + await UserRequests.objects.acreate(user=user, slug=self.slug) class ApiUserRateLimiter: @@ -1677,7 +1704,7 @@ class ApiUserRateLimiter: subscribed = has_required_scope(request, ["premium"]) # Remove requests outside of the time window - cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window) + cutoff = django_timezone.now() - timedelta(seconds=self.window) count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() # Check if the user has exceeded the rate limit @@ -1779,7 +1806,7 @@ class ConversationCommandRateLimiter: subscribed = has_required_scope(request, ["premium"]) # Remove requests outside of the 24-hr time window - cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24) + cutoff = django_timezone.now() - timedelta(seconds=60 * 60 * 24) command_slug = f"{self.slug}_{conversation_command.value}" count_requests = await UserRequests.objects.filter( user=user, created_at__gte=cutoff, slug=command_slug