diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 34f6be08..94aaebcf 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -4,6 +4,7 @@ import json import logging import time import uuid +from dataclasses import dataclass from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -1572,6 +1573,37 @@ async def process_chat_request( interrupt_queue: asyncio.Queue, ): """Process a single chat request with interrupt support""" + + # Server-side message buffering for better streaming performance + @dataclass + class MessageBuffer: + """Buffer for managing streamed chat messages with timing control.""" + + content: str = "" + timeout: Optional[asyncio.Task] = None + last_flush: float = 0.0 + + def __post_init__(self): + """Initialize last_flush with current time if not provided.""" + if self.last_flush == 0.0: + self.last_flush = time.perf_counter() + + message_buffer = MessageBuffer() + BUFFER_FLUSH_INTERVAL = 0.1 # 100ms buffer interval + BUFFER_MAX_SIZE = 512 # Flush if buffer reaches this size + + async def flush_message_buffer(): + """Flush the accumulated message buffer to the client""" + nonlocal message_buffer + if message_buffer.content: + buffered_content = message_buffer.content + message_buffer.content = "" + message_buffer.last_flush = time.perf_counter() + if message_buffer.timeout: + message_buffer.timeout.cancel() + message_buffer.timeout = None + yield buffered_content + try: # Since we are using websockets, we can ignore the stream parameter and always stream response_iterator = event_generator( @@ -1583,7 +1615,43 @@ async def process_chat_request( interrupt_queue, ) async for event in response_iterator: - await websocket.send_text(event) + if event.startswith("{") and event.endswith("}"): + evt_json = json.loads(event) + if evt_json["type"] == ChatEvent.END_LLM_RESPONSE.value: + # Flush remaining buffer content on end llm response event + chunks = "".join([chunk async for chunk in flush_message_buffer()]) + await websocket.send_text(chunks) + await websocket.send_text(ChatEvent.END_EVENT.value) + await websocket.send_text(event) + await websocket.send_text(ChatEvent.END_EVENT.value) + elif event != ChatEvent.END_EVENT.value: + # Buffer MESSAGE events for better streaming performance + message_buffer.content += str(event) + + # Flush if buffer is too large or enough time has passed + current_time = time.perf_counter() + should_flush_time = (current_time - message_buffer.last_flush) >= BUFFER_FLUSH_INTERVAL + should_flush_size = len(message_buffer.content) >= BUFFER_MAX_SIZE + + if should_flush_size or should_flush_time: + chunks = "".join([chunk async for chunk in flush_message_buffer()]) + await websocket.send_text(chunks) + await websocket.send_text(ChatEvent.END_EVENT.value) + else: + # Cancel any previous timeout tasks to reset the flush timer + if message_buffer.timeout: + message_buffer.timeout.cancel() + + async def delayed_flush(): + """Flush message buffer if no new messages arrive within debounce interval.""" + await asyncio.sleep(BUFFER_FLUSH_INTERVAL) + # Check if there's still content to flush + chunks = "".join([chunk async for chunk in flush_message_buffer()]) + await websocket.send_text(chunks) + await websocket.send_text(ChatEvent.END_EVENT.value) + + # Flush buffer if no new messages arrive within debounce interval + message_buffer.timeout = asyncio.create_task(delayed_flush()) except asyncio.CancelledError: logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}") raise