mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 05:40:17 +00:00
Rate limit the count and total size of images shared via API
This commit is contained in:
@@ -265,7 +265,8 @@ export default function Chat() {
|
|||||||
try {
|
try {
|
||||||
await readChatStream(response);
|
await readChatStream(response);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
const apiError = await response.json();
|
||||||
|
console.error(apiError);
|
||||||
// Retrieve latest message being processed
|
// Retrieve latest message being processed
|
||||||
const currentMessage = messages.find((message) => !message.completed);
|
const currentMessage = messages.find((message) => !message.completed);
|
||||||
if (!currentMessage) return;
|
if (!currentMessage) return;
|
||||||
@@ -274,7 +275,11 @@ export default function Chat() {
|
|||||||
const errorMessage = (err as Error).message;
|
const errorMessage = (err as Error).message;
|
||||||
if (errorMessage.includes("Error in input stream"))
|
if (errorMessage.includes("Error in input stream"))
|
||||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||||
else
|
else if (response.status === 429) {
|
||||||
|
"detail" in apiError
|
||||||
|
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||||
|
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||||
|
} else
|
||||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||||
|
|
||||||
// Complete message streaming teardown properly
|
// Complete message streaming teardown properly
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
div.actualInputArea {
|
div.actualInputArea {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: auto 1fr auto auto;
|
grid-template-columns: auto 1fr auto auto;
|
||||||
max-width: 700px;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
|||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
ApiImageRateLimiter,
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
|
ChatRequestBody,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
@@ -523,22 +525,6 @@ async def set_conversation_title(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
|
||||||
q: str
|
|
||||||
n: Optional[int] = 7
|
|
||||||
d: Optional[float] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
title: Optional[str] = None
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
city: Optional[str] = None
|
|
||||||
region: Optional[str] = None
|
|
||||||
country: Optional[str] = None
|
|
||||||
country_code: Optional[str] = None
|
|
||||||
timezone: Optional[str] = None
|
|
||||||
images: Optional[list[str]] = None
|
|
||||||
create_new: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@api_chat.post("")
|
@api_chat.post("")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
@@ -551,6 +537,7 @@ async def chat(
|
|||||||
rate_limiter_per_day=Depends(
|
rate_limiter_per_day=Depends(
|
||||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||||
),
|
),
|
||||||
|
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=10)),
|
||||||
):
|
):
|
||||||
# Access the parameters from the body
|
# Access the parameters from the body
|
||||||
q = body.q
|
q = body.q
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -21,7 +22,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import pytz
|
import pytz
|
||||||
@@ -30,6 +31,7 @@ from apscheduler.job import Job
|
|||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
|
|
||||||
@@ -1019,6 +1021,22 @@ def generate_chat_response(
|
|||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestBody(BaseModel):
|
||||||
|
q: str
|
||||||
|
n: Optional[int] = 7
|
||||||
|
d: Optional[float] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
title: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
city: Optional[str] = None
|
||||||
|
region: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
country_code: Optional[str] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
images: Optional[list[str]] = None
|
||||||
|
create_new: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
@@ -1064,13 +1082,58 @@ class ApiUserRateLimiter:
|
|||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
UserRequests.objects.create(user=user, slug=self.slug)
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiImageRateLimiter:
|
||||||
|
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||||
|
self.max_images = max_images
|
||||||
|
self.max_combined_size_mb = max_combined_size_mb
|
||||||
|
|
||||||
|
def __call__(self, request: Request, body: ChatRequestBody):
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
# Other systems handle authentication
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not body.images:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check number of images
|
||||||
|
if len(body.images) > self.max_images:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total size of images
|
||||||
|
total_size_mb = 0.0
|
||||||
|
for image in body.images:
|
||||||
|
# Unquote the image in case it's URL encoded
|
||||||
|
image = unquote(image)
|
||||||
|
# Assuming the image is a base64 encoded string
|
||||||
|
# Remove the data:image/jpeg;base64, part if present
|
||||||
|
if "," in image:
|
||||||
|
image = image.split(",", 1)[1]
|
||||||
|
|
||||||
|
# Decode base64 to get the actual size
|
||||||
|
image_bytes = base64.b64decode(image)
|
||||||
|
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||||
|
|
||||||
|
if total_size_mb > self.max_combined_size_mb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommandRateLimiter:
|
class ConversationCommandRateLimiter:
|
||||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||||
self.slug = slug
|
self.slug = slug
|
||||||
|
|||||||
Reference in New Issue
Block a user