mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Fix chat API for non-streaming mode json response
This commit is contained in:
@@ -41,6 +41,7 @@ from khoj.routers.helpers import (
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
read_chat_stream,
|
||||
text_to_image,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
@@ -570,8 +571,7 @@ async def chat(
|
||||
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True)
|
||||
return
|
||||
finally:
|
||||
if stream:
|
||||
yield event_delimiter
|
||||
yield event_delimiter
|
||||
|
||||
async def send_llm_response(response: str):
|
||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||
@@ -937,17 +937,6 @@ async def chat(
|
||||
return StreamingResponse(event_generator(q), media_type="text/plain")
|
||||
## Non-Streaming Text Response
|
||||
else:
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
response_obj = {}
|
||||
actual_response = ""
|
||||
iterator = event_generator(q)
|
||||
async for item in iterator:
|
||||
try:
|
||||
item_json = json.loads(item)
|
||||
if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value:
|
||||
response_obj = item_json["data"]
|
||||
except:
|
||||
actual_response += item
|
||||
response_obj["response"] = actual_response
|
||||
|
||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||
response_iterator = event_generator(q)
|
||||
response_data = await read_chat_stream(response_iterator)
|
||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||
|
||||
@@ -15,6 +15,7 @@ from random import random
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
@@ -1225,6 +1226,79 @@ class ChatEvent(Enum):
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class MessageProcessor:
|
||||
def __init__(self):
|
||||
self.references = {}
|
||||
self.raw_response = ""
|
||||
|
||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
||||
try:
|
||||
json_chunk = json.loads(raw_chunk)
|
||||
if "type" not in json_chunk:
|
||||
json_chunk = {"type": "message", "data": json_chunk}
|
||||
return json_chunk
|
||||
except json.JSONDecodeError:
|
||||
return {"type": "message", "data": raw_chunk}
|
||||
elif raw_chunk:
|
||||
return {"type": "message", "data": raw_chunk}
|
||||
return {"type": "", "data": ""}
|
||||
|
||||
def process_message_chunk(self, raw_chunk: str) -> None:
|
||||
chunk = self.convert_message_chunk_to_json(raw_chunk)
|
||||
if not chunk or not chunk["type"]:
|
||||
return
|
||||
|
||||
chunk_type = ChatEvent(chunk["type"])
|
||||
if chunk_type == ChatEvent.REFERENCES:
|
||||
self.references = chunk["data"]
|
||||
elif chunk_type == ChatEvent.MESSAGE:
|
||||
chunk_data = chunk["data"]
|
||||
if isinstance(chunk_data, dict):
|
||||
self.raw_response = self.handle_json_response(chunk_data)
|
||||
elif (
|
||||
isinstance(chunk_data, str) and chunk_data.strip().startswith("{") and chunk_data.strip().endswith("}")
|
||||
):
|
||||
try:
|
||||
json_data = json.loads(chunk_data.strip())
|
||||
self.raw_response = self.handle_json_response(json_data)
|
||||
except json.JSONDecodeError:
|
||||
self.raw_response += chunk_data
|
||||
else:
|
||||
self.raw_response += chunk_data
|
||||
|
||||
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
||||
if "image" in json_data or "details" in json_data:
|
||||
return json_data
|
||||
if "response" in json_data:
|
||||
return json_data["response"]
|
||||
return json_data
|
||||
|
||||
|
||||
async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict[str, Any]:
|
||||
processor = MessageProcessor()
|
||||
event_delimiter = "␃🔚␗"
|
||||
buffer = ""
|
||||
|
||||
async for chunk in response_iterator:
|
||||
# Start buffering chunks until complete event is received
|
||||
buffer += chunk
|
||||
|
||||
# Once the buffer contains a complete event
|
||||
while event_delimiter in buffer:
|
||||
# Extract the event from the buffer
|
||||
event, buffer = buffer.split(event_delimiter, 1)
|
||||
# Process the event
|
||||
if event:
|
||||
processor.process_message_chunk(event)
|
||||
|
||||
# Process any remaining data in the buffer
|
||||
if buffer:
|
||||
processor.process_message_chunk(buffer)
|
||||
|
||||
return {"response": processor.raw_response, "references": processor.references}
|
||||
|
||||
|
||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
is_active = has_required_scope(request, ["premium"])
|
||||
|
||||
Reference in New Issue
Block a user