Rate limit calls to the /chat API per user, per day/minute

This commit is contained in:
Debanjum Singh Solanky
2023-11-13 18:36:30 -08:00
parent 33a8eb0470
commit e9adb58c16
2 changed files with 35 additions and 3 deletions

View File

@@ -7,7 +7,7 @@ import json
from typing import List, Optional, Union, Any
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from starlette.authentication import requires
from asgiref.sync import sync_to_async
@@ -36,6 +36,7 @@ from khoj.routers.helpers import (
agenerate_chat_response,
update_telemetry_state,
is_ready_to_chat,
ApiUserRateLimiter,
)
from khoj.processor.conversation.prompts import help_message
from khoj.processor.conversation.openai.gpt import extract_questions
@@ -587,6 +588,8 @@ async def chat(
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
) -> Response:
user = request.user.object

View File

@@ -1,12 +1,17 @@
import logging
# Standard Packages
import asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
import logging
from time import time
from typing import Iterator, List, Optional, Union, Tuple, Dict
from concurrent.futures import ThreadPoolExecutor
# External Packages
from fastapi import HTTPException, Request
# Internal Packages
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
@@ -16,6 +21,7 @@ from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser, Subscription
from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
@@ -191,3 +197,26 @@ def generate_chat_response(
raise HTTPException(status_code=500, detail=str(e))
return chat_response, metadata
class ApiUserRateLimiter:
def __init__(self, requests: int, window: int):
self.requests = requests
self.window = window
self.cache: dict[str, list[float]] = defaultdict(list)
def __call__(self, request: Request):
user: KhojUser = request.user.object
user_requests = self.cache[user.uuid]
# Remove requests outside of the time window
cutoff = time() - self.window
while user_requests and user_requests[0] < cutoff:
user_requests.pop(0)
# Check if the user has exceeded the rate limit
if len(user_requests) >= self.requests:
raise HTTPException(status_code=429, detail="Too Many Requests")
# Add the current request to the cache
user_requests.append(time())