Add websocket chat api endpoint to ease bi-directional communication

This commit is contained in:
Debanjum
2025-06-18 16:45:09 -07:00
parent 99ed796c00
commit 38dd85c91f
3 changed files with 939 additions and 708 deletions

View File

@@ -220,7 +220,16 @@ def set_state(args):
def start_server(app, host=None, port=None, socket=None): def start_server(app, host=None, port=None, socket=None):
logger.info("🌖 Khoj is ready to engage") logger.info("🌖 Khoj is ready to engage")
if socket: if socket:
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None) uvicorn.run(
app,
proxy_headers=True,
uds=socket,
log_level="debug" if state.verbose > 1 else "info",
use_colors=True,
log_config=None,
ws_ping_timeout=300,
timeout_keep_alive=60,
)
else: else:
uvicorn.run( uvicorn.run(
app, app,
@@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None):
log_level="debug" if state.verbose > 1 else "info", log_level="debug" if state.verbose > 1 else "info",
use_colors=True, use_colors=True,
log_config=None, log_config=None,
ws_ping_timeout=300,
timeout_keep_alive=60, timeout_keep_alive=60,
**state.ssl_config if state.ssl_config else {}, **state.ssl_config if state.ssl_config else {},
) )

View File

@@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import RedirectResponse, Response, StreamingResponse from fastapi.responses import RedirectResponse, Response, StreamingResponse
from fastapi.websockets import WebSocketState
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
from starlette.requests import Headers
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
@@ -657,19 +666,12 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404) return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
@api_chat.post("") async def event_generator(
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
body: ChatRequestBody, body: ChatRequestBody,
rate_limiter_per_minute=Depends( user_scope: Any,
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") common: CommonQueryParams,
), headers: Headers,
rate_limiter_per_day=Depends( request_obj: Request | WebSocket,
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
): ):
# Access the parameters from the body # Access the parameters from the body
q = body.q q = body.q
@@ -688,17 +690,14 @@ async def chat(
raw_query_files = body.files raw_query_files = body.files
interrupt_flag = body.interrupt interrupt_flag = body.interrupt
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
conversation = None conversation = None
user: KhojUser = request.user.object user: KhojUser = user_scope.object
is_subscribed = has_required_scope(request, ["premium"]) is_subscribed = has_required_scope(request_obj, ["premium"])
q = unquote(q) q = unquote(q)
train_of_thought = [] train_of_thought = []
nonlocal conversation_id
nonlocal raw_query_files
cancellation_event = asyncio.Event() cancellation_event = asyncio.Event()
tracer: dict = { tracer: dict = {
@@ -709,13 +708,13 @@ async def chat(
} }
uploaded_images: list[str] = [] uploaded_images: list[str] = []
if images: if raw_images:
for image in images: for image in raw_images:
decoded_string = unquote(image) decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1] base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data) image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes) webp_image_bytes = convert_image_to_webp(image_bytes)
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id) uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id)
if not uploaded_image: if not uploaded_image:
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
uploaded_image = f"data:image/webp;base64,{base64_webp_image}" uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
@@ -739,14 +738,14 @@ async def chat(
generated_mermaidjs_diagram: str = None generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict() generated_asset_results: Dict = dict()
program_execution_context: List[str] = [] program_execution_context: List[str] = []
chat_history: List[ChatMessageModel] = []
# Create a task to monitor for disconnections # Create a task to monitor for disconnections
disconnect_monitor_task = None disconnect_monitor_task = None
async def monitor_disconnection(): async def monitor_disconnection():
if isinstance(request_obj, Request):
try: try:
msg = await request.receive() msg = await request_obj.receive()
if msg["type"] == "http.disconnect": if msg["type"] == "http.disconnect":
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.") logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
cancellation_event.set() cancellation_event.set()
@@ -765,7 +764,7 @@ async def chat(
operator_results=operator_results, operator_results=operator_results,
research_results=research_results, research_results=research_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=uploaded_images, query_images=uploaded_images,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
@@ -778,6 +777,36 @@ async def chat(
) )
except Exception as e: except Exception as e:
logger.error(f"Error in disconnect monitor: {e}") logger.error(f"Error in disconnect monitor: {e}")
elif isinstance(request_obj, WebSocket):
while request_obj.client_state == WebSocketState.CONNECTED:
await asyncio.sleep(1)
logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.")
cancellation_event.set()
if conversation:
await asyncio.shield(
save_to_conversation_log(
q,
chat_response="",
user=user,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=user_scope.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
)
# Cancel the disconnect monitor task if it is still running # Cancel the disconnect monitor task if it is still running
async def cancel_disconnect_monitor(): async def cancel_disconnect_monitor():
@@ -864,12 +893,12 @@ async def chat(
logger.info(f"Chat response total time: {latency:.3f} seconds") logger.info(f"Chat response total time: {latency:.3f} seconds")
logger.info(f"Chat response cost: ${cost:.5f}") logger.info(f"Chat response cost: ${cost:.5f}")
update_telemetry_state( update_telemetry_state(
request=request, request=request_obj,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
client=common.client, client=common.client,
user_agent=request.headers.get("user-agent"), user_agent=headers.get("user-agent"),
host=request.headers.get("host"), host=headers.get("host"),
metadata=chat_metadata, metadata=chat_metadata,
) )
@@ -894,7 +923,7 @@ async def chat(
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, user,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
title=title, title=title,
create_new=body.create_new, create_new=body.create_new,
@@ -935,7 +964,7 @@ async def chat(
# Refresh conversation to check if interrupted message saved to DB # Refresh conversation to check if interrupted message saved to DB
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, user,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
) )
if ( if (
@@ -1006,7 +1035,7 @@ async def chat(
cmds_to_rate_limit += conversation_commands cmds_to_rate_limit += conversation_commands
for cmd in cmds_to_rate_limit: for cmd in cmds_to_rate_limit:
try: try:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e: except HTTPException as e:
async for result in send_llm_response(str(e.detail), tracer.get("usage")): async for result in send_llm_response(str(e.detail), tracer.get("usage")):
@@ -1091,9 +1120,7 @@ async def chat(
inferred_queries.extend(result[1]) inferred_queries.extend(result[1])
defiltered_query = result[2] defiltered_query = result[2]
except Exception as e: except Exception as e:
error_message = ( error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
f"Error searching knowledge base: {e}. Attempting to respond without document references."
)
logger.error(error_message, exc_info=True) logger.error(error_message, exc_info=True)
async for result in send_event( async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references" ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
@@ -1343,9 +1370,7 @@ async def chat(
else: else:
error_message = "Failed to generate diagram. Please try again later." error_message = "Failed to generate diagram. Please try again later."
program_execution_context.append( program_execution_context.append(
prompts.failed_diagram_generation.format( prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt)
attempted_diagram=better_diagram_description_prompt
)
) )
async for result in send_event(ChatEvent.STATUS, error_message): async for result in send_event(ChatEvent.STATUS, error_message):
@@ -1423,7 +1448,7 @@ async def chat(
operator_results=operator_results, operator_results=operator_results,
research_results=research_results, research_results=research_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=str(conversation.id), conversation_id=str(conversation.id),
query_images=uploaded_images, query_images=uploaded_images,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
@@ -1450,11 +1475,119 @@ async def chat(
# Cancel the disconnect monitor task if it is still running # Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor() await cancel_disconnect_monitor()
## Stream Text Response
if stream: @api_chat.websocket("/ws")
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") @requires(["authenticated"])
## Non-Streaming Text Response async def chat_ws(
websocket: WebSocket,
common: CommonQueryParams,
):
await websocket.accept()
# Initialize rate limiters
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
rate_limiter_per_day = ApiUserRateLimiter(
requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day"
)
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
current_task = None
try:
while True:
data = await websocket.receive_json()
# Handle regular chat messages
# Handle regular chat messages - ensure data has required fields
if "q" not in data:
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
continue
body = ChatRequestBody(**data)
# Apply rate limiting manually
try:
rate_limiter_per_minute.check_websocket(websocket)
rate_limiter_per_day.check_websocket(websocket)
image_rate_limiter.check_websocket(websocket, body)
except HTTPException as e:
await websocket.send_text(json.dumps({"error": e.detail}))
continue
# Cancel any ongoing task before starting a new one
if current_task and not current_task.done():
current_task.cancel()
try:
await current_task
except asyncio.CancelledError:
pass
# Create a new task for processing the chat request
current_task = asyncio.create_task(process_chat_request(websocket, body, common))
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
if current_task and not current_task.done():
current_task.cancel()
except Exception as e:
logger.error(f"Error in websocket chat: {e}", exc_info=True)
if current_task and not current_task.done():
current_task.cancel()
await websocket.close(code=1011, reason="Internal Server Error")
async def process_chat_request(
websocket: WebSocket,
body: ChatRequestBody,
common: CommonQueryParams,
):
"""Process a single chat request with interrupt support"""
try:
# Since we are using websockets, we can ignore the stream parameter and always stream
response_iterator = event_generator(
body,
websocket.scope["user"],
common,
websocket.headers,
websocket,
)
async for event in response_iterator:
await websocket.send_text(event)
except asyncio.CancelledError:
logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
raise
except Exception as e:
logger.error(f"Error processing chat request: {e}", exc_info=True)
await websocket.send_text(json.dumps({"error": "Internal server error"}))
raise
@api_chat.post("")
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
body: ChatRequestBody,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
):
response_iterator = event_generator(
body,
request.user,
common,
request.headers,
request,
)
# Stream Text Response
if body.stream:
return StreamingResponse(response_iterator, media_type="text/plain")
# Non-Streaming Text Response
else: else:
response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator) response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View File

@@ -33,7 +33,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.utils import timezone as django_timezone from django.utils import timezone as django_timezone
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from starlette.requests import URL from starlette.requests import URL
@@ -1936,6 +1936,53 @@ class ApiUserRateLimiter:
# Add the current request to the cache # Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug) UserRequests.objects.create(user=user, slug=self.slug)
def check_websocket(self, websocket: WebSocket):
"""WebSocket-specific rate limiting method"""
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
# Remove requests outside of the time window
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
logger.info(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
)
# Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug)
class ApiImageRateLimiter: class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10): def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
@@ -1983,6 +2030,47 @@ class ApiImageRateLimiter:
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
) )
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
"""WebSocket-specific image rate limiting method"""
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class ConversationCommandRateLimiter: class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
@@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter:
self.subscribed_rate_limit = subscribed_rate_limit self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Research] self.restricted_commands = [ConversationCommand.Research]
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
if state.billing_enabled is False: if state.billing_enabled is False:
return return