Add email based rate limiting to email login API endpoint

Server:
 - Rate limit based on unverified email before creating user
 - Check email address for deliverability before creating user
 - Track rate limit for unverified email in new non-user keyed table

Web app:
 - Show error in login popup to user on failure/throttling
 - Simplify login popup logic by moving magic link handling logic
   into EmailSigninContext instead of passing require props via parent
This commit is contained in:
Debanjum
2025-04-06 15:14:06 +05:30
parent fe308c2911
commit d0a933b072
8 changed files with 198 additions and 123 deletions

View File

@@ -52,10 +52,6 @@ export default function LoginPrompt(props: LoginPromptProps) {
const [useEmailSignIn, setUseEmailSignIn] = useState(false);
const [email, setEmail] = useState("");
const [checkEmail, setCheckEmail] = useState(false);
const [recheckEmail, setRecheckEmail] = useState(false);
useEffect(() => {
const google = (window as any).google;
@@ -118,49 +114,13 @@ export default function LoginPrompt(props: LoginPromptProps) {
});
};
function handleMagicLinkSignIn() {
fetch("/auth/magic", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email: email }),
})
.then((res) => {
if (res.ok) {
setCheckEmail(true);
if (checkEmail) {
setRecheckEmail(true);
}
return res.json();
} else {
throw new Error("Failed to send magic link");
}
})
.then((data) => {
console.log(data);
})
.catch((err) => {
console.error(err);
});
}
if (props.isMobileWidth) {
return (
<Drawer open={true} onOpenChange={props.onOpenChange}>
<DrawerContent className={`flex flex-col gap-4 w-full mb-4`}>
<div>
{useEmailSignIn ? (
<EmailSignInContext
email={email}
setEmail={setEmail}
checkEmail={checkEmail}
setCheckEmail={setCheckEmail}
setUseEmailSignIn={setUseEmailSignIn}
recheckEmail={recheckEmail}
setRecheckEmail={setRecheckEmail}
handleMagicLinkSignIn={handleMagicLinkSignIn}
/>
<EmailSignInContext setUseEmailSignIn={setUseEmailSignIn} />
) : (
<MainSignInContext
handleGoogleScriptLoad={handleGoogleScriptLoad}
@@ -187,16 +147,7 @@ export default function LoginPrompt(props: LoginPromptProps) {
</VisuallyHidden.Root>
<div>
{useEmailSignIn ? (
<EmailSignInContext
email={email}
setEmail={setEmail}
checkEmail={checkEmail}
setCheckEmail={setCheckEmail}
setUseEmailSignIn={setUseEmailSignIn}
recheckEmail={recheckEmail}
setRecheckEmail={setRecheckEmail}
handleMagicLinkSignIn={handleMagicLinkSignIn}
/>
<EmailSignInContext setUseEmailSignIn={setUseEmailSignIn} />
) : (
<MainSignInContext
handleGoogleScriptLoad={handleGoogleScriptLoad}
@@ -214,26 +165,17 @@ export default function LoginPrompt(props: LoginPromptProps) {
}
function EmailSignInContext({
email,
setEmail,
checkEmail,
setCheckEmail,
setUseEmailSignIn,
recheckEmail,
handleMagicLinkSignIn,
}: {
email: string;
setEmail: (email: string) => void;
checkEmail: boolean;
setCheckEmail: (checkEmail: boolean) => void;
setUseEmailSignIn: (useEmailSignIn: boolean) => void;
recheckEmail: boolean;
setRecheckEmail: (recheckEmail: boolean) => void;
handleMagicLinkSignIn: () => void;
}) {
const [otp, setOTP] = useState("");
const [otpError, setOTPError] = useState("");
const [numFailures, setNumFailures] = useState(0);
const [email, setEmail] = useState("");
const [checkEmail, setCheckEmail] = useState(false);
const [recheckEmail, setRecheckEmail] = useState(false);
const [sendEmailError, setSendEmailError] = useState("");
function checkOTPAndRedirect() {
const verifyUrl = `/auth/magic?code=${encodeURIComponent(otp)}&email=${encodeURIComponent(email)}`;
@@ -275,6 +217,39 @@ function EmailSignInContext({
});
}
function handleMagicLinkSignIn() {
fetch("/auth/magic", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email: email }),
})
.then((res) => {
if (res.ok) {
setCheckEmail(true);
if (checkEmail) {
setRecheckEmail(true);
}
return res.json();
} else if (res.status === 429 || res.status === 404) {
res.json().then((data) => {
setSendEmailError(data.detail);
throw new Error(data.detail);
});
} else {
setSendEmailError("Failed to send email. Contact developers for assistance.");
throw new Error("Failed to send magic link via email.");
}
})
.then((data) => {
console.log(data);
})
.catch((err) => {
console.error(err);
});
}
return (
<div className="flex flex-col gap-4 p-4">
<Button
@@ -297,7 +272,7 @@ function EmailSignInContext({
: "You will receive a sign-in code on the email address you provide below"}
</div>
{!checkEmail && (
<>
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
<Input
placeholder="Email"
className="p-6 w-[300px] mx-auto rounded-lg"
@@ -320,7 +295,8 @@ function EmailSignInContext({
<PaperPlaneTilt className="h-6 w-6 mr-2 font-bold" />
{checkEmail ? "Check your email" : "Send sign in code"}
</Button>
</>
{sendEmailError && <div className="text-red-500 text-sm">{sendEmailError}</div>}
</div>
)}
{checkEmail && (
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
@@ -359,9 +335,7 @@ function EmailSignInContext({
variant="ghost"
className="p-0 text-orange-500"
disabled={recheckEmail}
onClick={() => {
handleMagicLinkSignIn();
}}
onClick={handleMagicLinkSignIn}
>
<ArrowsClockwise className="h-6 w-6 mr-2 text-gray-500" />
Resend email

View File

@@ -37,6 +37,7 @@ from khoj.database.adapters import (
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
ais_user_subscribed,
delete_ratelimit_records,
delete_user_requests,
get_all_users,
get_or_create_search_models,
@@ -428,8 +429,10 @@ def upload_telemetry():
@schedule.repeat(schedule.every(31).minutes)
@clean_connections
def delete_old_user_requests():
num_deleted = delete_user_requests()
logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests")
num_user_ratelimit_requests = delete_user_requests()
num_ratelimit_requests = delete_ratelimit_records()
if state.verbose > 2:
logger.debug(f"🗑️ Deleted {num_user_ratelimit_requests + num_ratelimit_requests} stale rate limit requests")
@schedule.repeat(schedule.every(17).minutes)

View File

@@ -27,6 +27,7 @@ from django.contrib.sessions.backends.db import SessionStore
from django.db.models import Prefetch, Q
from django.db.models.manager import BaseManager
from django.db.utils import IntegrityError
from django.utils import timezone as django_timezone
from django_apscheduler import util
from django_apscheduler.models import DjangoJob, DjangoJobExecution
from fastapi import HTTPException
@@ -49,6 +50,7 @@ from khoj.database.models import (
NotionConfig,
ProcessLock,
PublicConversation,
RateLimitRecord,
ReflectiveQuestion,
SearchModelConfig,
ServerChatSettings,
@@ -233,20 +235,21 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
return user
async def aget_or_create_user_by_email(input_email: str) -> tuple[KhojUser, bool]:
email, is_valid_email = normalize_email(input_email)
async def aget_or_create_user_by_email(input_email: str, check_deliverability=False) -> tuple[KhojUser, bool]:
# Validate deliverability to email address of new user
email, is_valid_email = normalize_email(input_email, check_deliverability=check_deliverability)
is_existing_user = await KhojUser.objects.filter(email=email).aexists()
# Validate email address of new users
if not is_existing_user and not is_valid_email:
logger.error(f"Account creation failed. Invalid email address: {email}")
return None, False
# Get/create user based on email address
user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
defaults={"username": email, "email": email}
)
# Generate a secure 6-digit numeric code
user.email_verification_code = f"{secrets.randbelow(1000000):06}"
user.email_verification_code = f"{secrets.randbelow(int(1e6)):06}"
user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5)
await user.asave()
@@ -516,8 +519,18 @@ def get_user_notion_config(user: KhojUser):
return config
def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
def delete_user_requests(max_age: timedelta = timedelta(days=1)):
"""Deletes UserRequests entries older than the specified max_age."""
cutoff = django_timezone.now() - max_age
deleted_count, _ = UserRequests.objects.filter(created_at__lte=cutoff).delete()
return deleted_count
def delete_ratelimit_records(max_age: timedelta = timedelta(days=1)):
"""Deletes RateLimitRecord entries older than the specified max_age."""
cutoff = django_timezone.now() - max_age
deleted_count, _ = RateLimitRecord.objects.filter(created_at__lt=cutoff).delete()
return deleted_count
@arequire_valid_user

View File

@@ -25,6 +25,7 @@ from khoj.database.models import (
KhojUser,
NotionConfig,
ProcessLock,
RateLimitRecord,
ReflectiveQuestion,
SearchModelConfig,
ServerChatSettings,
@@ -179,6 +180,7 @@ admin.site.register(NotionConfig, unfold_admin.ModelAdmin)
admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin)
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
@admin.register(Agent)

View File

@@ -0,0 +1,28 @@
# Generated by Django 5.0.13 on 2025-04-07 07:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0087_alter_aimodelapi_api_key"),
]
operations = [
migrations.CreateModel(
name="RateLimitRecord",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("identifier", models.CharField(db_index=True, max_length=255)),
("slug", models.CharField(db_index=True, max_length=255)),
],
options={
"ordering": ["-created_at"],
"indexes": [
models.Index(fields=["identifier", "slug", "created_at"], name="database_ra_identif_031adf_idx")
],
},
),
]

View File

@@ -730,10 +730,28 @@ class EntryDates(DbBaseModel):
class UserRequests(DbBaseModel):
"""Stores user requests to the server for rate limiting."""
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)
class RateLimitRecord(DbBaseModel):
"""Stores individual request timestamps for rate limiting."""
identifier = models.CharField(max_length=255, db_index=True) # IP address or email
slug = models.CharField(max_length=255, db_index=True) # Differentiates limit types
class Meta:
indexes = [
models.Index(fields=["identifier", "slug", "created_at"]),
]
ordering = ["-created_at"]
def __str__(self):
return f"{self.slug} - {self.identifier} at {self.created_at}"
class DataStore(DbBaseModel):
key = models.CharField(max_length=200, unique=True)
value = models.JSONField(default=dict)

View File

@@ -3,11 +3,10 @@ import datetime
import logging
import os
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from urllib.parse import urlencode, urlparse, urlunparse
import requests
from fastapi import APIRouter, Depends
from pydantic import BaseModel, EmailStr
from fastapi import APIRouter, Depends, HTTPException
from starlette.authentication import requires
from starlette.config import Config
from starlette.requests import Request
@@ -25,21 +24,20 @@ from khoj.database.adapters import (
)
from khoj.routers.email import send_magic_link_email, send_welcome_email
from khoj.routers.helpers import (
EmailAttemptRateLimiter,
EmailVerificationApiRateLimiter,
MagicLinkForm,
get_next_url,
update_telemetry_state,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
logger = logging.getLogger(__name__)
auth_router = APIRouter()
class MagicLinkForm(BaseModel):
email: EmailStr
if not state.anonymous_mode:
missing_requirements = []
from authlib.integrations.starlette_client import OAuth, OAuthError
@@ -78,24 +76,36 @@ async def login(request: Request):
@auth_router.post("/magic")
async def login_magic_link(request: Request, form: MagicLinkForm):
async def login_magic_link(
request: Request,
form: MagicLinkForm,
email_limiter=Depends(EmailAttemptRateLimiter(requests=20, window=60 * 60 * 24, slug="magic_link_login_by_email")),
):
if request.user.is_authenticated:
# Clear the session if user is already authenticated
request.session.pop("user", None)
user, is_new = await aget_or_create_user_by_email(form.email)
# Get/create user if valid email address
check_deliverability = state.billing_enabled and not in_debug_mode()
user, is_new = await aget_or_create_user_by_email(form.email, check_deliverability=check_deliverability)
if not user:
raise HTTPException(status_code=404, detail="Invalid email address. Please fix before trying again.")
if user:
unique_id = user.email_verification_code
await send_magic_link_email(user.email, unique_id, request.base_url)
if is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user__email",
metadata={"server_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
# Rate limit email login by user
user_limiter = EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_login_by_user")
await user_limiter(email=user.email)
# Send email with magic link
unique_id = user.email_verification_code
await send_magic_link_email(user.email, unique_id, request.base_url)
if is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user__email",
metadata={"server_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
return Response(status_code=200)

View File

@@ -33,8 +33,9 @@ import requests
from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from django.utils import timezone as django_timezone
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field
from pydantic import BaseModel, EmailStr, Field
from starlette.authentication import has_required_scope
from starlette.requests import URL
@@ -46,10 +47,10 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
aget_user_by_email,
ais_user_subscribed,
create_khoj_token,
get_khoj_tokens,
get_user_by_email,
get_user_name,
get_user_notion_config,
get_user_subscription_state,
@@ -64,6 +65,7 @@ from khoj.database.models import (
KhojUser,
NotionConfig,
ProcessLock,
RateLimitRecord,
Subscription,
TextToImageModelConfig,
UserRequests,
@@ -112,6 +114,7 @@ from khoj.utils.helpers import (
LRU,
ConversationCommand,
get_file_type,
in_debug_mode,
is_none_or_empty,
is_valid_url,
log_telemetry,
@@ -1613,47 +1616,71 @@ class FeedbackData(BaseModel):
sentiment: str
class EmailVerificationApiRateLimiter:
class MagicLinkForm(BaseModel):
email: EmailStr
class EmailAttemptRateLimiter:
"""Rate limiter for email attempts BEFORE get/create user with valid email address."""
def __init__(self, requests: int, window: int, slug: str):
self.requests = requests
self.window = window
self.window = window # Window in seconds
self.slug = slug
def __call__(self, request: Request):
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
async def __call__(self, form: MagicLinkForm):
# Disable login rate limiting in debug mode
if in_debug_mode():
return
# Extract the email query parameter
email = request.query_params.get("email")
# Calculate the time window cutoff
cutoff = django_timezone.now() - timedelta(seconds=self.window)
if email:
logger.info(f"Email query parameter: {email}")
# Count recent attempts for this email and slug
count = await RateLimitRecord.objects.filter(
identifier=form.email, slug=self.slug, created_at__gte=cutoff
).acount()
user: KhojUser = get_user_by_email(email)
if not user:
if count >= self.requests:
logger.warning(f"Email attempt rate limit exceeded for {form.email} (slug: {self.slug})")
raise HTTPException(
status_code=404,
detail="User not found.",
status_code=429, detail="Too many requests for your email address. Please wait before trying again."
)
# Record the current attempt
await RateLimitRecord.objects.acreate(identifier=form.email, slug=self.slug)
class EmailVerificationApiRateLimiter:
"""Rate limiter for actions AFTER user with valid email address is known to exist"""
def __init__(self, requests: int, window: int, slug: str):
self.requests = requests
self.window = window # Window in seconds
self.slug = slug
async def __call__(self, email: str = None):
# Disable login rate limiting in debug mode
if in_debug_mode():
return
user: KhojUser = await aget_user_by_email(email)
if not user:
raise HTTPException(status_code=404, detail="User not found.")
# Remove requests outside of the time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount()
# Check if the user has exceeded the rate limit
if count_requests >= self.requests:
logger.info(
logger.warning(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
)
raise HTTPException(
status_code=429,
detail="Ran out of login attempts",
)
raise HTTPException(status_code=429, detail="Ran out of login attempts. Please wait before trying again.")
# Add the current request to the db
UserRequests.objects.create(user=user, slug=self.slug)
await UserRequests.objects.acreate(user=user, slug=self.slug)
class ApiUserRateLimiter:
@@ -1677,7 +1704,7 @@ class ApiUserRateLimiter:
subscribed = has_required_scope(request, ["premium"])
# Remove requests outside of the time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
@@ -1779,7 +1806,7 @@ class ConversationCommandRateLimiter:
subscribed = has_required_scope(request, ["premium"])
# Remove requests outside of the 24-hr time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
cutoff = django_timezone.now() - timedelta(seconds=60 * 60 * 24)
command_slug = f"{self.slug}_{conversation_command.value}"
count_requests = await UserRequests.objects.filter(
user=user, created_at__gte=cutoff, slug=command_slug