Stream Status Messages via Streaming Response from server to web client

- Overview
Use simpler HTTP Streaming Response to send status messages, alongside
response and references from server to clients via API.

Update web client to use the streamed response to show train of thought,
stream response and render references.

- Motivation
This should allow other Khoj clients to pass auth headers and recieve
Khoj's train of thought messages from server over simple HTTP
streaming API.

It'll also eventually deduplicate chat logic across /websocket and
/chat API endpoints and help maintainability and dev velocity

- Details
  - Pass references as a separate streaming message type for simpler
    parsing. Remove passing "### compiled references" altogether once
    the original /api/chat API is deprecated/merged with the new one
    and clients have been updated to consume the references using this
    new mechanism
  - Save message to conversation even if client disconnects. This is
    done by not breaking out of the async iterator that is sending the
    llm response. As the save conversation is called at the end of the
    iteration
  - Handle parsing chunked json responses as a valid json on client.
    This requires additional logic on client side but makes the client
    more robust to server chunking json response such that each chunk
    isn't itself necessarily a valid json.
This commit is contained in:
Debanjum Singh Solanky
2024-07-22 00:20:23 +05:30
parent 91fe41106e
commit b8d3e3669a
3 changed files with 222 additions and 191 deletions

View File

@@ -40,7 +40,6 @@ dependencies = [
"dateparser >= 1.1.1", "dateparser >= 1.1.1",
"defusedxml == 0.7.1", "defusedxml == 0.7.1",
"fastapi >= 0.104.1", "fastapi >= 0.104.1",
"sse-starlette ~= 2.1.0",
"python-multipart >= 0.0.7", "python-multipart >= 0.0.7",
"jinja2 == 3.1.4", "jinja2 == 3.1.4",
"openai >= 1.0.0", "openai >= 1.0.0",

View File

@@ -74,13 +74,12 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000); }, 1000);
}); });
} }
var sseConnection = null;
let region = null; let region = null;
let city = null; let city = null;
let countryName = null; let countryName = null;
let timezone = null; let timezone = null;
let waitingForLocation = true; let waitingForLocation = true;
let chatMessageState = { let chatMessageState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
@@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co
.finally(() => { .finally(() => {
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
waitingForLocation = false; waitingForLocation = false;
initializeSSE(); initMessageState();
}); });
function formatDate(date) { function formatDate(date) {
@@ -599,7 +598,7 @@ To get started, just start typing below. You can also type / to see a list of co
} }
async function chat(isVoice=false) { async function chat(isVoice=false) {
sendMessageViaSSE(isVoice); renderMessageStream(isVoice);
return; return;
let query = document.getElementById("chat-input").value.trim(); let query = document.getElementById("chat-input").value.trim();
@@ -1067,7 +1066,7 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat; window.onload = loadChat;
function initializeSSE(isVoice=false) { function initMessageState(isVoice=false) {
if (waitingForLocation) { if (waitingForLocation) {
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
return; return;
@@ -1084,136 +1083,180 @@ To get started, just start typing below. You can also type / to see a list of co
} }
} }
function sendSSEMessage(query) { function sendMessageStream(query) {
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
let sseProtocol = window.location.protocol; let chatStreamUrl = `/api/chat/stream?q=${query}`;
let sseUrl = `/api/chat/stream?q=${query}`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatStreamUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; chatStreamUrl += (!!region && !!city && !!countryName && !!timezone)
? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
: '';
function handleChatResponse(event) { fetch(chatStreamUrl)
// Get the last element in the chat-body .then(response => {
let chunk = event.data; const reader = response.body.getReader();
try { const decoder = new TextDecoder();
if (chunk.includes("application/json")) let buffer = '';
chunk = JSON.parse(chunk); let netBracketCount = 0;
} catch (error) {
// If the chunk is not a JSON object, continue. function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
console.log("Stream complete");
handleChunk(buffer);
buffer = '';
return;
}
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length;
if (netBracketCount === 0) {
chunks = processJsonObjects(buffer);
chunks.objects.forEach(obj => handleChunk(obj));
buffer = chunks.remainder;
}
readStream();
});
}
readStream();
})
.catch(error => {
console.error('Error:', error);
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
}
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
});
function processJsonObjects(str) {
let startIndex = str.indexOf('{');
if (startIndex === -1) return { objects: [str], remainder: '' };
const objects = [str.slice(0, startIndex)];
let openBraces = 0;
let currentObject = '';
for (let i = startIndex; i < str.length; i++) {
if (str[i] === '{') {
if (openBraces === 0) startIndex = i;
openBraces++;
}
if (str[i] === '}') {
openBraces--;
if (openBraces === 0) {
currentObject = str.slice(startIndex, i + 1);
objects.push(currentObject);
currentObject = '';
}
}
} }
const contentType = chunk["content-type"] return {
if (contentType === "application/json") { objects: objects,
// Handle JSON response remainder: openBraces > 0 ? str.slice(startIndex) : ''
try {
if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse));
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else {
rawResponse = chunk.response;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
chatMessageState.rawResponse += chunk;
} finally {
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
}
} else {
// Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse));
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else {
// If the chunk is not a JSON object, just display it as is
chatMessageState.rawResponse += chunk;
if (chatMessageState.newResponseTextEl) {
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}; };
} }
};
sseConnection = new EventSource(sseUrl); function handleChunk(rawChunk) {
sseConnection.onmessage = handleChatResponse; // Split the chunk into lines
sseConnection.addEventListener("complete_llm_response", handleChatResponse); console.log("Chunk:", rawChunk);
sseConnection.addEventListener("status", (event) => { if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
console.log(`${event.data}`); try {
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false); let jsonChunk = JSON.parse(rawChunk);
}); if (!jsonChunk.type)
sseConnection.addEventListener("rate_limit", (event) => { jsonChunk = {type: 'message', data: jsonChunk};
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true); processChunk(jsonChunk);
}); } catch (e) {
sseConnection.addEventListener("start_llm_response", (event) => { const jsonChunk = {type: 'message', data: rawChunk};
console.log("Started streaming", new Date()); processChunk(jsonChunk);
}); }
sseConnection.addEventListener("end_llm_response", (event) => { } else if (rawChunk.length > 0) {
sseConnection.close(); const jsonChunk = {type: 'message', data: rawChunk};
console.log("Stopped streaming", new Date()); processChunk(jsonChunk);
}
}
function processChunk(chunk) {
console.log(chunk);
if (chunk.type ==='status') {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data;
handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false);
} else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message // Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True") if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse); textToSpeech(chatMessageState.rawResponse);
// Append any references after all the data has been streamed // Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
const liveQuery = chatMessageState.rawQuery; const liveQuery = chatMessageState.rawQuery;
// Reset variables // Reset variables
chatMessageState = { chatMessageState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
loadingEllipsis: null, loadingEllipsis: null,
references: {}, references: {},
rawResponse: "", rawResponse: "",
rawQuery: liveQuery, rawQuery: liveQuery,
}
} else if (chunk.type === "references") {
const rawReferenceAsJson = JSON.parse(chunk.data);
console.log(`${chunk.type}: ${rawReferenceAsJson}`);
chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results};
} else if (chunk.type === 'message') {
if (chunk.data.trim()?.startsWith("{") && chunk.data.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunk.data.trim());
handleJsonResponse(jsonData);
} catch (e) {
// Handle text response chunk with compiled references
if (chunk?.data.includes("### compiled references:")) {
chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0];
// Handle text response chunk
} else {
chatMessageState.rawResponse += chunk.data;
}
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
} else {
// Handle text response chunk with compiled references
if (chunk?.data.includes("### compiled references:")) {
chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0];
// Handle text response chunk
} else {
chatMessageState.rawResponse += chunk.data;
}
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
} }
// Reset status icon function handleJsonResponse(jsonData) {
let statusDotIcon = document.getElementById("connection-status-icon"); if (jsonData.image || jsonData.detail) {
statusDotIcon.style.backgroundColor = "green"; let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse);
let statusDotText = document.getElementById("connection-status-text"); chatMessageState.rawResponse = rawResponse;
statusDotText.textContent = "Ready"; chatMessageState.references = references;
statusDotText.style.marginTop = "5px"; } else if (jsonData.response) {
}); chatMessageState.rawResponse = jsonData.response;
sseConnection.onclose = function(event) { chatMessageState.references = {
sseConnection = null; notes: jsonData.context || {},
console.debug("SSE is closed now."); online: jsonData.online_results || {}
let statusDotIcon = document.getElementById("connection-status-icon"); };
statusDotIcon.style.backgroundColor = "green"; }
let statusDotText = document.getElementById("connection-status-text"); addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
}
sseConnection.onerror = function(event) {
console.log("SSE error observed:", event);
sseConnection.close();
sseConnection = null;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Server Error";
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
} }
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
}
sseConnection.onopen = function(event) {
console.debug("SSE is open now.")
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "orange";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Processing";
} }
} }
function sendMessageViaSSE(isVoice=false) { function renderMessageStream(isVoice=false) {
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
var query = document.getElementById("chat-input").value.trim(); var query = document.getElementById("chat-input").value.trim();
@@ -1253,7 +1296,7 @@ To get started, just start typing below. You can also type / to see a list of co
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Call specified Khoj API // Call specified Khoj API
sendSSEMessage(query); sendMessageStream(query);
let rawResponse = ""; let rawResponse = "";
let references = {}; let references = {};
@@ -1267,6 +1310,7 @@ To get started, just start typing below. You can also type / to see a list of co
isVoice: isVoice, isVoice: isVoice,
} }
} }
var userMessages = []; var userMessages = [];
var userMessageIndex = -1; var userMessageIndex = -1;
function loadChat() { function loadChat() {
@@ -1276,7 +1320,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`; let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
initializeSSE(); initMessageState();
loadFileFiltersFromConversation(); loadFileFiltersFromConversation();
} }
@@ -1316,7 +1360,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation(); loadFileFiltersFromConversation();
initializeSSE(); initMessageState();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent; let agentMetadata = response.agent;

View File

@@ -11,7 +11,6 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from sse_starlette import EventSourceResponse
from starlette.authentication import requires from starlette.authentication import requires
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
@@ -543,15 +542,24 @@ async def stream_chat(
async def send_event(event_type: str, data: str): async def send_event(event_type: str, data: str):
nonlocal connection_alive nonlocal connection_alive
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False
return return
try: try:
if event_type == "message": if event_type == "message":
yield data yield data
else: else:
yield {"event": event_type, "data": data, "retry": 15000} yield json.dumps({"type": event_type, "data": data})
except Exception as e: except Exception as e:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}") logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""):
yield result
async for result in send_event("message", response):
yield result
async for result in send_event("end_llm_response", ""):
yield result
user: KhojUser = request.user.object user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
@@ -585,17 +593,10 @@ async def stream_chat(
except HTTPException as e: except HTTPException as e:
async for result in send_event("rate_limit", e.detail): async for result in send_event("rate_limit", e.detail):
yield result yield result
break return
if is_query_empty(q): if is_query_empty(q):
async for event in send_event("start_llm_response", ""): async for event in send_llm_response("Please ask your query to get started."):
yield event
async for event in send_event(
"message",
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
):
yield event
async for event in send_event("end_llm_response", ""):
yield event yield event
return return
@@ -645,25 +646,19 @@ async def stream_chat(
response_log = ( response_log = (
"No files selected for summarization. Please add files using the section on the left." "No files selected for summarization. Please add files using the section on the left."
) )
async for result in send_event("complete_llm_response", response_log): async for result in send_llm_response(response_log):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
elif len(file_filters) > 1: elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization." response_log = "Only one file can be selected for summarization."
async for result in send_event("complete_llm_response", response_log): async for result in send_llm_response(response_log):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
else: else:
try: try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0: if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_event("complete_llm_response", response_log): async for result in send_llm_response(response_log):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
contextual_data = " ".join([file.raw_text for file in file_object]) contextual_data = " ".join([file.raw_text for file in file_object])
if not q: if not q:
@@ -675,17 +670,13 @@ async def stream_chat(
response = await extract_relevant_summary(q, contextual_data) response = await extract_relevant_summary(q, contextual_data)
response_log = str(response) response_log = str(response)
async for result in send_event("complete_llm_response", response_log): async for result in send_llm_response(response_log):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
except Exception as e: except Exception as e:
response_log = "Error summarizing file." response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_event("complete_llm_response", response_log): async for result in send_llm_response(response_log):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
response_log, response_log,
@@ -714,10 +705,8 @@ async def stream_chat(
formatted_help = help_message.format( formatted_help = help_message.format(
model=model_type, version=state.khoj_version, device=get_device() model=model_type, version=state.khoj_version, device=get_device()
) )
async for result in send_event("complete_llm_response", formatted_help): async for result in send_llm_response(formatted_help):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
custom_filters.append("site:khoj.dev") custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online) conversation_commands.append(ConversationCommand.Online)
@@ -730,10 +719,8 @@ async def stream_chat(
except Exception as e: except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}") logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist." error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_event("complete_llm_response", error_message): async for result in send_llm_response(error_message):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
@@ -760,10 +747,8 @@ async def stream_chat(
api="chat", api="chat",
**common.__dict__, **common.__dict__,
) )
async for result in send_event("complete_llm_response", llm_response): async for result in send_llm_response(llm_response):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
compiled_references, inferred_queries, defiltered_query = [], [], None compiled_references, inferred_queries, defiltered_query = [], [], None
@@ -797,9 +782,7 @@ async def stream_chat(
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
user user
): ):
async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"): async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
async for event in send_event("end_llm_response", ""):
yield event yield event
return return
@@ -818,10 +801,8 @@ async def stream_chat(
except ValueError as e: except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results" error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message) logger.warning(error_message)
async for result in send_event("complete_llm_response", error_message): async for result in send_llm_response(error_message):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
@@ -873,15 +854,13 @@ async def stream_chat(
if image is None or status_code != 200: if image is None or status_code != 200:
content_obj = { content_obj = {
"image": image, "content-type": "application/json",
"intentType": intent_type, "intentType": intent_type,
"detail": improved_image_prompt, "detail": improved_image_prompt,
"content-type": "application/json", "image": image,
} }
async for result in send_event("complete_llm_response", json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
@@ -898,19 +877,22 @@ async def stream_chat(
online_results=online_results, online_results=online_results,
) )
content_obj = { content_obj = {
"image": image,
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"context": compiled_references,
"content-type": "application/json", "content-type": "application/json",
"intentType": intent_type,
"context": compiled_references,
"online_results": online_results, "online_results": online_results,
"inferredQueries": [improved_image_prompt],
"image": image,
} }
async for result in send_event("complete_llm_response", json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
yield result yield result
async for event in send_event("end_llm_response", ""):
yield event
return return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
yield result
async for result in send_event("status", f"**💭 Generating a well-informed response**"): async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result yield result
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
@@ -941,27 +923,33 @@ async def stream_chat(
async for result in send_event("start_llm_response", ""): async for result in send_event("start_llm_response", ""):
yield result yield result
continue_stream = True
async for item in iterator: async for item in iterator:
if item is None: if item is None:
break async for result in send_event("end_llm_response", ""):
if connection_alive: yield result
try: logger.debug("Finished streaming response")
async for result in send_event("message", f"{item}"): return
yield result if not connection_alive or not continue_stream:
except Exception as e: continue
connection_alive = False try:
logger.info( async for result in send_event("message", f"{item}"):
f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}" yield result
) except Exception as e:
async for result in send_event("end_llm_response", ""): continue_stream = False
yield result logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
except asyncio.CancelledError: except asyncio.CancelledError:
break logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
return
except Exception as e: except Exception as e:
logger.error(f"Error in SSE endpoint: {e}", exc_info=True) logger.error(f"General Error in API endpoint: {e}", exc_info=True)
break return
return EventSourceResponse(event_generator(q)) return StreamingResponse(event_generator(q), media_type="text/plain")
@api_chat.get("", response_class=Response) @api_chat.get("", response_class=Response)