mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Rate limit calls to the /chat API per user, per day/minute
This commit is contained in:
@@ -7,7 +7,7 @@ import json
|
||||
from typing import List, Optional, Union, Any
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
@@ -36,6 +36,7 @@ from khoj.routers.helpers import (
|
||||
agenerate_chat_response,
|
||||
update_telemetry_state,
|
||||
is_ready_to_chat,
|
||||
ApiUserRateLimiter,
|
||||
)
|
||||
from khoj.processor.conversation.prompts import help_message
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
@@ -587,6 +588,8 @@ async def chat(
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user = request.user.object
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import logging
|
||||
# Standard Packages
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import logging
|
||||
from time import time
|
||||
from typing import Iterator, List, Optional, Union, Tuple, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# External Packages
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
@@ -16,6 +21,7 @@ from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import ConversationAdapters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
@@ -191,3 +197,26 @@ def generate_chat_response(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return chat_response, metadata
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, window: int):
|
||||
self.requests = requests
|
||||
self.window = window
|
||||
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def __call__(self, request: Request):
|
||||
user: KhojUser = request.user.object
|
||||
user_requests = self.cache[user.uuid]
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = time() - self.window
|
||||
while user_requests and user_requests[0] < cutoff:
|
||||
user_requests.pop(0)
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if len(user_requests) >= self.requests:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||
|
||||
# Add the current request to the cache
|
||||
user_requests.append(time())
|
||||
|
||||
Reference in New Issue
Block a user