mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Add email based rate limiting to email login API endpoint
Server: - Rate limit based on unverified email before creating user - Check email address for deliverability before creating user - Track rate limit for unverified email in new non-user keyed table Web app: - Show error in login popup to user on failure/throttling - Simplify login popup logic by moving magic link handling logic into EmailSigninContext instead of passing require props via parent
This commit is contained in:
@@ -52,10 +52,6 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
|||||||
|
|
||||||
const [useEmailSignIn, setUseEmailSignIn] = useState(false);
|
const [useEmailSignIn, setUseEmailSignIn] = useState(false);
|
||||||
|
|
||||||
const [email, setEmail] = useState("");
|
|
||||||
const [checkEmail, setCheckEmail] = useState(false);
|
|
||||||
const [recheckEmail, setRecheckEmail] = useState(false);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const google = (window as any).google;
|
const google = (window as any).google;
|
||||||
|
|
||||||
@@ -118,49 +114,13 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
function handleMagicLinkSignIn() {
|
|
||||||
fetch("/auth/magic", {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ email: email }),
|
|
||||||
})
|
|
||||||
.then((res) => {
|
|
||||||
if (res.ok) {
|
|
||||||
setCheckEmail(true);
|
|
||||||
if (checkEmail) {
|
|
||||||
setRecheckEmail(true);
|
|
||||||
}
|
|
||||||
return res.json();
|
|
||||||
} else {
|
|
||||||
throw new Error("Failed to send magic link");
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.then((data) => {
|
|
||||||
console.log(data);
|
|
||||||
})
|
|
||||||
.catch((err) => {
|
|
||||||
console.error(err);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (props.isMobileWidth) {
|
if (props.isMobileWidth) {
|
||||||
return (
|
return (
|
||||||
<Drawer open={true} onOpenChange={props.onOpenChange}>
|
<Drawer open={true} onOpenChange={props.onOpenChange}>
|
||||||
<DrawerContent className={`flex flex-col gap-4 w-full mb-4`}>
|
<DrawerContent className={`flex flex-col gap-4 w-full mb-4`}>
|
||||||
<div>
|
<div>
|
||||||
{useEmailSignIn ? (
|
{useEmailSignIn ? (
|
||||||
<EmailSignInContext
|
<EmailSignInContext setUseEmailSignIn={setUseEmailSignIn} />
|
||||||
email={email}
|
|
||||||
setEmail={setEmail}
|
|
||||||
checkEmail={checkEmail}
|
|
||||||
setCheckEmail={setCheckEmail}
|
|
||||||
setUseEmailSignIn={setUseEmailSignIn}
|
|
||||||
recheckEmail={recheckEmail}
|
|
||||||
setRecheckEmail={setRecheckEmail}
|
|
||||||
handleMagicLinkSignIn={handleMagicLinkSignIn}
|
|
||||||
/>
|
|
||||||
) : (
|
) : (
|
||||||
<MainSignInContext
|
<MainSignInContext
|
||||||
handleGoogleScriptLoad={handleGoogleScriptLoad}
|
handleGoogleScriptLoad={handleGoogleScriptLoad}
|
||||||
@@ -187,16 +147,7 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
|||||||
</VisuallyHidden.Root>
|
</VisuallyHidden.Root>
|
||||||
<div>
|
<div>
|
||||||
{useEmailSignIn ? (
|
{useEmailSignIn ? (
|
||||||
<EmailSignInContext
|
<EmailSignInContext setUseEmailSignIn={setUseEmailSignIn} />
|
||||||
email={email}
|
|
||||||
setEmail={setEmail}
|
|
||||||
checkEmail={checkEmail}
|
|
||||||
setCheckEmail={setCheckEmail}
|
|
||||||
setUseEmailSignIn={setUseEmailSignIn}
|
|
||||||
recheckEmail={recheckEmail}
|
|
||||||
setRecheckEmail={setRecheckEmail}
|
|
||||||
handleMagicLinkSignIn={handleMagicLinkSignIn}
|
|
||||||
/>
|
|
||||||
) : (
|
) : (
|
||||||
<MainSignInContext
|
<MainSignInContext
|
||||||
handleGoogleScriptLoad={handleGoogleScriptLoad}
|
handleGoogleScriptLoad={handleGoogleScriptLoad}
|
||||||
@@ -214,26 +165,17 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function EmailSignInContext({
|
function EmailSignInContext({
|
||||||
email,
|
|
||||||
setEmail,
|
|
||||||
checkEmail,
|
|
||||||
setCheckEmail,
|
|
||||||
setUseEmailSignIn,
|
setUseEmailSignIn,
|
||||||
recheckEmail,
|
|
||||||
handleMagicLinkSignIn,
|
|
||||||
}: {
|
}: {
|
||||||
email: string;
|
|
||||||
setEmail: (email: string) => void;
|
|
||||||
checkEmail: boolean;
|
|
||||||
setCheckEmail: (checkEmail: boolean) => void;
|
|
||||||
setUseEmailSignIn: (useEmailSignIn: boolean) => void;
|
setUseEmailSignIn: (useEmailSignIn: boolean) => void;
|
||||||
recheckEmail: boolean;
|
|
||||||
setRecheckEmail: (recheckEmail: boolean) => void;
|
|
||||||
handleMagicLinkSignIn: () => void;
|
|
||||||
}) {
|
}) {
|
||||||
const [otp, setOTP] = useState("");
|
const [otp, setOTP] = useState("");
|
||||||
const [otpError, setOTPError] = useState("");
|
const [otpError, setOTPError] = useState("");
|
||||||
const [numFailures, setNumFailures] = useState(0);
|
const [numFailures, setNumFailures] = useState(0);
|
||||||
|
const [email, setEmail] = useState("");
|
||||||
|
const [checkEmail, setCheckEmail] = useState(false);
|
||||||
|
const [recheckEmail, setRecheckEmail] = useState(false);
|
||||||
|
const [sendEmailError, setSendEmailError] = useState("");
|
||||||
|
|
||||||
function checkOTPAndRedirect() {
|
function checkOTPAndRedirect() {
|
||||||
const verifyUrl = `/auth/magic?code=${encodeURIComponent(otp)}&email=${encodeURIComponent(email)}`;
|
const verifyUrl = `/auth/magic?code=${encodeURIComponent(otp)}&email=${encodeURIComponent(email)}`;
|
||||||
@@ -275,6 +217,39 @@ function EmailSignInContext({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleMagicLinkSignIn() {
|
||||||
|
fetch("/auth/magic", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ email: email }),
|
||||||
|
})
|
||||||
|
.then((res) => {
|
||||||
|
if (res.ok) {
|
||||||
|
setCheckEmail(true);
|
||||||
|
if (checkEmail) {
|
||||||
|
setRecheckEmail(true);
|
||||||
|
}
|
||||||
|
return res.json();
|
||||||
|
} else if (res.status === 429 || res.status === 404) {
|
||||||
|
res.json().then((data) => {
|
||||||
|
setSendEmailError(data.detail);
|
||||||
|
throw new Error(data.detail);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
setSendEmailError("Failed to send email. Contact developers for assistance.");
|
||||||
|
throw new Error("Failed to send magic link via email.");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then((data) => {
|
||||||
|
console.log(data);
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.error(err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4 p-4">
|
<div className="flex flex-col gap-4 p-4">
|
||||||
<Button
|
<Button
|
||||||
@@ -297,7 +272,7 @@ function EmailSignInContext({
|
|||||||
: "You will receive a sign-in code on the email address you provide below"}
|
: "You will receive a sign-in code on the email address you provide below"}
|
||||||
</div>
|
</div>
|
||||||
{!checkEmail && (
|
{!checkEmail && (
|
||||||
<>
|
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
|
||||||
<Input
|
<Input
|
||||||
placeholder="Email"
|
placeholder="Email"
|
||||||
className="p-6 w-[300px] mx-auto rounded-lg"
|
className="p-6 w-[300px] mx-auto rounded-lg"
|
||||||
@@ -320,7 +295,8 @@ function EmailSignInContext({
|
|||||||
<PaperPlaneTilt className="h-6 w-6 mr-2 font-bold" />
|
<PaperPlaneTilt className="h-6 w-6 mr-2 font-bold" />
|
||||||
{checkEmail ? "Check your email" : "Send sign in code"}
|
{checkEmail ? "Check your email" : "Send sign in code"}
|
||||||
</Button>
|
</Button>
|
||||||
</>
|
{sendEmailError && <div className="text-red-500 text-sm">{sendEmailError}</div>}
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
{checkEmail && (
|
{checkEmail && (
|
||||||
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
|
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
|
||||||
@@ -359,9 +335,7 @@ function EmailSignInContext({
|
|||||||
variant="ghost"
|
variant="ghost"
|
||||||
className="p-0 text-orange-500"
|
className="p-0 text-orange-500"
|
||||||
disabled={recheckEmail}
|
disabled={recheckEmail}
|
||||||
onClick={() => {
|
onClick={handleMagicLinkSignIn}
|
||||||
handleMagicLinkSignIn();
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
<ArrowsClockwise className="h-6 w-6 mr-2 text-gray-500" />
|
<ArrowsClockwise className="h-6 w-6 mr-2 text-gray-500" />
|
||||||
Resend email
|
Resend email
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from khoj.database.adapters import (
|
|||||||
aget_or_create_user_by_phone_number,
|
aget_or_create_user_by_phone_number,
|
||||||
aget_user_by_phone_number,
|
aget_user_by_phone_number,
|
||||||
ais_user_subscribed,
|
ais_user_subscribed,
|
||||||
|
delete_ratelimit_records,
|
||||||
delete_user_requests,
|
delete_user_requests,
|
||||||
get_all_users,
|
get_all_users,
|
||||||
get_or_create_search_models,
|
get_or_create_search_models,
|
||||||
@@ -428,8 +429,10 @@ def upload_telemetry():
|
|||||||
@schedule.repeat(schedule.every(31).minutes)
|
@schedule.repeat(schedule.every(31).minutes)
|
||||||
@clean_connections
|
@clean_connections
|
||||||
def delete_old_user_requests():
|
def delete_old_user_requests():
|
||||||
num_deleted = delete_user_requests()
|
num_user_ratelimit_requests = delete_user_requests()
|
||||||
logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests")
|
num_ratelimit_requests = delete_ratelimit_records()
|
||||||
|
if state.verbose > 2:
|
||||||
|
logger.debug(f"🗑️ Deleted {num_user_ratelimit_requests + num_ratelimit_requests} stale rate limit requests")
|
||||||
|
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(17).minutes)
|
@schedule.repeat(schedule.every(17).minutes)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from django.contrib.sessions.backends.db import SessionStore
|
|||||||
from django.db.models import Prefetch, Q
|
from django.db.models import Prefetch, Q
|
||||||
from django.db.models.manager import BaseManager
|
from django.db.models.manager import BaseManager
|
||||||
from django.db.utils import IntegrityError
|
from django.db.utils import IntegrityError
|
||||||
|
from django.utils import timezone as django_timezone
|
||||||
from django_apscheduler import util
|
from django_apscheduler import util
|
||||||
from django_apscheduler.models import DjangoJob, DjangoJobExecution
|
from django_apscheduler.models import DjangoJob, DjangoJobExecution
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@@ -49,6 +50,7 @@ from khoj.database.models import (
|
|||||||
NotionConfig,
|
NotionConfig,
|
||||||
ProcessLock,
|
ProcessLock,
|
||||||
PublicConversation,
|
PublicConversation,
|
||||||
|
RateLimitRecord,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
SearchModelConfig,
|
SearchModelConfig,
|
||||||
ServerChatSettings,
|
ServerChatSettings,
|
||||||
@@ -233,20 +235,21 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def aget_or_create_user_by_email(input_email: str) -> tuple[KhojUser, bool]:
|
async def aget_or_create_user_by_email(input_email: str, check_deliverability=False) -> tuple[KhojUser, bool]:
|
||||||
email, is_valid_email = normalize_email(input_email)
|
# Validate deliverability to email address of new user
|
||||||
|
email, is_valid_email = normalize_email(input_email, check_deliverability=check_deliverability)
|
||||||
is_existing_user = await KhojUser.objects.filter(email=email).aexists()
|
is_existing_user = await KhojUser.objects.filter(email=email).aexists()
|
||||||
# Validate email address of new users
|
|
||||||
if not is_existing_user and not is_valid_email:
|
if not is_existing_user and not is_valid_email:
|
||||||
logger.error(f"Account creation failed. Invalid email address: {email}")
|
logger.error(f"Account creation failed. Invalid email address: {email}")
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
|
# Get/create user based on email address
|
||||||
user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
|
user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
|
||||||
defaults={"username": email, "email": email}
|
defaults={"username": email, "email": email}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate a secure 6-digit numeric code
|
# Generate a secure 6-digit numeric code
|
||||||
user.email_verification_code = f"{secrets.randbelow(1000000):06}"
|
user.email_verification_code = f"{secrets.randbelow(int(1e6)):06}"
|
||||||
user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5)
|
user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5)
|
||||||
await user.asave()
|
await user.asave()
|
||||||
|
|
||||||
@@ -516,8 +519,18 @@ def get_user_notion_config(user: KhojUser):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def delete_user_requests(window: timedelta = timedelta(days=1)):
|
def delete_user_requests(max_age: timedelta = timedelta(days=1)):
|
||||||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
"""Deletes UserRequests entries older than the specified max_age."""
|
||||||
|
cutoff = django_timezone.now() - max_age
|
||||||
|
deleted_count, _ = UserRequests.objects.filter(created_at__lte=cutoff).delete()
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
def delete_ratelimit_records(max_age: timedelta = timedelta(days=1)):
|
||||||
|
"""Deletes RateLimitRecord entries older than the specified max_age."""
|
||||||
|
cutoff = django_timezone.now() - max_age
|
||||||
|
deleted_count, _ = RateLimitRecord.objects.filter(created_at__lt=cutoff).delete()
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
@arequire_valid_user
|
@arequire_valid_user
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from khoj.database.models import (
|
|||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
ProcessLock,
|
ProcessLock,
|
||||||
|
RateLimitRecord,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
SearchModelConfig,
|
SearchModelConfig,
|
||||||
ServerChatSettings,
|
ServerChatSettings,
|
||||||
@@ -179,6 +180,7 @@ admin.site.register(NotionConfig, unfold_admin.ModelAdmin)
|
|||||||
admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin)
|
admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin)
|
||||||
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
|
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
|
||||||
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
|
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
|
||||||
|
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
|
||||||
|
|
||||||
|
|
||||||
@admin.register(Agent)
|
@admin.register(Agent)
|
||||||
|
|||||||
28
src/khoj/database/migrations/0088_ratelimitrecord.py
Normal file
28
src/khoj/database/migrations/0088_ratelimitrecord.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Generated by Django 5.0.13 on 2025-04-07 07:10
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0087_alter_aimodelapi_api_key"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="RateLimitRecord",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("identifier", models.CharField(db_index=True, max_length=255)),
|
||||||
|
("slug", models.CharField(db_index=True, max_length=255)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"ordering": ["-created_at"],
|
||||||
|
"indexes": [
|
||||||
|
models.Index(fields=["identifier", "slug", "created_at"], name="database_ra_identif_031adf_idx")
|
||||||
|
],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -730,10 +730,28 @@ class EntryDates(DbBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UserRequests(DbBaseModel):
|
class UserRequests(DbBaseModel):
|
||||||
|
"""Stores user requests to the server for rate limiting."""
|
||||||
|
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
slug = models.CharField(max_length=200)
|
slug = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitRecord(DbBaseModel):
|
||||||
|
"""Stores individual request timestamps for rate limiting."""
|
||||||
|
|
||||||
|
identifier = models.CharField(max_length=255, db_index=True) # IP address or email
|
||||||
|
slug = models.CharField(max_length=255, db_index=True) # Differentiates limit types
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
indexes = [
|
||||||
|
models.Index(fields=["identifier", "slug", "created_at"]),
|
||||||
|
]
|
||||||
|
ordering = ["-created_at"]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.slug} - {self.identifier} at {self.created_at}"
|
||||||
|
|
||||||
|
|
||||||
class DataStore(DbBaseModel):
|
class DataStore(DbBaseModel):
|
||||||
key = models.CharField(max_length=200, unique=True)
|
key = models.CharField(max_length=200, unique=True)
|
||||||
value = models.JSONField(default=dict)
|
value = models.JSONField(default=dict)
|
||||||
|
|||||||
@@ -3,11 +3,10 @@ import datetime
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
from urllib.parse import urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from starlette.config import Config
|
from starlette.config import Config
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
@@ -25,21 +24,20 @@ from khoj.database.adapters import (
|
|||||||
)
|
)
|
||||||
from khoj.routers.email import send_magic_link_email, send_welcome_email
|
from khoj.routers.email import send_magic_link_email, send_welcome_email
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
EmailAttemptRateLimiter,
|
||||||
EmailVerificationApiRateLimiter,
|
EmailVerificationApiRateLimiter,
|
||||||
|
MagicLinkForm,
|
||||||
get_next_url,
|
get_next_url,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import in_debug_mode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
auth_router = APIRouter()
|
auth_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class MagicLinkForm(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
|
|
||||||
|
|
||||||
if not state.anonymous_mode:
|
if not state.anonymous_mode:
|
||||||
missing_requirements = []
|
missing_requirements = []
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||||
@@ -78,24 +76,36 @@ async def login(request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@auth_router.post("/magic")
|
@auth_router.post("/magic")
|
||||||
async def login_magic_link(request: Request, form: MagicLinkForm):
|
async def login_magic_link(
|
||||||
|
request: Request,
|
||||||
|
form: MagicLinkForm,
|
||||||
|
email_limiter=Depends(EmailAttemptRateLimiter(requests=20, window=60 * 60 * 24, slug="magic_link_login_by_email")),
|
||||||
|
):
|
||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
# Clear the session if user is already authenticated
|
# Clear the session if user is already authenticated
|
||||||
request.session.pop("user", None)
|
request.session.pop("user", None)
|
||||||
|
|
||||||
user, is_new = await aget_or_create_user_by_email(form.email)
|
# Get/create user if valid email address
|
||||||
|
check_deliverability = state.billing_enabled and not in_debug_mode()
|
||||||
|
user, is_new = await aget_or_create_user_by_email(form.email, check_deliverability=check_deliverability)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="Invalid email address. Please fix before trying again.")
|
||||||
|
|
||||||
if user:
|
# Rate limit email login by user
|
||||||
unique_id = user.email_verification_code
|
user_limiter = EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_login_by_user")
|
||||||
await send_magic_link_email(user.email, unique_id, request.base_url)
|
await user_limiter(email=user.email)
|
||||||
if is_new:
|
|
||||||
update_telemetry_state(
|
# Send email with magic link
|
||||||
request=request,
|
unique_id = user.email_verification_code
|
||||||
telemetry_type="api",
|
await send_magic_link_email(user.email, unique_id, request.base_url)
|
||||||
api="create_user__email",
|
if is_new:
|
||||||
metadata={"server_id": str(user.uuid)},
|
update_telemetry_state(
|
||||||
)
|
request=request,
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
telemetry_type="api",
|
||||||
|
api="create_user__email",
|
||||||
|
metadata={"server_id": str(user.uuid)},
|
||||||
|
)
|
||||||
|
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||||
|
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|||||||
@@ -33,8 +33,9 @@ import requests
|
|||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.utils import timezone as django_timezone
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
|
|
||||||
@@ -46,10 +47,10 @@ from khoj.database.adapters import (
|
|||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
FileObjectAdapters,
|
||||||
|
aget_user_by_email,
|
||||||
ais_user_subscribed,
|
ais_user_subscribed,
|
||||||
create_khoj_token,
|
create_khoj_token,
|
||||||
get_khoj_tokens,
|
get_khoj_tokens,
|
||||||
get_user_by_email,
|
|
||||||
get_user_name,
|
get_user_name,
|
||||||
get_user_notion_config,
|
get_user_notion_config,
|
||||||
get_user_subscription_state,
|
get_user_subscription_state,
|
||||||
@@ -64,6 +65,7 @@ from khoj.database.models import (
|
|||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
ProcessLock,
|
ProcessLock,
|
||||||
|
RateLimitRecord,
|
||||||
Subscription,
|
Subscription,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserRequests,
|
UserRequests,
|
||||||
@@ -112,6 +114,7 @@ from khoj.utils.helpers import (
|
|||||||
LRU,
|
LRU,
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
get_file_type,
|
get_file_type,
|
||||||
|
in_debug_mode,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_valid_url,
|
is_valid_url,
|
||||||
log_telemetry,
|
log_telemetry,
|
||||||
@@ -1613,47 +1616,71 @@ class FeedbackData(BaseModel):
|
|||||||
sentiment: str
|
sentiment: str
|
||||||
|
|
||||||
|
|
||||||
class EmailVerificationApiRateLimiter:
|
class MagicLinkForm(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class EmailAttemptRateLimiter:
|
||||||
|
"""Rate limiter for email attempts BEFORE get/create user with valid email address."""
|
||||||
|
|
||||||
def __init__(self, requests: int, window: int, slug: str):
|
def __init__(self, requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
self.window = window
|
self.window = window # Window in seconds
|
||||||
self.slug = slug
|
self.slug = slug
|
||||||
|
|
||||||
def __call__(self, request: Request):
|
async def __call__(self, form: MagicLinkForm):
|
||||||
# Rate limiting disabled if billing is disabled
|
# Disable login rate limiting in debug mode
|
||||||
if state.billing_enabled is False:
|
if in_debug_mode():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Extract the email query parameter
|
# Calculate the time window cutoff
|
||||||
email = request.query_params.get("email")
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||||
|
|
||||||
if email:
|
# Count recent attempts for this email and slug
|
||||||
logger.info(f"Email query parameter: {email}")
|
count = await RateLimitRecord.objects.filter(
|
||||||
|
identifier=form.email, slug=self.slug, created_at__gte=cutoff
|
||||||
|
).acount()
|
||||||
|
|
||||||
user: KhojUser = get_user_by_email(email)
|
if count >= self.requests:
|
||||||
|
logger.warning(f"Email attempt rate limit exceeded for {form.email} (slug: {self.slug})")
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=429, detail="Too many requests for your email address. Please wait before trying again."
|
||||||
detail="User not found.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Record the current attempt
|
||||||
|
await RateLimitRecord.objects.acreate(identifier=form.email, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailVerificationApiRateLimiter:
|
||||||
|
"""Rate limiter for actions AFTER user with valid email address is known to exist"""
|
||||||
|
|
||||||
|
def __init__(self, requests: int, window: int, slug: str):
|
||||||
|
self.requests = requests
|
||||||
|
self.window = window # Window in seconds
|
||||||
|
self.slug = slug
|
||||||
|
|
||||||
|
async def __call__(self, email: str = None):
|
||||||
|
# Disable login rate limiting in debug mode
|
||||||
|
if in_debug_mode():
|
||||||
|
return
|
||||||
|
|
||||||
|
user: KhojUser = await aget_user_by_email(email)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found.")
|
||||||
|
|
||||||
# Remove requests outside of the time window
|
# Remove requests outside of the time window
|
||||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount()
|
||||||
|
|
||||||
# Check if the user has exceeded the rate limit
|
# Check if the user has exceeded the rate limit
|
||||||
if count_requests >= self.requests:
|
if count_requests >= self.requests:
|
||||||
logger.info(
|
logger.warning(
|
||||||
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
|
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=429, detail="Ran out of login attempts. Please wait before trying again.")
|
||||||
status_code=429,
|
|
||||||
detail="Ran out of login attempts",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the current request to the db
|
# Add the current request to the db
|
||||||
UserRequests.objects.create(user=user, slug=self.slug)
|
await UserRequests.objects.acreate(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
@@ -1677,7 +1704,7 @@ class ApiUserRateLimiter:
|
|||||||
subscribed = has_required_scope(request, ["premium"])
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
|
||||||
# Remove requests outside of the time window
|
# Remove requests outside of the time window
|
||||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||||
|
|
||||||
# Check if the user has exceeded the rate limit
|
# Check if the user has exceeded the rate limit
|
||||||
@@ -1779,7 +1806,7 @@ class ConversationCommandRateLimiter:
|
|||||||
subscribed = has_required_scope(request, ["premium"])
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
|
||||||
# Remove requests outside of the 24-hr time window
|
# Remove requests outside of the 24-hr time window
|
||||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
|
cutoff = django_timezone.now() - timedelta(seconds=60 * 60 * 24)
|
||||||
command_slug = f"{self.slug}_{conversation_command.value}"
|
command_slug = f"{self.slug}_{conversation_command.value}"
|
||||||
count_requests = await UserRequests.objects.filter(
|
count_requests = await UserRequests.objects.filter(
|
||||||
user=user, created_at__gte=cutoff, slug=command_slug
|
user=user, created_at__gte=cutoff, slug=command_slug
|
||||||
|
|||||||
Reference in New Issue
Block a user