Stream Status Messages via Streaming Response from server to web client

- Overview
Use simpler HTTP Streaming Response to send status messages, alongside
response and references from server to clients via API.

Update web client to use the streamed response to show train of thought,
stream response and render references.

- Motivation
This should allow other Khoj clients to pass auth headers and recieve
Khoj's train of thought messages from server over simple HTTP
streaming API.

It'll also eventually deduplicate chat logic across /websocket and
/chat API endpoints and help maintainability and dev velocity

- Details
  - Pass references as a separate streaming message type for simpler
    parsing. Remove passing "### compiled references" altogether once
    the original /api/chat API is deprecated/merged with the new one
    and clients have been updated to consume the references using this
    new mechanism
  - Save message to conversation even if client disconnects. This is
    done by not breaking out of the async iterator that is sending the
    llm response. As the save conversation is called at the end of the
    iteration
  - Handle parsing chunked json responses as a valid json on client.
    This requires additional logic on client side but makes the client
    more robust to server chunking json response such that each chunk
    isn't itself necessarily a valid json.
This commit is contained in:
Debanjum Singh Solanky
2024-07-22 00:20:23 +05:30
parent 91fe41106e
commit b8d3e3669a
3 changed files with 222 additions and 191 deletions

View File

@@ -40,7 +40,6 @@ dependencies = [
"dateparser >= 1.1.1",
"defusedxml == 0.7.1",
"fastapi >= 0.104.1",
"sse-starlette ~= 2.1.0",
"python-multipart >= 0.0.7",
"jinja2 == 3.1.4",
"openai >= 1.0.0",

View File

@@ -74,13 +74,12 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000);
});
}
var sseConnection = null;
let region = null;
let city = null;
let countryName = null;
let timezone = null;
let waitingForLocation = true;
let chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
@@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co
.finally(() => {
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
waitingForLocation = false;
initializeSSE();
initMessageState();
});
function formatDate(date) {
@@ -599,7 +598,7 @@ To get started, just start typing below. You can also type / to see a list of co
}
async function chat(isVoice=false) {
sendMessageViaSSE(isVoice);
renderMessageStream(isVoice);
return;
let query = document.getElementById("chat-input").value.trim();
@@ -1067,7 +1066,7 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat;
function initializeSSE(isVoice=false) {
function initMessageState(isVoice=false) {
if (waitingForLocation) {
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
return;
@@ -1084,136 +1083,180 @@ To get started, just start typing below. You can also type / to see a list of co
}
}
function sendSSEMessage(query) {
function sendMessageStream(query) {
let chatBody = document.getElementById("chat-body");
let sseProtocol = window.location.protocol;
let sseUrl = `/api/chat/stream?q=${query}`;
let chatStreamUrl = `/api/chat/stream?q=${query}`;
if (chatBody.dataset.conversationId) {
sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
chatStreamUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
chatStreamUrl += (!!region && !!city && !!countryName && !!timezone)
? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}`
: '';
function handleChatResponse(event) {
// Get the last element in the chat-body
let chunk = event.data;
try {
if (chunk.includes("application/json"))
chunk = JSON.parse(chunk);
} catch (error) {
// If the chunk is not a JSON object, continue.
fetch(chatStreamUrl)
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
let netBracketCount = 0;
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
console.log("Stream complete");
handleChunk(buffer);
buffer = '';
return;
}
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length;
if (netBracketCount === 0) {
chunks = processJsonObjects(buffer);
chunks.objects.forEach(obj => handleChunk(obj));
buffer = chunks.remainder;
}
readStream();
});
}
readStream();
})
.catch(error => {
console.error('Error:', error);
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
}
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
});
function processJsonObjects(str) {
let startIndex = str.indexOf('{');
if (startIndex === -1) return { objects: [str], remainder: '' };
const objects = [str.slice(0, startIndex)];
let openBraces = 0;
let currentObject = '';
for (let i = startIndex; i < str.length; i++) {
if (str[i] === '{') {
if (openBraces === 0) startIndex = i;
openBraces++;
}
if (str[i] === '}') {
openBraces--;
if (openBraces === 0) {
currentObject = str.slice(startIndex, i + 1);
objects.push(currentObject);
currentObject = '';
}
}
}
const contentType = chunk["content-type"]
if (contentType === "application/json") {
// Handle JSON response
try {
if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse));
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else {
rawResponse = chunk.response;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
chatMessageState.rawResponse += chunk;
} finally {
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
}
} else {
// Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse));
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else {
// If the chunk is not a JSON object, just display it as is
chatMessageState.rawResponse += chunk;
if (chatMessageState.newResponseTextEl) {
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
return {
objects: objects,
remainder: openBraces > 0 ? str.slice(startIndex) : ''
};
}
};
sseConnection = new EventSource(sseUrl);
sseConnection.onmessage = handleChatResponse;
sseConnection.addEventListener("complete_llm_response", handleChatResponse);
sseConnection.addEventListener("status", (event) => {
console.log(`${event.data}`);
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false);
});
sseConnection.addEventListener("rate_limit", (event) => {
handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true);
});
sseConnection.addEventListener("start_llm_response", (event) => {
console.log("Started streaming", new Date());
});
sseConnection.addEventListener("end_llm_response", (event) => {
sseConnection.close();
console.log("Stopped streaming", new Date());
function handleChunk(rawChunk) {
// Split the chunk into lines
console.log("Chunk:", rawChunk);
if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) {
try {
let jsonChunk = JSON.parse(rawChunk);
if (!jsonChunk.type)
jsonChunk = {type: 'message', data: jsonChunk};
processChunk(jsonChunk);
} catch (e) {
const jsonChunk = {type: 'message', data: rawChunk};
processChunk(jsonChunk);
}
} else if (rawChunk.length > 0) {
const jsonChunk = {type: 'message', data: rawChunk};
processChunk(jsonChunk);
}
}
function processChunk(chunk) {
console.log(chunk);
if (chunk.type ==='status') {
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data;
handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false);
} else if (chunk.type === 'start_llm_response') {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse);
// Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
const liveQuery = chatMessageState.rawQuery;
// Reset variables
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
const liveQuery = chatMessageState.rawQuery;
// Reset variables
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
}
} else if (chunk.type === "references") {
const rawReferenceAsJson = JSON.parse(chunk.data);
console.log(`${chunk.type}: ${rawReferenceAsJson}`);
chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results};
} else if (chunk.type === 'message') {
if (chunk.data.trim()?.startsWith("{") && chunk.data.trim()?.endsWith("}")) {
// Try process chunk data as if it is a JSON object
try {
const jsonData = JSON.parse(chunk.data.trim());
handleJsonResponse(jsonData);
} catch (e) {
// Handle text response chunk with compiled references
if (chunk?.data.includes("### compiled references:")) {
chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0];
// Handle text response chunk
} else {
chatMessageState.rawResponse += chunk.data;
}
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
} else {
// Handle text response chunk with compiled references
if (chunk?.data.includes("### compiled references:")) {
chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0];
// Handle text response chunk
} else {
chatMessageState.rawResponse += chunk.data;
}
handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
}
}
}
// Reset status icon
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
});
sseConnection.onclose = function(event) {
sseConnection = null;
console.debug("SSE is closed now.");
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
}
sseConnection.onerror = function(event) {
console.log("SSE error observed:", event);
sseConnection.close();
sseConnection = null;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Server Error";
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
function handleJsonResponse(jsonData) {
if (jsonData.image || jsonData.detail) {
let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse);
chatMessageState.rawResponse = rawResponse;
chatMessageState.references = references;
} else if (jsonData.response) {
chatMessageState.rawResponse = jsonData.response;
chatMessageState.references = {
notes: jsonData.context || {},
online: jsonData.online_results || {}
};
}
addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
}
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
}
sseConnection.onopen = function(event) {
console.debug("SSE is open now.")
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "orange";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Processing";
}
}
function sendMessageViaSSE(isVoice=false) {
function renderMessageStream(isVoice=false) {
let chatBody = document.getElementById("chat-body");
var query = document.getElementById("chat-input").value.trim();
@@ -1253,7 +1296,7 @@ To get started, just start typing below. You can also type / to see a list of co
chatInput.classList.remove("option-enabled");
// Call specified Khoj API
sendSSEMessage(query);
sendMessageStream(query);
let rawResponse = "";
let references = {};
@@ -1267,6 +1310,7 @@ To get started, just start typing below. You can also type / to see a list of co
isVoice: isVoice,
}
}
var userMessages = [];
var userMessageIndex = -1;
function loadChat() {
@@ -1276,7 +1320,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
initializeSSE();
initMessageState();
loadFileFiltersFromConversation();
}
@@ -1316,7 +1360,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation();
initializeSSE();
initMessageState();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent;

View File

@@ -11,7 +11,6 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from sse_starlette import EventSourceResponse
from starlette.authentication import requires
from khoj.app.settings import ALLOWED_HOSTS
@@ -543,15 +542,24 @@ async def stream_chat(
async def send_event(event_type: str, data: str):
nonlocal connection_alive
if not connection_alive or await request.is_disconnected():
connection_alive = False
return
try:
if event_type == "message":
yield data
else:
yield {"event": event_type, "data": data, "retry": 15000}
yield json.dumps({"type": event_type, "data": data})
except Exception as e:
connection_alive = False
logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}")
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
async def send_llm_response(response: str):
async for result in send_event("start_llm_response", ""):
yield result
async for result in send_event("message", response):
yield result
async for result in send_event("end_llm_response", ""):
yield result
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
@@ -585,17 +593,10 @@ async def stream_chat(
except HTTPException as e:
async for result in send_event("rate_limit", e.detail):
yield result
break
return
if is_query_empty(q):
async for event in send_event("start_llm_response", ""):
yield event
async for event in send_event(
"message",
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
):
yield event
async for event in send_event("end_llm_response", ""):
async for event in send_llm_response("Please ask your query to get started."):
yield event
return
@@ -645,25 +646,19 @@ async def stream_chat(
response_log = (
"No files selected for summarization. Please add files using the section on the left."
)
async for result in send_event("complete_llm_response", response_log):
async for result in send_llm_response(response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization."
async for result in send_event("complete_llm_response", response_log):
async for result in send_llm_response(response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
else:
try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_event("complete_llm_response", response_log):
async for result in send_llm_response(response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
@@ -675,17 +670,13 @@ async def stream_chat(
response = await extract_relevant_summary(q, contextual_data)
response_log = str(response)
async for result in send_event("complete_llm_response", response_log):
async for result in send_llm_response(response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_event("complete_llm_response", response_log):
async for result in send_llm_response(response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
await sync_to_async(save_to_conversation_log)(
q,
response_log,
@@ -714,10 +705,8 @@ async def stream_chat(
formatted_help = help_message.format(
model=model_type, version=state.khoj_version, device=get_device()
)
async for result in send_event("complete_llm_response", formatted_help):
async for result in send_llm_response(formatted_help):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online)
@@ -730,10 +719,8 @@ async def stream_chat(
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_event("complete_llm_response", error_message):
async for result in send_llm_response(error_message):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
@@ -760,10 +747,8 @@ async def stream_chat(
api="chat",
**common.__dict__,
)
async for result in send_event("complete_llm_response", llm_response):
async for result in send_llm_response(llm_response):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
compiled_references, inferred_queries, defiltered_query = [], [], None
@@ -797,9 +782,7 @@ async def stream_chat(
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
user
):
async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"):
yield result
async for event in send_event("end_llm_response", ""):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield event
return
@@ -818,10 +801,8 @@ async def stream_chat(
except ValueError as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_event("complete_llm_response", error_message):
async for result in send_llm_response(error_message):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
if ConversationCommand.Webpage in conversation_commands:
@@ -873,15 +854,13 @@ async def stream_chat(
if image is None or status_code != 200:
content_obj = {
"image": image,
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"content-type": "application/json",
"image": image,
}
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj)):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
await sync_to_async(save_to_conversation_log)(
@@ -898,19 +877,22 @@ async def stream_chat(
online_results=online_results,
)
content_obj = {
"image": image,
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"context": compiled_references,
"content-type": "application/json",
"intentType": intent_type,
"context": compiled_references,
"online_results": online_results,
"inferredQueries": [improved_image_prompt],
"image": image,
}
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj)):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
async for result in send_event(
"references", json.dumps({"context": compiled_references, "online_results": online_results})
):
yield result
async for result in send_event("status", f"**💭 Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(
@@ -941,27 +923,33 @@ async def stream_chat(
async for result in send_event("start_llm_response", ""):
yield result
continue_stream = True
async for item in iterator:
if item is None:
break
if connection_alive:
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
connection_alive = False
logger.info(
f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}"
)
async for result in send_event("end_llm_response", ""):
yield result
async for result in send_event("end_llm_response", ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:
continue
try:
async for result in send_event("message", f"{item}"):
yield result
except Exception as e:
continue_stream = False
logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}")
# Stop streaming after compiled references section of response starts
# References are being processed via the references event rather than the message event
if "### compiled references:" in item:
continue_stream = False
except asyncio.CancelledError:
break
logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True)
return
except Exception as e:
logger.error(f"Error in SSE endpoint: {e}", exc_info=True)
break
logger.error(f"General Error in API endpoint: {e}", exc_info=True)
return
return EventSourceResponse(event_generator(q))
return StreamingResponse(event_generator(q), media_type="text/plain")
@api_chat.get("", response_class=Response)