From e9adb58c165ceed384a269e7b5fa90d996cb2919 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 13 Nov 2023 18:36:30 -0800 Subject: [PATCH] Rate limit calls to the /chat API per user, per day/minute --- src/khoj/routers/api.py | 5 ++++- src/khoj/routers/helpers.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index fbdfbd63..2bf7fd6f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -7,7 +7,7 @@ import json from typing import List, Optional, Union, Any # External Packages -from fastapi import APIRouter, HTTPException, Header, Request +from fastapi import APIRouter, Depends, HTTPException, Header, Request from starlette.authentication import requires from asgiref.sync import sync_to_async @@ -36,6 +36,7 @@ from khoj.routers.helpers import ( agenerate_chat_response, update_telemetry_state, is_ready_to_chat, + ApiUserRateLimiter, ) from khoj.processor.conversation.prompts import help_message from khoj.processor.conversation.openai.gpt import extract_questions @@ -587,6 +588,8 @@ async def chat( user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), host: Optional[str] = Header(None), + rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)), + rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)), ) -> Response: user = request.user.object diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 46ef0641..2fc7ab79 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,12 +1,17 @@ -import logging +# Standard Packages import asyncio +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from functools import partial +import logging +from time import time from typing import Iterator, List, Optional, Union, Tuple, Dict -from concurrent.futures import ThreadPoolExecutor +# External Packages from fastapi import HTTPException, Request +# Internal Packages from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, log_telemetry @@ -16,6 +21,7 @@ from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator from database.models import KhojUser, Subscription from database.adapters import ConversationAdapters + logger = logging.getLogger(__name__) executor = ThreadPoolExecutor(max_workers=1) @@ -191,3 +197,26 @@ def generate_chat_response( raise HTTPException(status_code=500, detail=str(e)) return chat_response, metadata + + +class ApiUserRateLimiter: + def __init__(self, requests: int, window: int): + self.requests = requests + self.window = window + self.cache: dict[str, list[float]] = defaultdict(list) + + def __call__(self, request: Request): + user: KhojUser = request.user.object + user_requests = self.cache[user.uuid] + + # Remove requests outside of the time window + cutoff = time() - self.window + while user_requests and user_requests[0] < cutoff: + user_requests.pop(0) + + # Check if the user has exceeded the rate limit + if len(user_requests) >= self.requests: + raise HTTPException(status_code=429, detail="Too Many Requests") + + # Add the current request to the cache + user_requests.append(time())