From 167ef000f40cc27e43c9d16f1658df2aaad4b2d8 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 6 Aug 2024 19:09:52 +0530 Subject: [PATCH] Fix chat API for non-streaming mode json response --- src/khoj/routers/api_chat.py | 21 +++------- src/khoj/routers/helpers.py | 74 ++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 6ff3ba8f..bab6db03 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -41,6 +41,7 @@ from khoj.routers.helpers import ( get_conversation_command, is_query_empty, is_ready_to_chat, + read_chat_stream, text_to_image, update_telemetry_state, validate_conversation_config, @@ -570,8 +571,7 @@ async def chat( logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) return finally: - if stream: - yield event_delimiter + yield event_delimiter async def send_llm_response(response: str): async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): @@ -937,17 +937,6 @@ async def chat( return StreamingResponse(event_generator(q), media_type="text/plain") ## Non-Streaming Text Response else: - # Get the full response from the generator if the stream is not requested. - response_obj = {} - actual_response = "" - iterator = event_generator(q) - async for item in iterator: - try: - item_json = json.loads(item) - if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value: - response_obj = item_json["data"] - except: - actual_response += item - response_obj["response"] = actual_response - - return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) + response_iterator = event_generator(q) + response_data = await read_chat_stream(response_iterator) + return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 193f6591..33500fb3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -15,6 +15,7 @@ from random import random from typing import ( Annotated, Any, + AsyncGenerator, Callable, Dict, Iterator, @@ -1225,6 +1226,79 @@ class ChatEvent(Enum): STATUS = "status" +class MessageProcessor: + def __init__(self): + self.references = {} + self.raw_response = "" + + def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: + if raw_chunk.startswith("{") and raw_chunk.endswith("}"): + try: + json_chunk = json.loads(raw_chunk) + if "type" not in json_chunk: + json_chunk = {"type": "message", "data": json_chunk} + return json_chunk + except json.JSONDecodeError: + return {"type": "message", "data": raw_chunk} + elif raw_chunk: + return {"type": "message", "data": raw_chunk} + return {"type": "", "data": ""} + + def process_message_chunk(self, raw_chunk: str) -> None: + chunk = self.convert_message_chunk_to_json(raw_chunk) + if not chunk or not chunk["type"]: + return + + chunk_type = ChatEvent(chunk["type"]) + if chunk_type == ChatEvent.REFERENCES: + self.references = chunk["data"] + elif chunk_type == ChatEvent.MESSAGE: + chunk_data = chunk["data"] + if isinstance(chunk_data, dict): + self.raw_response = self.handle_json_response(chunk_data) + elif ( + isinstance(chunk_data, str) and chunk_data.strip().startswith("{") and chunk_data.strip().endswith("}") + ): + try: + json_data = json.loads(chunk_data.strip()) + self.raw_response = self.handle_json_response(json_data) + except json.JSONDecodeError: + self.raw_response += chunk_data + else: + self.raw_response += chunk_data + + def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]: + if "image" in json_data or "details" in json_data: + return json_data + if "response" in json_data: + return json_data["response"] + return json_data + + +async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]: + processor = MessageProcessor() + event_delimiter = "␃🔚␗" + buffer = "" + + async for chunk in response_iterator: + # Start buffering chunks until complete event is received + buffer += chunk + + # Once the buffer contains a complete event + while event_delimiter in buffer: + # Extract the event from the buffer + event, buffer = buffer.split(event_delimiter, 1) + # Process the event + if event: + processor.process_message_chunk(event) + + # Process any remaining data in the buffer + if buffer: + processor.process_message_chunk(buffer) + + return {"response": processor.raw_response, "references": processor.references} + + def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): user_picture = request.session.get("user", {}).get("picture") is_active = has_required_scope(request, ["premium"])