diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index 00139232..6855c196 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -709,27 +709,11 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
- let references = {};
- if (imageJson.context && imageJson.context.length > 0) {
- const rawReferenceAsJson = imageJson.context;
- if (rawReferenceAsJson instanceof Array) {
- references["notes"] = rawReferenceAsJson;
- } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
- references["online"] = rawReferenceAsJson;
- }
- }
- if (imageJson.detail) {
- // If response has detail field, response is an error message.
- rawResponse += imageJson.detail;
- }
- return { rawResponse, references };
- }
- function addMessageToChatBody(rawResponse, newResponseElement, references) {
- newResponseElement.innerHTML = "";
- newResponseElement.appendChild(formatHTMLMessage(rawResponse));
+ // If response has detail field, response is an error message.
+ if (imageJson.detail) rawResponse += imageJson.detail;
- finalizeChatBodyResponse(references, newResponseElement);
+ return rawResponse;
}
function finalizeChatBodyResponse(references, newResponseElement) {
@@ -743,7 +727,6 @@ To get started, just start typing below. You can also type / to see a list of co
function collectJsonsInBufferedMessageChunk(chunk) {
// Collect list of JSON objects and raw strings in the chunk
// Return the list of objects and the remaining raw string
- console.log("Raw Chunk:", chunk);
let startIndex = chunk.indexOf('{');
if (startIndex === -1) return { objects: [chunk], remainder: '' };
const objects = [chunk.slice(0, startIndex)];
@@ -819,11 +802,13 @@ To get started, just start typing below. You can also type / to see a list of co
isVoice: false,
}
} else if (chunk.type === "references") {
- const rawReferenceAsJson = JSON.parse(chunk.data);
- chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results};
+ chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results};
} else if (chunk.type === 'message') {
const chunkData = chunk.data;
- if (chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
+ if (typeof chunkData === 'object' && chunkData !== null) {
+ // If chunkData is already a JSON object
+ handleJsonResponse(chunkData);
+ } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunkData.trim());
@@ -841,17 +826,15 @@ To get started, just start typing below. You can also type / to see a list of co
function handleJsonResponse(jsonData) {
if (jsonData.image || jsonData.detail) {
- let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse);
- chatMessageState.rawResponse = rawResponse;
- chatMessageState.references = references;
+ chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse);
} else if (jsonData.response) {
chatMessageState.rawResponse = jsonData.response;
- chatMessageState.references = {
- notes: jsonData.context || {},
- online: jsonData.online_results || {}
- };
}
- addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
+
+ if (chatMessageState.newResponseTextEl) {
+ chatMessageState.newResponseTextEl.innerHTML = "";
+ chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse));
+ }
}
async function sendMessageStream(query) {
@@ -866,7 +849,7 @@ To get started, just start typing below. You can also type / to see a list of co
refreshChatSessionsPanel();
}
- let chatStreamUrl = `/api/chat/stream?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&client=web`;
+ let chatStreamUrl = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&stream=true&client=web`;
chatStreamUrl += (!!region && !!city && !!countryName && !!timezone)
? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
: '';
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 34879b86..d8826264 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -525,13 +525,14 @@ async def set_conversation_title(
)
-@api_chat.get("/stream")
-async def stream_chat(
+@api_chat.get("")
+async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: int = 7,
d: float = 0.18,
+ stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
@@ -550,7 +551,7 @@ async def stream_chat(
user: KhojUser = request.user.object
q = unquote(q)
- async def send_event(event_type: str, data: str):
+ async def send_event(event_type: str, data: str | dict):
nonlocal connection_alive
if not connection_alive or await request.is_disconnected():
connection_alive = False
@@ -559,7 +560,9 @@ async def stream_chat(
try:
if event_type == "message":
yield data
- else:
+ elif event_type == "references":
+ yield json.dumps({"type": event_type, "data": data})
+ elif stream:
yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError:
connection_alive = False
@@ -744,6 +747,8 @@ async def stream_chat(
yield result
return
+ # Gather Context
+ ## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
@@ -778,6 +783,7 @@ async def stream_chat(
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
+ ## Gather Online References
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
@@ -794,6 +800,7 @@ async def stream_chat(
yield result
return
+ ## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")):
@@ -818,6 +825,19 @@ async def stream_chat(
exc_info=True,
)
+ ## Send Gathered References
+ async for result in send_event(
+ "references",
+ {
+ "inferredQueries": inferred_queries,
+ "context": compiled_references,
+ "online_results": online_results,
+ },
+ ):
+ yield result
+
+ # Generate Output
+ ## Generate Image Output
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
@@ -875,11 +895,7 @@ async def stream_chat(
yield result
return
- async for result in send_event(
- "references", json.dumps({"context": compiled_references, "online_results": online_results})
- ):
- yield result
-
+ ## Generate Text Output
async for result in send_event("status", f"**💠Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
@@ -897,6 +913,8 @@ async def stream_chat(
user_name,
)
+ cmd_set = set([cmd.value for cmd in conversation_commands])
+ chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
@@ -905,12 +923,13 @@ async def stream_chat(
api="chat",
metadata=chat_metadata,
)
- iterator = AsyncIteratorWrapper(llm_response)
+ # Send Response
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
+ iterator = AsyncIteratorWrapper(llm_response)
async for item in iterator:
if item is None:
async for result in send_event("end_llm_response", ""):
@@ -931,282 +950,22 @@ async def stream_chat(
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
- return StreamingResponse(event_generator(q), media_type="text/plain")
-
-
-@api_chat.get("", response_class=Response)
-@requires(["authenticated"])
-async def chat(
- request: Request,
- common: CommonQueryParams,
- q: str,
- n: Optional[int] = 5,
- d: Optional[float] = 0.22,
- stream: Optional[bool] = False,
- title: Optional[str] = None,
- conversation_id: Optional[int] = None,
- city: Optional[str] = None,
- region: Optional[str] = None,
- country: Optional[str] = None,
- timezone: Optional[str] = None,
- rate_limiter_per_minute=Depends(
- ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
- ),
- rate_limiter_per_day=Depends(
- ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
- ),
-) -> Response:
- user: KhojUser = request.user.object
- q = unquote(q)
- if is_query_empty(q):
- return Response(
- content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
- media_type="text/plain",
- status_code=400,
- )
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- logger.info(f"Chat request by {user.username}: {q}")
-
- await is_ready_to_chat(user)
- conversation_commands = [get_conversation_command(query=q, any_references=True)]
-
- _custom_filters = []
- if conversation_commands == [ConversationCommand.Help]:
- help_str = "/" + ConversationCommand.Help
- if q.strip() == help_str:
- conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
- if conversation_config == None:
- conversation_config = await ConversationAdapters.aget_default_conversation_config()
- model_type = conversation_config.model_type
- formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
- return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
- # Adding specification to search online specifically on khoj.dev pages.
- _custom_filters.append("site:khoj.dev")
- conversation_commands.append(ConversationCommand.Online)
-
- conversation = await ConversationAdapters.aget_conversation_by_user(
- user, request.user.client_app, conversation_id, title
- )
- conversation_id = conversation.id if conversation else None
-
- if not conversation:
- return Response(
- content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
- )
- else:
- meta_log = conversation.conversation_log
-
- if ConversationCommand.Summarize in conversation_commands:
- file_filters = conversation.file_filters
- llm_response = ""
- if len(file_filters) == 0:
- llm_response = "No files selected for summarization. Please add files using the section on the left."
- elif len(file_filters) > 1:
- llm_response = "Only one file can be selected for summarization."
- else:
- try:
- file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
- if len(file_object) == 0:
- llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
- contextual_data = " ".join([file.raw_text for file in file_object])
- summarizeStr = "/" + ConversationCommand.Summarize
- if q.strip() == summarizeStr:
- q = "Create a general summary of the file"
- response = await extract_relevant_summary(q, contextual_data)
- llm_response = str(response)
- except Exception as e:
- logger.error(f"Error summarizing file for {user.email}: {e}")
- llm_response = "Error summarizing file."
- await sync_to_async(save_to_conversation_log)(
- q,
- llm_response,
- user,
- conversation.conversation_log,
- user_message_time,
- intent_type="summarize",
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- )
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- metadata={"conversation_command": conversation_commands[0].value},
- **common.__dict__,
- )
- return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
-
- is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
-
- if conversation_commands == [ConversationCommand.Default] or is_automated_task:
- conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
- mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
- if mode not in conversation_commands:
- conversation_commands.append(mode)
-
- for cmd in conversation_commands:
- await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
- q = q.replace(f"/{cmd.value}", "").strip()
-
- location = None
-
- if city or region or country:
- location = LocationData(city=city, region=region, country=country)
-
- user_name = await aget_user_name(user)
-
- if ConversationCommand.Automation in conversation_commands:
- try:
- automation, crontime, query_to_run, subject = await create_automation(
- q, timezone, user, request.url, meta_log
- )
- except Exception as e:
- logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
- return Response(
- content=f"Unable to create automation. Ensure the automation doesn't already exist.",
- media_type="text/plain",
- status_code=500,
- )
-
- llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
- await sync_to_async(save_to_conversation_log)(
- q,
- llm_response,
- user,
- meta_log,
- user_message_time,
- intent_type="automation",
- client_application=request.user.client_app,
- conversation_id=conversation_id,
- inferred_queries=[query_to_run],
- automation_id=automation.id,
- )
-
- if stream:
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
- else:
- return Response(content=llm_response, media_type="text/plain", status_code=200)
-
- compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
- )
- online_results: Dict[str, Dict] = {}
-
- if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
- no_entries_found_format = no_entries_found.format()
- if stream:
- return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
- else:
- response_obj = {"response": no_entries_found_format}
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
-
- if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
- no_notes_found_format = no_notes_found.format()
- if stream:
- return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
- else:
- response_obj = {"response": no_notes_found_format}
- return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
-
- if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
- conversation_commands.remove(ConversationCommand.Notes)
-
- if ConversationCommand.Online in conversation_commands:
- try:
- online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
- except ValueError as e:
- logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
-
- if ConversationCommand.Webpage in conversation_commands:
- try:
- online_results = await read_webpages(defiltered_query, meta_log, location)
- except ValueError as e:
- logger.warning(
- f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
- )
-
- if ConversationCommand.Image in conversation_commands:
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- metadata={"conversation_command": conversation_commands[0].value},
- **common.__dict__,
- )
- image, status_code, improved_image_prompt, intent_type = await text_to_image(
- q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
- )
- if image is None:
- content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
-
- await sync_to_async(save_to_conversation_log)(
- q,
- image,
- user,
- meta_log,
- user_message_time,
- intent_type=intent_type,
- inferred_queries=[improved_image_prompt],
- client_application=request.user.client_app,
- conversation_id=conversation.id,
- compiled_references=compiled_references,
- online_results=online_results,
- )
- content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
- return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
-
- # Get the (streamed) chat response from the LLM of choice.
- llm_response, chat_metadata = await agenerate_chat_response(
- defiltered_query,
- meta_log,
- conversation,
- compiled_references,
- online_results,
- inferred_queries,
- conversation_commands,
- user,
- request.user.client_app,
- conversation.id,
- location,
- user_name,
- )
-
- cmd_set = set([cmd.value for cmd in conversation_commands])
- chat_metadata["conversation_command"] = cmd_set
- chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
-
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- metadata=chat_metadata,
- **common.__dict__,
- )
-
- if llm_response is None:
- return Response(content=llm_response, media_type="text/plain", status_code=500)
-
+ ## Stream Text Response
if stream:
- return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
+ return StreamingResponse(event_generator(q), media_type="text/plain")
+ ## Non-Streaming Text Response
+ else:
+ # Get the full response from the generator if the stream is not requested.
+ response_obj = {}
+ actual_response = ""
+ iterator = event_generator(q)
+ async for item in iterator:
+ try:
+ item_json = json.loads(item)
+ if "type" in item_json and item_json["type"] == "references":
+ response_obj = item_json["data"]
+ except:
+ actual_response += item
+ response_obj["response"] = actual_response
- iterator = AsyncIteratorWrapper(llm_response)
-
- # Get the full response from the generator if the stream is not requested.
- aggregated_gpt_response = ""
- async for item in iterator:
- if item is None:
- break
- aggregated_gpt_response += item
-
- actual_response = aggregated_gpt_response.split("### compiled references:")[0]
-
- response_obj = {
- "response": actual_response,
- "inferredQueries": inferred_queries,
- "context": compiled_references,
- "online_results": online_results,
- }
-
- return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
+ return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)