diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 13cc0325..8670d35e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1596,6 +1596,7 @@ async def process_chat_request( self.last_flush = time.perf_counter() message_buffer = MessageBuffer() + thought_buffer = MessageBuffer() BUFFER_FLUSH_INTERVAL = 0.1 # 100ms buffer interval BUFFER_MAX_SIZE = 512 # Flush if buffer reaches this size @@ -1611,6 +1612,18 @@ async def process_chat_request( message_buffer.timeout = None yield buffered_content + async def flush_thought_buffer(): + """Flush the accumulated thought buffer to the client""" + nonlocal thought_buffer + if thought_buffer.content: + thought_event = json.dumps({"type": ChatEvent.THOUGHT.value, "data": thought_buffer.content}) + thought_buffer.content = "" + thought_buffer.last_flush = time.perf_counter() + if thought_buffer.timeout: + thought_buffer.timeout.cancel() + thought_buffer.timeout = None + yield thought_event + try: # Since we are using websockets, we can ignore the stream parameter and always stream response_iterator = event_generator( @@ -1629,6 +1642,37 @@ async def process_chat_request( chunks = "".join([chunk async for chunk in flush_message_buffer()]) await websocket.send_text(chunks) await websocket.send_text(ChatEvent.END_EVENT.value) + elif evt_json["type"] == ChatEvent.THOUGHT.value: + # Buffer THOUGHT events for better streaming performance + thought_buffer.content += str(evt_json.get("data", "")) + + # Flush if buffer is too large or enough time has passed + current_time = time.perf_counter() + should_flush_time = (current_time - thought_buffer.last_flush) >= BUFFER_FLUSH_INTERVAL + should_flush_size = len(thought_buffer.content) >= BUFFER_MAX_SIZE + + if should_flush_size or should_flush_time: + thought_event = "".join([chunk async for chunk in flush_thought_buffer()]) + await websocket.send_text(thought_event) + await websocket.send_text(ChatEvent.END_EVENT.value) + else: + # Cancel any previous timeout tasks to reset the flush timer + if thought_buffer.timeout: + thought_buffer.timeout.cancel() + + async def delayed_thought_flush(): + """Flush thought buffer if no new messages arrive within debounce interval.""" + await asyncio.sleep(BUFFER_FLUSH_INTERVAL) + # Check if there's still content to flush + thought_chunks = "".join([chunk async for chunk in flush_thought_buffer()]) + if thought_chunks: + thought_event = "".join([chunk async for chunk in flush_thought_buffer()]) + await websocket.send_text(thought_event) + await websocket.send_text(ChatEvent.END_EVENT.value) + + # Flush buffer if no new thoughts arrive within debounce interval + thought_buffer.timeout = asyncio.create_task(delayed_thought_flush()) + continue await websocket.send_text(event) await websocket.send_text(ChatEvent.END_EVENT.value) elif event != ChatEvent.END_EVENT.value: