mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Replace old chat router with new chat router with advanced streaming
- Details Only return notes refs, online refs, inferred queries and generated response in non-streaming mode. Do not return train of throught and other status messages Incorporate missing logic from old chat API router into new one. - Motivation So we can halve chat API code by getting rid of the duplicate logic for the websocket router The deduplicated code: - Avoids inadvertant logic drift between the 2 routers - Improves dev velocity
This commit is contained in:
@@ -709,27 +709,11 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
}
|
||||
let references = {};
|
||||
if (imageJson.context && imageJson.context.length > 0) {
|
||||
const rawReferenceAsJson = imageJson.context;
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
references["notes"] = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
references["online"] = rawReferenceAsJson;
|
||||
}
|
||||
}
|
||||
if (imageJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += imageJson.detail;
|
||||
}
|
||||
return { rawResponse, references };
|
||||
}
|
||||
|
||||
function addMessageToChatBody(rawResponse, newResponseElement, references) {
|
||||
newResponseElement.innerHTML = "";
|
||||
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||
// If response has detail field, response is an error message.
|
||||
if (imageJson.detail) rawResponse += imageJson.detail;
|
||||
|
||||
finalizeChatBodyResponse(references, newResponseElement);
|
||||
return rawResponse;
|
||||
}
|
||||
|
||||
function finalizeChatBodyResponse(references, newResponseElement) {
|
||||
@@ -743,7 +727,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
function collectJsonsInBufferedMessageChunk(chunk) {
|
||||
// Collect list of JSON objects and raw strings in the chunk
|
||||
// Return the list of objects and the remaining raw string
|
||||
console.log("Raw Chunk:", chunk);
|
||||
let startIndex = chunk.indexOf('{');
|
||||
if (startIndex === -1) return { objects: [chunk], remainder: '' };
|
||||
const objects = [chunk.slice(0, startIndex)];
|
||||
@@ -819,11 +802,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
isVoice: false,
|
||||
}
|
||||
} else if (chunk.type === "references") {
|
||||
const rawReferenceAsJson = JSON.parse(chunk.data);
|
||||
chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results};
|
||||
chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results};
|
||||
} else if (chunk.type === 'message') {
|
||||
const chunkData = chunk.data;
|
||||
if (chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
|
||||
if (typeof chunkData === 'object' && chunkData !== null) {
|
||||
// If chunkData is already a JSON object
|
||||
handleJsonResponse(chunkData);
|
||||
} else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
|
||||
// Try process chunk data as if it is a JSON object
|
||||
try {
|
||||
const jsonData = JSON.parse(chunkData.trim());
|
||||
@@ -841,17 +826,15 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
|
||||
function handleJsonResponse(jsonData) {
|
||||
if (jsonData.image || jsonData.detail) {
|
||||
let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse);
|
||||
chatMessageState.rawResponse = rawResponse;
|
||||
chatMessageState.references = references;
|
||||
chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
|
||||
} else if (jsonData.response) {
|
||||
chatMessageState.rawResponse = jsonData.response;
|
||||
chatMessageState.references = {
|
||||
notes: jsonData.context || {},
|
||||
online: jsonData.online_results || {}
|
||||
};
|
||||
}
|
||||
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
|
||||
|
||||
if (chatMessageState.newResponseTextEl) {
|
||||
chatMessageState.newResponseTextEl.innerHTML = "";
|
||||
chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
|
||||
}
|
||||
}
|
||||
|
||||
async function sendMessageStream(query) {
|
||||
@@ -866,7 +849,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
refreshChatSessionsPanel();
|
||||
}
|
||||
|
||||
let chatStreamUrl = `/api/chat/stream?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&client=web`;
|
||||
let chatStreamUrl = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&stream=true&client=web`;
|
||||
chatStreamUrl += (!!region && !!city && !!countryName && !!timezone)
|
||||
? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
|
||||
: '';
|
||||
|
||||
@@ -525,13 +525,14 @@ async def set_conversation_title(
|
||||
)
|
||||
|
||||
|
||||
@api_chat.get("/stream")
|
||||
async def stream_chat(
|
||||
@api_chat.get("")
|
||||
async def chat(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: int = 7,
|
||||
d: float = 0.18,
|
||||
stream: Optional[bool] = False,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
city: Optional[str] = None,
|
||||
@@ -550,7 +551,7 @@ async def stream_chat(
|
||||
user: KhojUser = request.user.object
|
||||
q = unquote(q)
|
||||
|
||||
async def send_event(event_type: str, data: str):
|
||||
async def send_event(event_type: str, data: str | dict):
|
||||
nonlocal connection_alive
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
@@ -559,7 +560,9 @@ async def stream_chat(
|
||||
try:
|
||||
if event_type == "message":
|
||||
yield data
|
||||
else:
|
||||
elif event_type == "references":
|
||||
yield json.dumps({"type": event_type, "data": data})
|
||||
elif stream:
|
||||
yield json.dumps({"type": event_type, "data": data})
|
||||
except asyncio.CancelledError:
|
||||
connection_alive = False
|
||||
@@ -744,6 +747,8 @@ async def stream_chat(
|
||||
yield result
|
||||
return
|
||||
|
||||
# Gather Context
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
@@ -778,6 +783,7 @@ async def stream_chat(
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
## Gather Online References
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
async for result in search_online(
|
||||
@@ -794,6 +800,7 @@ async def stream_chat(
|
||||
yield result
|
||||
return
|
||||
|
||||
## Gather Webpage References
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
|
||||
@@ -818,6 +825,19 @@ async def stream_chat(
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
## Send Gathered References
|
||||
async for result in send_event(
|
||||
"references",
|
||||
{
|
||||
"inferredQueries": inferred_queries,
|
||||
"context": compiled_references,
|
||||
"online_results": online_results,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
|
||||
# Generate Output
|
||||
## Generate Image Output
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
@@ -875,11 +895,7 @@ async def stream_chat(
|
||||
yield result
|
||||
return
|
||||
|
||||
async for result in send_event(
|
||||
"references", json.dumps({"context": compiled_references, "online_results": online_results})
|
||||
):
|
||||
yield result
|
||||
|
||||
## Generate Text Output
|
||||
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
|
||||
yield result
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
@@ -897,6 +913,8 @@ async def stream_chat(
|
||||
user_name,
|
||||
)
|
||||
|
||||
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||
chat_metadata["conversation_command"] = cmd_set
|
||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||
|
||||
update_telemetry_state(
|
||||
@@ -905,12 +923,13 @@ async def stream_chat(
|
||||
api="chat",
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
# Send Response
|
||||
async for result in send_event("start_llm_response", ""):
|
||||
yield result
|
||||
|
||||
continue_stream = True
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
async for result in send_event("end_llm_response", ""):
|
||||
@@ -931,282 +950,22 @@ async def stream_chat(
|
||||
continue_stream = False
|
||||
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
|
||||
|
||||
return StreamingResponse(event_generator(q), media_type="text/plain")
|
||||
|
||||
|
||||
@api_chat.get("", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.22,
|
||||
stream: Optional[bool] = False,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
country: Optional[str] = None,
|
||||
timezone: Optional[str] = None,
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
q = unquote(q)
|
||||
if is_query_empty(q):
|
||||
return Response(
|
||||
content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
|
||||
media_type="text/plain",
|
||||
status_code=400,
|
||||
)
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f"Chat request by {user.username}: {q}")
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
_custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
help_str = "/" + ConversationCommand.Help
|
||||
if q.strip() == help_str:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
# Adding specification to search online specifically on khoj.dev pages.
|
||||
_custom_filters.append("site:khoj.dev")
|
||||
conversation_commands.append(ConversationCommand.Online)
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, request.user.client_app, conversation_id, title
|
||||
)
|
||||
conversation_id = conversation.id if conversation else None
|
||||
|
||||
if not conversation:
|
||||
return Response(
|
||||
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
|
||||
)
|
||||
else:
|
||||
meta_log = conversation.conversation_log
|
||||
|
||||
if ConversationCommand.Summarize in conversation_commands:
|
||||
file_filters = conversation.file_filters
|
||||
llm_response = ""
|
||||
if len(file_filters) == 0:
|
||||
llm_response = "No files selected for summarization. Please add files using the section on the left."
|
||||
elif len(file_filters) > 1:
|
||||
llm_response = "Only one file can be selected for summarization."
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
summarizeStr = "/" + ConversationCommand.Summarize
|
||||
if q.strip() == summarizeStr:
|
||||
q = "Create a general summary of the file"
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
llm_response = str(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}")
|
||||
llm_response = "Error summarizing file."
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
conversation.conversation_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
location = None
|
||||
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
|
||||
return Response(
|
||||
content=f"Unable to create automation. Ensure the automation doesn't already exist.",
|
||||
media_type="text/plain",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="automation",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||
else:
|
||||
return Response(content=llm_response, media_type="text/plain", status_code=200)
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
|
||||
)
|
||||
online_results: Dict[str, Dict] = {}
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
no_entries_found_format = no_entries_found.format()
|
||||
if stream:
|
||||
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
|
||||
else:
|
||||
response_obj = {"response": no_entries_found_format}
|
||||
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
no_notes_found_format = no_notes_found.format()
|
||||
if stream:
|
||||
return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
|
||||
else:
|
||||
response_obj = {"response": no_notes_found_format}
|
||||
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
|
||||
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
||||
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
online_results = await read_webpages(defiltered_query, meta_log, location)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
|
||||
)
|
||||
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||
)
|
||||
if image is None:
|
||||
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type=intent_type,
|
||||
inferred_queries=[improved_image_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation.id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
)
|
||||
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
request.user.client_app,
|
||||
conversation.id,
|
||||
location,
|
||||
user_name,
|
||||
)
|
||||
|
||||
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||
chat_metadata["conversation_command"] = cmd_set
|
||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata=chat_metadata,
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
if llm_response is None:
|
||||
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
||||
|
||||
## Stream Text Response
|
||||
if stream:
|
||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||
return StreamingResponse(event_generator(q), media_type="text/plain")
|
||||
## Non-Streaming Text Response
|
||||
else:
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
response_obj = {}
|
||||
actual_response = ""
|
||||
iterator = event_generator(q)
|
||||
async for item in iterator:
|
||||
try:
|
||||
item_json = json.loads(item)
|
||||
if "type" in item_json and item_json["type"] == "references":
|
||||
response_obj = item_json["data"]
|
||||
except:
|
||||
actual_response += item
|
||||
response_obj["response"] = actual_response
|
||||
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
aggregated_gpt_response = ""
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
break
|
||||
aggregated_gpt_response += item
|
||||
|
||||
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
||||
|
||||
response_obj = {
|
||||
"response": actual_response,
|
||||
"inferredQueries": inferred_queries,
|
||||
"context": compiled_references,
|
||||
"online_results": online_results,
|
||||
}
|
||||
|
||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||
|
||||
Reference in New Issue
Block a user