mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Buffer message chunks on server side for more performant ws streaming
Send larger message chunks to improve streaming efficiency and reduce rendering load on web client. This rendering load was most evident when using high throughput models, low compute clients and message with images. As message content was rerendered on every token sent to the web app. The server side message buffering should result in fewer re-renders and lower compute load on client.
This commit is contained in:
@@ -4,6 +4,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
@@ -1572,6 +1573,37 @@ async def process_chat_request(
|
|||||||
interrupt_queue: asyncio.Queue,
|
interrupt_queue: asyncio.Queue,
|
||||||
):
|
):
|
||||||
"""Process a single chat request with interrupt support"""
|
"""Process a single chat request with interrupt support"""
|
||||||
|
|
||||||
|
# Server-side message buffering for better streaming performance
|
||||||
|
@dataclass
|
||||||
|
class MessageBuffer:
|
||||||
|
"""Buffer for managing streamed chat messages with timing control."""
|
||||||
|
|
||||||
|
content: str = ""
|
||||||
|
timeout: Optional[asyncio.Task] = None
|
||||||
|
last_flush: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize last_flush with current time if not provided."""
|
||||||
|
if self.last_flush == 0.0:
|
||||||
|
self.last_flush = time.perf_counter()
|
||||||
|
|
||||||
|
message_buffer = MessageBuffer()
|
||||||
|
BUFFER_FLUSH_INTERVAL = 0.1 # 100ms buffer interval
|
||||||
|
BUFFER_MAX_SIZE = 512 # Flush if buffer reaches this size
|
||||||
|
|
||||||
|
async def flush_message_buffer():
|
||||||
|
"""Flush the accumulated message buffer to the client"""
|
||||||
|
nonlocal message_buffer
|
||||||
|
if message_buffer.content:
|
||||||
|
buffered_content = message_buffer.content
|
||||||
|
message_buffer.content = ""
|
||||||
|
message_buffer.last_flush = time.perf_counter()
|
||||||
|
if message_buffer.timeout:
|
||||||
|
message_buffer.timeout.cancel()
|
||||||
|
message_buffer.timeout = None
|
||||||
|
yield buffered_content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Since we are using websockets, we can ignore the stream parameter and always stream
|
# Since we are using websockets, we can ignore the stream parameter and always stream
|
||||||
response_iterator = event_generator(
|
response_iterator = event_generator(
|
||||||
@@ -1583,7 +1615,43 @@ async def process_chat_request(
|
|||||||
interrupt_queue,
|
interrupt_queue,
|
||||||
)
|
)
|
||||||
async for event in response_iterator:
|
async for event in response_iterator:
|
||||||
|
if event.startswith("{") and event.endswith("}"):
|
||||||
|
evt_json = json.loads(event)
|
||||||
|
if evt_json["type"] == ChatEvent.END_LLM_RESPONSE.value:
|
||||||
|
# Flush remaining buffer content on end llm response event
|
||||||
|
chunks = "".join([chunk async for chunk in flush_message_buffer()])
|
||||||
|
await websocket.send_text(chunks)
|
||||||
|
await websocket.send_text(ChatEvent.END_EVENT.value)
|
||||||
await websocket.send_text(event)
|
await websocket.send_text(event)
|
||||||
|
await websocket.send_text(ChatEvent.END_EVENT.value)
|
||||||
|
elif event != ChatEvent.END_EVENT.value:
|
||||||
|
# Buffer MESSAGE events for better streaming performance
|
||||||
|
message_buffer.content += str(event)
|
||||||
|
|
||||||
|
# Flush if buffer is too large or enough time has passed
|
||||||
|
current_time = time.perf_counter()
|
||||||
|
should_flush_time = (current_time - message_buffer.last_flush) >= BUFFER_FLUSH_INTERVAL
|
||||||
|
should_flush_size = len(message_buffer.content) >= BUFFER_MAX_SIZE
|
||||||
|
|
||||||
|
if should_flush_size or should_flush_time:
|
||||||
|
chunks = "".join([chunk async for chunk in flush_message_buffer()])
|
||||||
|
await websocket.send_text(chunks)
|
||||||
|
await websocket.send_text(ChatEvent.END_EVENT.value)
|
||||||
|
else:
|
||||||
|
# Cancel any previous timeout tasks to reset the flush timer
|
||||||
|
if message_buffer.timeout:
|
||||||
|
message_buffer.timeout.cancel()
|
||||||
|
|
||||||
|
async def delayed_flush():
|
||||||
|
"""Flush message buffer if no new messages arrive within debounce interval."""
|
||||||
|
await asyncio.sleep(BUFFER_FLUSH_INTERVAL)
|
||||||
|
# Check if there's still content to flush
|
||||||
|
chunks = "".join([chunk async for chunk in flush_message_buffer()])
|
||||||
|
await websocket.send_text(chunks)
|
||||||
|
await websocket.send_text(ChatEvent.END_EVENT.value)
|
||||||
|
|
||||||
|
# Flush buffer if no new messages arrive within debounce interval
|
||||||
|
message_buffer.timeout = asyncio.create_task(delayed_flush())
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
|
logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user