Fix chat API for non-streaming mode json response

This commit is contained in:
Debanjum Singh Solanky
2024-08-06 19:09:52 +05:30
parent 00ee4c2697
commit 167ef000f4
2 changed files with 79 additions and 16 deletions

View File

@@ -41,6 +41,7 @@ from khoj.routers.helpers import (
get_conversation_command,
is_query_empty,
is_ready_to_chat,
read_chat_stream,
text_to_image,
update_telemetry_state,
validate_conversation_config,
@@ -570,8 +571,7 @@ async def chat(
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
return
finally:
if stream:
yield event_delimiter
yield event_delimiter
async def send_llm_response(response: str):
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
@@ -937,17 +937,6 @@ async def chat(
return StreamingResponse(event_generator(q), media_type="text/plain")
## Non-Streaming Text Response
else:
# Get the full response from the generator if the stream is not requested.
response_obj = {}
actual_response = ""
iterator = event_generator(q)
async for item in iterator:
try:
item_json = json.loads(item)
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
response_obj = item_json["data"]
except:
actual_response += item
response_obj["response"] = actual_response
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
response_iterator = event_generator(q)
response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View File

@@ -15,6 +15,7 @@ from random import random
from typing import (
Annotated,
Any,
AsyncGenerator,
Callable,
Dict,
Iterator,
@@ -1225,6 +1226,79 @@ class ChatEvent(Enum):
STATUS = "status"
class MessageProcessor:
def __init__(self):
self.references = {}
self.raw_response = ""
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
try:
json_chunk = json.loads(raw_chunk)
if "type" not in json_chunk:
json_chunk = {"type": "message", "data": json_chunk}
return json_chunk
except json.JSONDecodeError:
return {"type": "message", "data": raw_chunk}
elif raw_chunk:
return {"type": "message", "data": raw_chunk}
return {"type": "", "data": ""}
def process_message_chunk(self, raw_chunk: str) -> None:
chunk = self.convert_message_chunk_to_json(raw_chunk)
if not chunk or not chunk["type"]:
return
chunk_type = ChatEvent(chunk["type"])
if chunk_type == ChatEvent.REFERENCES:
self.references = chunk["data"]
elif chunk_type == ChatEvent.MESSAGE:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
self.raw_response = self.handle_json_response(chunk_data)
elif (
isinstance(chunk_data, str) and chunk_data.strip().startswith("{") and chunk_data.strip().endswith("}")
):
try:
json_data = json.loads(chunk_data.strip())
self.raw_response = self.handle_json_response(json_data)
except json.JSONDecodeError:
self.raw_response += chunk_data
else:
self.raw_response += chunk_data
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
return json_data
if "response" in json_data:
return json_data["response"]
return json_data
async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]:
processor = MessageProcessor()
event_delimiter = "␃🔚␗"
buffer = ""
async for chunk in response_iterator:
# Start buffering chunks until complete event is received
buffer += chunk
# Once the buffer contains a complete event
while event_delimiter in buffer:
# Extract the event from the buffer
event, buffer = buffer.split(event_delimiter, 1)
# Process the event
if event:
processor.process_message_chunk(event)
# Process any remaining data in the buffer
if buffer:
processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
user_picture = request.session.get("user", {}).get("picture")
is_active = has_required_scope(request, ["premium"])