mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Rate limit calls to the /chat API per user, per day/minute
This commit is contained in:
@@ -7,7 +7,7 @@ import json
|
|||||||
from typing import List, Optional, Union, Any
|
from typing import List, Optional, Union, Any
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
@@ -36,6 +36,7 @@ from khoj.routers.helpers import (
|
|||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
|
ApiUserRateLimiter,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.prompts import help_message
|
from khoj.processor.conversation.prompts import help_message
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
@@ -587,6 +588,8 @@ async def chat(
|
|||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||||
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
import logging
|
# Standard Packages
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import logging
|
||||||
|
from time import time
|
||||||
from typing import Iterator, List, Optional, Union, Tuple, Dict
|
from typing import Iterator, List, Optional, Union, Tuple, Dict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
|
# External Packages
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
@@ -16,6 +21,7 @@ from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
|||||||
from database.models import KhojUser, Subscription
|
from database.models import KhojUser, Subscription
|
||||||
from database.adapters import ConversationAdapters
|
from database.adapters import ConversationAdapters
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
@@ -191,3 +197,26 @@ def generate_chat_response(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class ApiUserRateLimiter:
|
||||||
|
def __init__(self, requests: int, window: int):
|
||||||
|
self.requests = requests
|
||||||
|
self.window = window
|
||||||
|
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
def __call__(self, request: Request):
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
user_requests = self.cache[user.uuid]
|
||||||
|
|
||||||
|
# Remove requests outside of the time window
|
||||||
|
cutoff = time() - self.window
|
||||||
|
while user_requests and user_requests[0] < cutoff:
|
||||||
|
user_requests.pop(0)
|
||||||
|
|
||||||
|
# Check if the user has exceeded the rate limit
|
||||||
|
if len(user_requests) >= self.requests:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
|
||||||
|
# Add the current request to the cache
|
||||||
|
user_requests.append(time())
|
||||||
|
|||||||
Reference in New Issue
Block a user