From e8fb79a3695083ae8c43be6d17dda2f6262b01ce Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 22 Oct 2024 04:26:55 -0700 Subject: [PATCH] Rate limit the count and total size of images shared via API --- src/interface/web/app/chat/page.tsx | 9 ++- .../chatInputArea/chatInputArea.module.css | 1 - src/khoj/routers/api_chat.py | 19 +----- src/khoj/routers/helpers.py | 67 ++++++++++++++++++- 4 files changed, 75 insertions(+), 21 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index d25d8222..b1524ea9 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -265,7 +265,8 @@ export default function Chat() { try { await readChatStream(response); } catch (err) { - console.error(err); + const apiError = await response.json(); + console.error(apiError); // Retrieve latest message being processed const currentMessage = messages.find((message) => !message.completed); if (!currentMessage) return; @@ -274,7 +275,11 @@ export default function Chat() { const errorMessage = (err as Error).message; if (errorMessage.includes("Error in input stream")) currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`; - else + else if (response.status === 429) { + "detail" in apiError + ? (currentMessage.rawResponse = `${apiError.detail}`) + : (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`); + } else currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`; // Complete message streaming teardown properly diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.module.css b/src/interface/web/app/components/chatInputArea/chatInputArea.module.css index 5561b158..cfee75f1 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.module.css +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.module.css @@ -1,5 +1,4 @@ div.actualInputArea { display: grid; grid-template-columns: auto 1fr auto auto; - max-width: 700px; } diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index ee84c554..a4d213c8 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( + ApiImageRateLimiter, ApiUserRateLimiter, ChatEvent, + ChatRequestBody, CommonQueryParams, ConversationCommandRateLimiter, agenerate_chat_response, @@ -523,22 +525,6 @@ async def set_conversation_title( ) -class ChatRequestBody(BaseModel): - q: str - n: Optional[int] = 7 - d: Optional[float] = None - stream: Optional[bool] = False - title: Optional[str] = None - conversation_id: Optional[str] = None - city: Optional[str] = None - region: Optional[str] = None - country: Optional[str] = None - country_code: Optional[str] = None - timezone: Optional[str] = None - images: Optional[list[str]] = None - create_new: Optional[bool] = False - - @api_chat.post("") @requires(["authenticated"]) async def chat( @@ -551,6 +537,7 @@ async def chat( rate_limiter_per_day=Depends( ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day") ), + image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=10)), ): # Access the parameters from the body q = body.q diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 739a3ad6..cde38eb7 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,5 @@ import asyncio +import base64 import hashlib import json import logging @@ -21,7 +22,7 @@ from typing import ( Tuple, Union, ) -from urllib.parse import parse_qs, quote, urljoin, urlparse +from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse import cron_descriptor import pytz @@ -30,6 +31,7 @@ from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from fastapi import Depends, Header, HTTPException, Request, UploadFile +from pydantic import BaseModel from starlette.authentication import has_required_scope from starlette.requests import URL @@ -1019,6 +1021,22 @@ def generate_chat_response( return chat_response, metadata +class ChatRequestBody(BaseModel): + q: str + n: Optional[int] = 7 + d: Optional[float] = None + stream: Optional[bool] = False + title: Optional[str] = None + conversation_id: Optional[str] = None + city: Optional[str] = None + region: Optional[str] = None + country: Optional[str] = None + country_code: Optional[str] = None + timezone: Optional[str] = None + images: Optional[list[str]] = None + create_new: Optional[bool] = False + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): self.requests = requests @@ -1064,13 +1082,58 @@ class ApiUserRateLimiter: ) raise HTTPException( status_code=429, - detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).", + detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).", ) # Add the current request to the cache UserRequests.objects.create(user=user, slug=self.slug) +class ApiImageRateLimiter: + def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10): + self.max_images = max_images + self.max_combined_size_mb = max_combined_size_mb + + def __call__(self, request: Request, body: ChatRequestBody): + if state.billing_enabled is False: + return + + # Rate limiting is disabled if user unauthenticated. + # Other systems handle authentication + if not request.user.is_authenticated: + return + + if not body.images: + return + + # Check number of images + if len(body.images) > self.max_images: + raise HTTPException( + status_code=429, + detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.", + ) + + # Check total size of images + total_size_mb = 0.0 + for image in body.images: + # Unquote the image in case it's URL encoded + image = unquote(image) + # Assuming the image is a base64 encoded string + # Remove the data:image/jpeg;base64, part if present + if "," in image: + image = image.split(",", 1)[1] + + # Decode base64 to get the actual size + image_bytes = base64.b64decode(image) + total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB + + if total_size_mb > self.max_combined_size_mb: + raise HTTPException( + status_code=429, + detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", + ) + + class ConversationCommandRateLimiter: def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): self.slug = slug