Replace old chat router with new chat router with advanced streaming

- Details
  Only return notes refs, online refs, inferred queries and generated
  response in non-streaming mode. Do not return train of throught and
  other status messages

  Incorporate missing logic from old chat API router into new one.

- Motivation
  So we can halve chat API code by getting rid of the duplicate logic
  for the websocket router

  The deduplicated code:
  - Avoids inadvertant logic drift between the 2 routers
  - Improves dev velocity
This commit is contained in:
Debanjum Singh Solanky
2024-07-22 20:29:45 +05:30
parent 2d4b284218
commit daec439d52
2 changed files with 61 additions and 319 deletions

View File

@@ -709,27 +709,11 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
let references = {};
if (imageJson.context && imageJson.context.length > 0) {
const rawReferenceAsJson = imageJson.context;
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
}
if (imageJson.detail) {
// If response has detail field, response is an error message.
rawResponse += imageJson.detail;
}
return { rawResponse, references };
}
function addMessageToChatBody(rawResponse, newResponseElement, references) {
newResponseElement.innerHTML = "";
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
// If response has detail field, response is an error message.
if (imageJson.detail) rawResponse += imageJson.detail;
finalizeChatBodyResponse(references, newResponseElement);
return rawResponse;
}
function finalizeChatBodyResponse(references, newResponseElement) {
@@ -743,7 +727,6 @@ To get started, just start typing below. You can also type / to see a list of co
function collectJsonsInBufferedMessageChunk(chunk) {
// Collect list of JSON objects and raw strings in the chunk
// Return the list of objects and the remaining raw string
console.log("Raw Chunk:", chunk);
let startIndex = chunk.indexOf('{');
if (startIndex === -1) return { objects: [chunk], remainder: '' };
const objects = [chunk.slice(0, startIndex)];
@@ -819,11 +802,13 @@ To get started, just start typing below. You can also type / to see a list of co
isVoice: false,
}
} else if (chunk.type === "references") {
const rawReferenceAsJson = JSON.parse(chunk.data);
chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results};
chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results};
} else if (chunk.type === 'message') {
const chunkData = chunk.data;
if (chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
if (typeof chunkData === 'object' && chunkData !== null) {
// If chunkData is already a JSON object
handleJsonResponse(chunkData);
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunkData.trim());
@@ -841,17 +826,15 @@ To get started, just start typing below. You can also type / to see a list of co
function handleJsonResponse(jsonData) {
if (jsonData.image || jsonData.detail) {
let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse);
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
} else if (jsonData.response) {
chatMessageState.rawResponse = jsonData.response;
chatMessageState.references = {
notes: jsonData.context || {},
online: jsonData.online_results || {}
};
}
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
if (chatMessageState.newResponseTextEl) {
chatMessageState.newResponseTextEl.innerHTML = "";
chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
}
}
async function sendMessageStream(query) {
@@ -866,7 +849,7 @@ To get started, just start typing below. You can also type / to see a list of co
refreshChatSessionsPanel();
}
let chatStreamUrl = `/api/chat/stream?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&client=web`;
let chatStreamUrl = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&stream=true&client=web`;
chatStreamUrl += (!!region && !!city && !!countryName && !!timezone)
? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
: '';

View File

@@ -525,13 +525,14 @@ async def set_conversation_title(
)
@api_chat.get("/stream")
async def stream_chat(
@api_chat.get("")
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: int = 7,
d: float = 0.18,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
@@ -550,7 +551,7 @@ async def stream_chat(
user: KhojUser = request.user.object
q = unquote(q)
async def send_event(event_type: str, data: str):
async def send_event(event_type: str, data: str | dict):
nonlocal connection_alive
if not connection_alive or await request.is_disconnected():
connection_alive = False
@@ -559,7 +560,9 @@ async def stream_chat(
try:
if event_type == "message":
yield data
else:
elif event_type == "references":
yield json.dumps({"type": event_type, "data": data})
elif stream:
yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError:
connection_alive = False
@@ -744,6 +747,8 @@ async def stream_chat(
yield result
return
# Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
@@ -778,6 +783,7 @@ async def stream_chat(
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
## Gather Online References
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
@@ -794,6 +800,7 @@ async def stream_chat(
yield result
return
## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
@@ -818,6 +825,19 @@ async def stream_chat(
exc_info=True,
)
## Send Gathered References
async for result in send_event(
"references",
{
"inferredQueries": inferred_queries,
"context": compiled_references,
"online_results": online_results,
},
):
yield result
# Generate Output
## Generate Image Output
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
@@ -875,11 +895,7 @@ async def stream_chat(
yield result
return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
yield result
## Generate Text Output
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
@@ -897,6 +913,8 @@ async def stream_chat(
user_name,
)
cmd_set = set([cmd.value for cmd in conversation_commands])
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
@@ -905,12 +923,13 @@ async def stream_chat(
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
# Send Response
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
@@ -931,282 +950,22 @@ async def stream_chat(
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
return StreamingResponse(event_generator(q), media_type="text/plain")
@api_chat.get("", response_class=Response)
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.22,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
if is_query_empty(q):
return Response(
content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
media_type="text/plain",
status_code=400,
)
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Chat request by {user.username}: {q}")
await is_ready_to_chat(user)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
_custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
help_str = "/" + ConversationCommand.Help
if q.strip() == help_str:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Adding specification to search online specifically on khoj.dev pages.
_custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
conversation = await ConversationAdapters.aget_conversation_by_user(
user, request.user.client_app, conversation_id, title
)
conversation_id = conversation.id if conversation else None
if not conversation:
return Response(
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
)
else:
meta_log = conversation.conversation_log
if ConversationCommand.Summarize in conversation_commands:
file_filters = conversation.file_filters
llm_response = ""
if len(file_filters) == 0:
llm_response = "No files selected for summarization. Please add files using the section on the left."
elif len(file_filters) > 1:
llm_response = "Only one file can be selected for summarization."
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
contextual_data = " ".join([file.raw_text for file in file_object])
summarizeStr = "/" + ConversationCommand.Summarize
if q.strip() == summarizeStr:
q = "Create a general summary of the file"
response = await extract_relevant_summary(q, contextual_data)
llm_response = str(response)
except Exception as e:
logger.error(f"Error summarizing file for {user.email}: {e}")
llm_response = "Error summarizing file."
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
conversation.conversation_log,
user_message_time,
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
user_name = await aget_user_name(user)
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
)
except Exception as e:
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
return Response(
content=f"Unable to create automation. Ensure the automation doesn't already exist.",
media_type="text/plain",
status_code=500,
)
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
user,
meta_log,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
)
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
else:
return Response(content=llm_response, media_type="text/plain", status_code=200)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
)
online_results: Dict[str, Dict] = {}
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
no_notes_found_format = no_notes_found.format()
if stream:
return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_notes_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try:
online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
if ConversationCommand.Webpage in conversation_commands:
try:
online_results = await read_webpages(defiltered_query, meta_log, location)
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
)
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt, intent_type = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None:
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
user_message_time,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation.id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
request.user.client_app,
conversation.id,
location,
user_name,
)
cmd_set = set([cmd.value for cmd in conversation_commands])
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
**common.__dict__,
)
if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500)
## Stream Text Response
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
return StreamingResponse(event_generator(q), media_type="text/plain")
## Non-Streaming Text Response
else:
# Get the full response from the generator if the stream is not requested.
response_obj = {}
actual_response = ""
iterator = event_generator(q)
async for item in iterator:
try:
item_json = json.loads(item)
if "type" in item_json and item_json["type"] == "references":
response_obj = item_json["data"]
except:
actual_response += item
response_obj["response"] = actual_response
iterator = AsyncIteratorWrapper(llm_response)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
async for item in iterator:
if item is None:
break
aggregated_gpt_response += item
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
response_obj = {
"response": actual_response,
"inferredQueries": inferred_queries,
"context": compiled_references,
"online_results": online_results,
}
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)