mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 21:29:13 +00:00
Add websocket chat api endpoint to ease bi-directional communication
This commit is contained in:
@@ -220,7 +220,16 @@ def set_state(args):
|
|||||||
def start_server(app, host=None, port=None, socket=None):
|
def start_server(app, host=None, port=None, socket=None):
|
||||||
logger.info("🌖 Khoj is ready to engage")
|
logger.info("🌖 Khoj is ready to engage")
|
||||||
if socket:
|
if socket:
|
||||||
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None)
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
proxy_headers=True,
|
||||||
|
uds=socket,
|
||||||
|
log_level="debug" if state.verbose > 1 else "info",
|
||||||
|
use_colors=True,
|
||||||
|
log_config=None,
|
||||||
|
ws_ping_timeout=300,
|
||||||
|
timeout_keep_alive=60,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
@@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None):
|
|||||||
log_level="debug" if state.verbose > 1 else "info",
|
log_level="debug" if state.verbose > 1 else "info",
|
||||||
use_colors=True,
|
use_colors=True,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
|
ws_ping_timeout=300,
|
||||||
timeout_keep_alive=60,
|
timeout_keep_alive=60,
|
||||||
**state.ssl_config if state.ssl_config else {},
|
**state.ssl_config if state.ssl_config else {},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional
|
|||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
WebSocket,
|
||||||
|
WebSocketDisconnect,
|
||||||
|
)
|
||||||
from fastapi.responses import RedirectResponse, Response, StreamingResponse
|
from fastapi.responses import RedirectResponse, Response, StreamingResponse
|
||||||
|
from fastapi.websockets import WebSocketState
|
||||||
from starlette.authentication import has_required_scope, requires
|
from starlette.authentication import has_required_scope, requires
|
||||||
|
from starlette.requests import Headers
|
||||||
|
|
||||||
from khoj.app.settings import ALLOWED_HOSTS
|
from khoj.app.settings import ALLOWED_HOSTS
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
@@ -657,19 +666,12 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -
|
|||||||
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@api_chat.post("")
|
async def event_generator(
|
||||||
@requires(["authenticated"])
|
|
||||||
async def chat(
|
|
||||||
request: Request,
|
|
||||||
common: CommonQueryParams,
|
|
||||||
body: ChatRequestBody,
|
body: ChatRequestBody,
|
||||||
rate_limiter_per_minute=Depends(
|
user_scope: Any,
|
||||||
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
common: CommonQueryParams,
|
||||||
),
|
headers: Headers,
|
||||||
rate_limiter_per_day=Depends(
|
request_obj: Request | WebSocket,
|
||||||
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
|
||||||
),
|
|
||||||
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
|
||||||
):
|
):
|
||||||
# Access the parameters from the body
|
# Access the parameters from the body
|
||||||
q = body.q
|
q = body.q
|
||||||
@@ -688,17 +690,14 @@ async def chat(
|
|||||||
raw_query_files = body.files
|
raw_query_files = body.files
|
||||||
interrupt_flag = body.interrupt
|
interrupt_flag = body.interrupt
|
||||||
|
|
||||||
async def event_generator(q: str, images: list[str]):
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ttft = None
|
ttft = None
|
||||||
chat_metadata: dict = {}
|
chat_metadata: dict = {}
|
||||||
conversation = None
|
conversation = None
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = user_scope.object
|
||||||
is_subscribed = has_required_scope(request, ["premium"])
|
is_subscribed = has_required_scope(request_obj, ["premium"])
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
train_of_thought = []
|
train_of_thought = []
|
||||||
nonlocal conversation_id
|
|
||||||
nonlocal raw_query_files
|
|
||||||
cancellation_event = asyncio.Event()
|
cancellation_event = asyncio.Event()
|
||||||
|
|
||||||
tracer: dict = {
|
tracer: dict = {
|
||||||
@@ -709,13 +708,13 @@ async def chat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
uploaded_images: list[str] = []
|
uploaded_images: list[str] = []
|
||||||
if images:
|
if raw_images:
|
||||||
for image in images:
|
for image in raw_images:
|
||||||
decoded_string = unquote(image)
|
decoded_string = unquote(image)
|
||||||
base64_data = decoded_string.split(",", 1)[1]
|
base64_data = decoded_string.split(",", 1)[1]
|
||||||
image_bytes = base64.b64decode(base64_data)
|
image_bytes = base64.b64decode(base64_data)
|
||||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||||
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
|
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id)
|
||||||
if not uploaded_image:
|
if not uploaded_image:
|
||||||
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
|
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
|
||||||
@@ -739,14 +738,14 @@ async def chat(
|
|||||||
generated_mermaidjs_diagram: str = None
|
generated_mermaidjs_diagram: str = None
|
||||||
generated_asset_results: Dict = dict()
|
generated_asset_results: Dict = dict()
|
||||||
program_execution_context: List[str] = []
|
program_execution_context: List[str] = []
|
||||||
chat_history: List[ChatMessageModel] = []
|
|
||||||
|
|
||||||
# Create a task to monitor for disconnections
|
# Create a task to monitor for disconnections
|
||||||
disconnect_monitor_task = None
|
disconnect_monitor_task = None
|
||||||
|
|
||||||
async def monitor_disconnection():
|
async def monitor_disconnection():
|
||||||
|
if isinstance(request_obj, Request):
|
||||||
try:
|
try:
|
||||||
msg = await request.receive()
|
msg = await request_obj.receive()
|
||||||
if msg["type"] == "http.disconnect":
|
if msg["type"] == "http.disconnect":
|
||||||
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
|
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
|
||||||
cancellation_event.set()
|
cancellation_event.set()
|
||||||
@@ -765,7 +764,7 @@ async def chat(
|
|||||||
operator_results=operator_results,
|
operator_results=operator_results,
|
||||||
research_results=research_results,
|
research_results=research_results,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
client_application=request.user.client_app,
|
client_application=user_scope.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
@@ -778,6 +777,36 @@ async def chat(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in disconnect monitor: {e}")
|
logger.error(f"Error in disconnect monitor: {e}")
|
||||||
|
elif isinstance(request_obj, WebSocket):
|
||||||
|
while request_obj.client_state == WebSocketState.CONNECTED:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.")
|
||||||
|
cancellation_event.set()
|
||||||
|
if conversation:
|
||||||
|
await asyncio.shield(
|
||||||
|
save_to_conversation_log(
|
||||||
|
q,
|
||||||
|
chat_response="",
|
||||||
|
user=user,
|
||||||
|
chat_history=chat_history,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
code_results=code_results,
|
||||||
|
operator_results=operator_results,
|
||||||
|
research_results=research_results,
|
||||||
|
inferred_queries=inferred_queries,
|
||||||
|
client_application=user_scope.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query_images=uploaded_images,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
raw_generated_files=generated_asset_results,
|
||||||
|
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Cancel the disconnect monitor task if it is still running
|
# Cancel the disconnect monitor task if it is still running
|
||||||
async def cancel_disconnect_monitor():
|
async def cancel_disconnect_monitor():
|
||||||
@@ -864,12 +893,12 @@ async def chat(
|
|||||||
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
||||||
logger.info(f"Chat response cost: ${cost:.5f}")
|
logger.info(f"Chat response cost: ${cost:.5f}")
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request_obj,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="chat",
|
api="chat",
|
||||||
client=common.client,
|
client=common.client,
|
||||||
user_agent=request.headers.get("user-agent"),
|
user_agent=headers.get("user-agent"),
|
||||||
host=request.headers.get("host"),
|
host=headers.get("host"),
|
||||||
metadata=chat_metadata,
|
metadata=chat_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -894,7 +923,7 @@ async def chat(
|
|||||||
|
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
user,
|
user,
|
||||||
client_application=request.user.client_app,
|
client_application=user_scope.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
title=title,
|
title=title,
|
||||||
create_new=body.create_new,
|
create_new=body.create_new,
|
||||||
@@ -935,7 +964,7 @@ async def chat(
|
|||||||
# Refresh conversation to check if interrupted message saved to DB
|
# Refresh conversation to check if interrupted message saved to DB
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
user,
|
user,
|
||||||
client_application=request.user.client_app,
|
client_application=user_scope.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
@@ -1006,7 +1035,7 @@ async def chat(
|
|||||||
cmds_to_rate_limit += conversation_commands
|
cmds_to_rate_limit += conversation_commands
|
||||||
for cmd in cmds_to_rate_limit:
|
for cmd in cmds_to_rate_limit:
|
||||||
try:
|
try:
|
||||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd)
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
||||||
@@ -1091,9 +1120,7 @@ async def chat(
|
|||||||
inferred_queries.extend(result[1])
|
inferred_queries.extend(result[1])
|
||||||
defiltered_query = result[2]
|
defiltered_query = result[2]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = (
|
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||||
f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
|
||||||
)
|
|
||||||
logger.error(error_message, exc_info=True)
|
logger.error(error_message, exc_info=True)
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||||
@@ -1343,9 +1370,7 @@ async def chat(
|
|||||||
else:
|
else:
|
||||||
error_message = "Failed to generate diagram. Please try again later."
|
error_message = "Failed to generate diagram. Please try again later."
|
||||||
program_execution_context.append(
|
program_execution_context.append(
|
||||||
prompts.failed_diagram_generation.format(
|
prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt)
|
||||||
attempted_diagram=better_diagram_description_prompt
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_event(ChatEvent.STATUS, error_message):
|
async for result in send_event(ChatEvent.STATUS, error_message):
|
||||||
@@ -1423,7 +1448,7 @@ async def chat(
|
|||||||
operator_results=operator_results,
|
operator_results=operator_results,
|
||||||
research_results=research_results,
|
research_results=research_results,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
client_application=request.user.client_app,
|
client_application=user_scope.client_app,
|
||||||
conversation_id=str(conversation.id),
|
conversation_id=str(conversation.id),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
@@ -1450,11 +1475,119 @@ async def chat(
|
|||||||
# Cancel the disconnect monitor task if it is still running
|
# Cancel the disconnect monitor task if it is still running
|
||||||
await cancel_disconnect_monitor()
|
await cancel_disconnect_monitor()
|
||||||
|
|
||||||
## Stream Text Response
|
|
||||||
if stream:
|
@api_chat.websocket("/ws")
|
||||||
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
@requires(["authenticated"])
|
||||||
## Non-Streaming Text Response
|
async def chat_ws(
|
||||||
|
websocket: WebSocket,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
):
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# Initialize rate limiters
|
||||||
|
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
||||||
|
rate_limiter_per_day = ApiUserRateLimiter(
|
||||||
|
requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day"
|
||||||
|
)
|
||||||
|
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
|
||||||
|
|
||||||
|
current_task = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_json()
|
||||||
|
|
||||||
|
# Handle regular chat messages
|
||||||
|
# Handle regular chat messages - ensure data has required fields
|
||||||
|
if "q" not in data:
|
||||||
|
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
|
||||||
|
continue
|
||||||
|
|
||||||
|
body = ChatRequestBody(**data)
|
||||||
|
|
||||||
|
# Apply rate limiting manually
|
||||||
|
try:
|
||||||
|
rate_limiter_per_minute.check_websocket(websocket)
|
||||||
|
rate_limiter_per_day.check_websocket(websocket)
|
||||||
|
image_rate_limiter.check_websocket(websocket, body)
|
||||||
|
except HTTPException as e:
|
||||||
|
await websocket.send_text(json.dumps({"error": e.detail}))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Cancel any ongoing task before starting a new one
|
||||||
|
if current_task and not current_task.done():
|
||||||
|
current_task.cancel()
|
||||||
|
try:
|
||||||
|
await current_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create a new task for processing the chat request
|
||||||
|
current_task = asyncio.create_task(process_chat_request(websocket, body, common))
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
|
||||||
|
if current_task and not current_task.done():
|
||||||
|
current_task.cancel()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in websocket chat: {e}", exc_info=True)
|
||||||
|
if current_task and not current_task.done():
|
||||||
|
current_task.cancel()
|
||||||
|
await websocket.close(code=1011, reason="Internal Server Error")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_chat_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
body: ChatRequestBody,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
):
|
||||||
|
"""Process a single chat request with interrupt support"""
|
||||||
|
try:
|
||||||
|
# Since we are using websockets, we can ignore the stream parameter and always stream
|
||||||
|
response_iterator = event_generator(
|
||||||
|
body,
|
||||||
|
websocket.scope["user"],
|
||||||
|
common,
|
||||||
|
websocket.headers,
|
||||||
|
websocket,
|
||||||
|
)
|
||||||
|
async for event in response_iterator:
|
||||||
|
await websocket.send_text(event)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing chat request: {e}", exc_info=True)
|
||||||
|
await websocket.send_text(json.dumps({"error": "Internal server error"}))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@api_chat.post("")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def chat(
|
||||||
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
body: ChatRequestBody,
|
||||||
|
rate_limiter_per_minute=Depends(
|
||||||
|
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
||||||
|
),
|
||||||
|
rate_limiter_per_day=Depends(
|
||||||
|
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
|
),
|
||||||
|
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
||||||
|
):
|
||||||
|
response_iterator = event_generator(
|
||||||
|
body,
|
||||||
|
request.user,
|
||||||
|
common,
|
||||||
|
request.headers,
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream Text Response
|
||||||
|
if body.stream:
|
||||||
|
return StreamingResponse(response_iterator, media_type="text/plain")
|
||||||
|
# Non-Streaming Text Response
|
||||||
else:
|
else:
|
||||||
response_iterator = event_generator(q, images=raw_images)
|
|
||||||
response_data = await read_chat_stream(response_iterator)
|
response_data = await read_chat_stream(response_iterator)
|
||||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from apscheduler.job import Job
|
|||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.utils import timezone as django_timezone
|
from django.utils import timezone as django_timezone
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
|
||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
@@ -1936,6 +1936,53 @@ class ApiUserRateLimiter:
|
|||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
UserRequests.objects.create(user=user, slug=self.slug)
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
def check_websocket(self, websocket: WebSocket):
|
||||||
|
"""WebSocket-specific rate limiting method"""
|
||||||
|
# Rate limiting disabled if billing is disabled
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
user: KhojUser = websocket.scope["user"].object
|
||||||
|
subscribed = has_required_scope(websocket, ["premium"])
|
||||||
|
|
||||||
|
# Remove requests outside of the time window
|
||||||
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||||
|
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||||
|
|
||||||
|
# Check if the user has exceeded the rate limit
|
||||||
|
if subscribed and count_requests >= self.subscribed_requests:
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
||||||
|
)
|
||||||
|
if not subscribed and count_requests >= self.requests:
|
||||||
|
if self.requests >= self.subscribed_requests:
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the current request to the cache
|
||||||
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
class ApiImageRateLimiter:
|
class ApiImageRateLimiter:
|
||||||
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||||
@@ -1983,6 +2030,47 @@ class ApiImageRateLimiter:
|
|||||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
|
||||||
|
"""WebSocket-specific image rate limiting method"""
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not body.images:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check number of images
|
||||||
|
if len(body.images) > self.max_images:
|
||||||
|
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total size of images
|
||||||
|
total_size_mb = 0.0
|
||||||
|
for image in body.images:
|
||||||
|
# Unquote the image in case it's URL encoded
|
||||||
|
image = unquote(image)
|
||||||
|
# Assuming the image is a base64 encoded string
|
||||||
|
# Remove the data:image/jpeg;base64, part if present
|
||||||
|
if "," in image:
|
||||||
|
image = image.split(",", 1)[1]
|
||||||
|
|
||||||
|
# Decode base64 to get the actual size
|
||||||
|
image_bytes = base64.b64decode(image)
|
||||||
|
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||||
|
|
||||||
|
if total_size_mb > self.max_combined_size_mb:
|
||||||
|
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommandRateLimiter:
|
class ConversationCommandRateLimiter:
|
||||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||||
@@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter:
|
|||||||
self.subscribed_rate_limit = subscribed_rate_limit
|
self.subscribed_rate_limit = subscribed_rate_limit
|
||||||
self.restricted_commands = [ConversationCommand.Research]
|
self.restricted_commands = [ConversationCommand.Research]
|
||||||
|
|
||||||
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
|
||||||
if state.billing_enabled is False:
|
if state.billing_enabled is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user