mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Add email based rate limiting to email login API endpoint
Server: - Rate limit based on unverified email before creating user - Check email address for deliverability before creating user - Track rate limit for unverified email in new non-user keyed table Web app: - Show error in login popup to user on failure/throttling - Simplify login popup logic by moving magic link handling logic into EmailSigninContext instead of passing require props via parent
This commit is contained in:
@@ -52,10 +52,6 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
||||
|
||||
const [useEmailSignIn, setUseEmailSignIn] = useState(false);
|
||||
|
||||
const [email, setEmail] = useState("");
|
||||
const [checkEmail, setCheckEmail] = useState(false);
|
||||
const [recheckEmail, setRecheckEmail] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const google = (window as any).google;
|
||||
|
||||
@@ -118,49 +114,13 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
||||
});
|
||||
};
|
||||
|
||||
function handleMagicLinkSignIn() {
|
||||
fetch("/auth/magic", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ email: email }),
|
||||
})
|
||||
.then((res) => {
|
||||
if (res.ok) {
|
||||
setCheckEmail(true);
|
||||
if (checkEmail) {
|
||||
setRecheckEmail(true);
|
||||
}
|
||||
return res.json();
|
||||
} else {
|
||||
throw new Error("Failed to send magic link");
|
||||
}
|
||||
})
|
||||
.then((data) => {
|
||||
console.log(data);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(err);
|
||||
});
|
||||
}
|
||||
|
||||
if (props.isMobileWidth) {
|
||||
return (
|
||||
<Drawer open={true} onOpenChange={props.onOpenChange}>
|
||||
<DrawerContent className={`flex flex-col gap-4 w-full mb-4`}>
|
||||
<div>
|
||||
{useEmailSignIn ? (
|
||||
<EmailSignInContext
|
||||
email={email}
|
||||
setEmail={setEmail}
|
||||
checkEmail={checkEmail}
|
||||
setCheckEmail={setCheckEmail}
|
||||
setUseEmailSignIn={setUseEmailSignIn}
|
||||
recheckEmail={recheckEmail}
|
||||
setRecheckEmail={setRecheckEmail}
|
||||
handleMagicLinkSignIn={handleMagicLinkSignIn}
|
||||
/>
|
||||
<EmailSignInContext setUseEmailSignIn={setUseEmailSignIn} />
|
||||
) : (
|
||||
<MainSignInContext
|
||||
handleGoogleScriptLoad={handleGoogleScriptLoad}
|
||||
@@ -187,16 +147,7 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
||||
</VisuallyHidden.Root>
|
||||
<div>
|
||||
{useEmailSignIn ? (
|
||||
<EmailSignInContext
|
||||
email={email}
|
||||
setEmail={setEmail}
|
||||
checkEmail={checkEmail}
|
||||
setCheckEmail={setCheckEmail}
|
||||
setUseEmailSignIn={setUseEmailSignIn}
|
||||
recheckEmail={recheckEmail}
|
||||
setRecheckEmail={setRecheckEmail}
|
||||
handleMagicLinkSignIn={handleMagicLinkSignIn}
|
||||
/>
|
||||
<EmailSignInContext setUseEmailSignIn={setUseEmailSignIn} />
|
||||
) : (
|
||||
<MainSignInContext
|
||||
handleGoogleScriptLoad={handleGoogleScriptLoad}
|
||||
@@ -214,26 +165,17 @@ export default function LoginPrompt(props: LoginPromptProps) {
|
||||
}
|
||||
|
||||
function EmailSignInContext({
|
||||
email,
|
||||
setEmail,
|
||||
checkEmail,
|
||||
setCheckEmail,
|
||||
setUseEmailSignIn,
|
||||
recheckEmail,
|
||||
handleMagicLinkSignIn,
|
||||
}: {
|
||||
email: string;
|
||||
setEmail: (email: string) => void;
|
||||
checkEmail: boolean;
|
||||
setCheckEmail: (checkEmail: boolean) => void;
|
||||
setUseEmailSignIn: (useEmailSignIn: boolean) => void;
|
||||
recheckEmail: boolean;
|
||||
setRecheckEmail: (recheckEmail: boolean) => void;
|
||||
handleMagicLinkSignIn: () => void;
|
||||
}) {
|
||||
const [otp, setOTP] = useState("");
|
||||
const [otpError, setOTPError] = useState("");
|
||||
const [numFailures, setNumFailures] = useState(0);
|
||||
const [email, setEmail] = useState("");
|
||||
const [checkEmail, setCheckEmail] = useState(false);
|
||||
const [recheckEmail, setRecheckEmail] = useState(false);
|
||||
const [sendEmailError, setSendEmailError] = useState("");
|
||||
|
||||
function checkOTPAndRedirect() {
|
||||
const verifyUrl = `/auth/magic?code=${encodeURIComponent(otp)}&email=${encodeURIComponent(email)}`;
|
||||
@@ -275,6 +217,39 @@ function EmailSignInContext({
|
||||
});
|
||||
}
|
||||
|
||||
function handleMagicLinkSignIn() {
|
||||
fetch("/auth/magic", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ email: email }),
|
||||
})
|
||||
.then((res) => {
|
||||
if (res.ok) {
|
||||
setCheckEmail(true);
|
||||
if (checkEmail) {
|
||||
setRecheckEmail(true);
|
||||
}
|
||||
return res.json();
|
||||
} else if (res.status === 429 || res.status === 404) {
|
||||
res.json().then((data) => {
|
||||
setSendEmailError(data.detail);
|
||||
throw new Error(data.detail);
|
||||
});
|
||||
} else {
|
||||
setSendEmailError("Failed to send email. Contact developers for assistance.");
|
||||
throw new Error("Failed to send magic link via email.");
|
||||
}
|
||||
})
|
||||
.then((data) => {
|
||||
console.log(data);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(err);
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<Button
|
||||
@@ -297,7 +272,7 @@ function EmailSignInContext({
|
||||
: "You will receive a sign-in code on the email address you provide below"}
|
||||
</div>
|
||||
{!checkEmail && (
|
||||
<>
|
||||
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
|
||||
<Input
|
||||
placeholder="Email"
|
||||
className="p-6 w-[300px] mx-auto rounded-lg"
|
||||
@@ -320,7 +295,8 @@ function EmailSignInContext({
|
||||
<PaperPlaneTilt className="h-6 w-6 mr-2 font-bold" />
|
||||
{checkEmail ? "Check your email" : "Send sign in code"}
|
||||
</Button>
|
||||
</>
|
||||
{sendEmailError && <div className="text-red-500 text-sm">{sendEmailError}</div>}
|
||||
</div>
|
||||
)}
|
||||
{checkEmail && (
|
||||
<div className="flex items-center justify-center gap-4 text-muted-foreground flex-col">
|
||||
@@ -359,9 +335,7 @@ function EmailSignInContext({
|
||||
variant="ghost"
|
||||
className="p-0 text-orange-500"
|
||||
disabled={recheckEmail}
|
||||
onClick={() => {
|
||||
handleMagicLinkSignIn();
|
||||
}}
|
||||
onClick={handleMagicLinkSignIn}
|
||||
>
|
||||
<ArrowsClockwise className="h-6 w-6 mr-2 text-gray-500" />
|
||||
Resend email
|
||||
|
||||
@@ -37,6 +37,7 @@ from khoj.database.adapters import (
|
||||
aget_or_create_user_by_phone_number,
|
||||
aget_user_by_phone_number,
|
||||
ais_user_subscribed,
|
||||
delete_ratelimit_records,
|
||||
delete_user_requests,
|
||||
get_all_users,
|
||||
get_or_create_search_models,
|
||||
@@ -428,8 +429,10 @@ def upload_telemetry():
|
||||
@schedule.repeat(schedule.every(31).minutes)
|
||||
@clean_connections
|
||||
def delete_old_user_requests():
|
||||
num_deleted = delete_user_requests()
|
||||
logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests")
|
||||
num_user_ratelimit_requests = delete_user_requests()
|
||||
num_ratelimit_requests = delete_ratelimit_records()
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"🗑️ Deleted {num_user_ratelimit_requests + num_ratelimit_requests} stale rate limit requests")
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(17).minutes)
|
||||
|
||||
@@ -27,6 +27,7 @@ from django.contrib.sessions.backends.db import SessionStore
|
||||
from django.db.models import Prefetch, Q
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.utils import IntegrityError
|
||||
from django.utils import timezone as django_timezone
|
||||
from django_apscheduler import util
|
||||
from django_apscheduler.models import DjangoJob, DjangoJobExecution
|
||||
from fastapi import HTTPException
|
||||
@@ -49,6 +50,7 @@ from khoj.database.models import (
|
||||
NotionConfig,
|
||||
ProcessLock,
|
||||
PublicConversation,
|
||||
RateLimitRecord,
|
||||
ReflectiveQuestion,
|
||||
SearchModelConfig,
|
||||
ServerChatSettings,
|
||||
@@ -233,20 +235,21 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
||||
return user
|
||||
|
||||
|
||||
async def aget_or_create_user_by_email(input_email: str) -> tuple[KhojUser, bool]:
|
||||
email, is_valid_email = normalize_email(input_email)
|
||||
async def aget_or_create_user_by_email(input_email: str, check_deliverability=False) -> tuple[KhojUser, bool]:
|
||||
# Validate deliverability to email address of new user
|
||||
email, is_valid_email = normalize_email(input_email, check_deliverability=check_deliverability)
|
||||
is_existing_user = await KhojUser.objects.filter(email=email).aexists()
|
||||
# Validate email address of new users
|
||||
if not is_existing_user and not is_valid_email:
|
||||
logger.error(f"Account creation failed. Invalid email address: {email}")
|
||||
return None, False
|
||||
|
||||
# Get/create user based on email address
|
||||
user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
|
||||
defaults={"username": email, "email": email}
|
||||
)
|
||||
|
||||
# Generate a secure 6-digit numeric code
|
||||
user.email_verification_code = f"{secrets.randbelow(1000000):06}"
|
||||
user.email_verification_code = f"{secrets.randbelow(int(1e6)):06}"
|
||||
user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5)
|
||||
await user.asave()
|
||||
|
||||
@@ -516,8 +519,18 @@ def get_user_notion_config(user: KhojUser):
|
||||
return config
|
||||
|
||||
|
||||
def delete_user_requests(window: timedelta = timedelta(days=1)):
|
||||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
||||
def delete_user_requests(max_age: timedelta = timedelta(days=1)):
|
||||
"""Deletes UserRequests entries older than the specified max_age."""
|
||||
cutoff = django_timezone.now() - max_age
|
||||
deleted_count, _ = UserRequests.objects.filter(created_at__lte=cutoff).delete()
|
||||
return deleted_count
|
||||
|
||||
|
||||
def delete_ratelimit_records(max_age: timedelta = timedelta(days=1)):
|
||||
"""Deletes RateLimitRecord entries older than the specified max_age."""
|
||||
cutoff = django_timezone.now() - max_age
|
||||
deleted_count, _ = RateLimitRecord.objects.filter(created_at__lt=cutoff).delete()
|
||||
return deleted_count
|
||||
|
||||
|
||||
@arequire_valid_user
|
||||
|
||||
@@ -25,6 +25,7 @@ from khoj.database.models import (
|
||||
KhojUser,
|
||||
NotionConfig,
|
||||
ProcessLock,
|
||||
RateLimitRecord,
|
||||
ReflectiveQuestion,
|
||||
SearchModelConfig,
|
||||
ServerChatSettings,
|
||||
@@ -179,6 +180,7 @@ admin.site.register(NotionConfig, unfold_admin.ModelAdmin)
|
||||
admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin)
|
||||
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
|
||||
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
|
||||
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
|
||||
|
||||
|
||||
@admin.register(Agent)
|
||||
|
||||
28
src/khoj/database/migrations/0088_ratelimitrecord.py
Normal file
28
src/khoj/database/migrations/0088_ratelimitrecord.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Generated by Django 5.0.13 on 2025-04-07 07:10
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0087_alter_aimodelapi_api_key"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="RateLimitRecord",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("identifier", models.CharField(db_index=True, max_length=255)),
|
||||
("slug", models.CharField(db_index=True, max_length=255)),
|
||||
],
|
||||
options={
|
||||
"ordering": ["-created_at"],
|
||||
"indexes": [
|
||||
models.Index(fields=["identifier", "slug", "created_at"], name="database_ra_identif_031adf_idx")
|
||||
],
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -730,10 +730,28 @@ class EntryDates(DbBaseModel):
|
||||
|
||||
|
||||
class UserRequests(DbBaseModel):
|
||||
"""Stores user requests to the server for rate limiting."""
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200)
|
||||
|
||||
|
||||
class RateLimitRecord(DbBaseModel):
|
||||
"""Stores individual request timestamps for rate limiting."""
|
||||
|
||||
identifier = models.CharField(max_length=255, db_index=True) # IP address or email
|
||||
slug = models.CharField(max_length=255, db_index=True) # Differentiates limit types
|
||||
|
||||
class Meta:
|
||||
indexes = [
|
||||
models.Index(fields=["identifier", "slug", "created_at"]),
|
||||
]
|
||||
ordering = ["-created_at"]
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.slug} - {self.identifier} at {self.created_at}"
|
||||
|
||||
|
||||
class DataStore(DbBaseModel):
|
||||
key = models.CharField(max_length=200, unique=True)
|
||||
value = models.JSONField(default=dict)
|
||||
|
||||
@@ -3,11 +3,10 @@ import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
from urllib.parse import urlencode, urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from starlette.authentication import requires
|
||||
from starlette.config import Config
|
||||
from starlette.requests import Request
|
||||
@@ -25,21 +24,20 @@ from khoj.database.adapters import (
|
||||
)
|
||||
from khoj.routers.email import send_magic_link_email, send_welcome_email
|
||||
from khoj.routers.helpers import (
|
||||
EmailAttemptRateLimiter,
|
||||
EmailVerificationApiRateLimiter,
|
||||
MagicLinkForm,
|
||||
get_next_url,
|
||||
update_telemetry_state,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
auth_router = APIRouter()
|
||||
|
||||
|
||||
class MagicLinkForm(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
if not state.anonymous_mode:
|
||||
missing_requirements = []
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
@@ -78,24 +76,36 @@ async def login(request: Request):
|
||||
|
||||
|
||||
@auth_router.post("/magic")
|
||||
async def login_magic_link(request: Request, form: MagicLinkForm):
|
||||
async def login_magic_link(
|
||||
request: Request,
|
||||
form: MagicLinkForm,
|
||||
email_limiter=Depends(EmailAttemptRateLimiter(requests=20, window=60 * 60 * 24, slug="magic_link_login_by_email")),
|
||||
):
|
||||
if request.user.is_authenticated:
|
||||
# Clear the session if user is already authenticated
|
||||
request.session.pop("user", None)
|
||||
|
||||
user, is_new = await aget_or_create_user_by_email(form.email)
|
||||
# Get/create user if valid email address
|
||||
check_deliverability = state.billing_enabled and not in_debug_mode()
|
||||
user, is_new = await aget_or_create_user_by_email(form.email, check_deliverability=check_deliverability)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="Invalid email address. Please fix before trying again.")
|
||||
|
||||
if user:
|
||||
unique_id = user.email_verification_code
|
||||
await send_magic_link_email(user.email, unique_id, request.base_url)
|
||||
if is_new:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user__email",
|
||||
metadata={"server_id": str(user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||
# Rate limit email login by user
|
||||
user_limiter = EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_login_by_user")
|
||||
await user_limiter(email=user.email)
|
||||
|
||||
# Send email with magic link
|
||||
unique_id = user.email_verification_code
|
||||
await send_magic_link_email(user.email, unique_id, request.base_url)
|
||||
if is_new:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user__email",
|
||||
metadata={"server_id": str(user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -33,8 +33,9 @@ import requests
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.utils import timezone as django_timezone
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
@@ -46,10 +47,10 @@ from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
aget_user_by_email,
|
||||
ais_user_subscribed,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
get_user_by_email,
|
||||
get_user_name,
|
||||
get_user_notion_config,
|
||||
get_user_subscription_state,
|
||||
@@ -64,6 +65,7 @@ from khoj.database.models import (
|
||||
KhojUser,
|
||||
NotionConfig,
|
||||
ProcessLock,
|
||||
RateLimitRecord,
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserRequests,
|
||||
@@ -112,6 +114,7 @@ from khoj.utils.helpers import (
|
||||
LRU,
|
||||
ConversationCommand,
|
||||
get_file_type,
|
||||
in_debug_mode,
|
||||
is_none_or_empty,
|
||||
is_valid_url,
|
||||
log_telemetry,
|
||||
@@ -1613,47 +1616,71 @@ class FeedbackData(BaseModel):
|
||||
sentiment: str
|
||||
|
||||
|
||||
class EmailVerificationApiRateLimiter:
|
||||
class MagicLinkForm(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class EmailAttemptRateLimiter:
|
||||
"""Rate limiter for email attempts BEFORE get/create user with valid email address."""
|
||||
|
||||
def __init__(self, requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
self.window = window
|
||||
self.window = window # Window in seconds
|
||||
self.slug = slug
|
||||
|
||||
def __call__(self, request: Request):
|
||||
# Rate limiting disabled if billing is disabled
|
||||
if state.billing_enabled is False:
|
||||
async def __call__(self, form: MagicLinkForm):
|
||||
# Disable login rate limiting in debug mode
|
||||
if in_debug_mode():
|
||||
return
|
||||
|
||||
# Extract the email query parameter
|
||||
email = request.query_params.get("email")
|
||||
# Calculate the time window cutoff
|
||||
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||
|
||||
if email:
|
||||
logger.info(f"Email query parameter: {email}")
|
||||
# Count recent attempts for this email and slug
|
||||
count = await RateLimitRecord.objects.filter(
|
||||
identifier=form.email, slug=self.slug, created_at__gte=cutoff
|
||||
).acount()
|
||||
|
||||
user: KhojUser = get_user_by_email(email)
|
||||
|
||||
if not user:
|
||||
if count >= self.requests:
|
||||
logger.warning(f"Email attempt rate limit exceeded for {form.email} (slug: {self.slug})")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="User not found.",
|
||||
status_code=429, detail="Too many requests for your email address. Please wait before trying again."
|
||||
)
|
||||
|
||||
# Record the current attempt
|
||||
await RateLimitRecord.objects.acreate(identifier=form.email, slug=self.slug)
|
||||
|
||||
|
||||
class EmailVerificationApiRateLimiter:
|
||||
"""Rate limiter for actions AFTER user with valid email address is known to exist"""
|
||||
|
||||
def __init__(self, requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
self.window = window # Window in seconds
|
||||
self.slug = slug
|
||||
|
||||
async def __call__(self, email: str = None):
|
||||
# Disable login rate limiting in debug mode
|
||||
if in_debug_mode():
|
||||
return
|
||||
|
||||
user: KhojUser = await aget_user_by_email(email)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found.")
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||
count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount()
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if count_requests >= self.requests:
|
||||
logger.info(
|
||||
logger.warning(
|
||||
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Ran out of login attempts",
|
||||
)
|
||||
raise HTTPException(status_code=429, detail="Ran out of login attempts. Please wait before trying again.")
|
||||
|
||||
# Add the current request to the db
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
await UserRequests.objects.acreate(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
@@ -1677,7 +1704,7 @@ class ApiUserRateLimiter:
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
||||
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
@@ -1779,7 +1806,7 @@ class ConversationCommandRateLimiter:
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
# Remove requests outside of the 24-hr time window
|
||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
|
||||
cutoff = django_timezone.now() - timedelta(seconds=60 * 60 * 24)
|
||||
command_slug = f"{self.slug}_{conversation_command.value}"
|
||||
count_requests = await UserRequests.objects.filter(
|
||||
user=user, created_at__gte=cutoff, slug=command_slug
|
||||
|
||||
Reference in New Issue
Block a user