From b8d3e3669ac14b752ee08d96e65b2f3d2d1bfb41 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 00:20:23 +0530 Subject: [PATCH] Stream Status Messages via Streaming Response from server to web client - Overview Use simpler HTTP Streaming Response to send status messages, alongside response and references from server to clients via API. Update web client to use the streamed response to show train of thought, stream response and render references. - Motivation This should allow other Khoj clients to pass auth headers and recieve Khoj's train of thought messages from server over simple HTTP streaming API. It'll also eventually deduplicate chat logic across /websocket and /chat API endpoints and help maintainability and dev velocity - Details - Pass references as a separate streaming message type for simpler parsing. Remove passing "### compiled references" altogether once the original /api/chat API is deprecated/merged with the new one and clients have been updated to consume the references using this new mechanism - Save message to conversation even if client disconnects. This is done by not breaking out of the async iterator that is sending the llm response. As the save conversation is called at the end of the iteration - Handle parsing chunked json responses as a valid json on client. This requires additional logic on client side but makes the client more robust to server chunking json response such that each chunk isn't itself necessarily a valid json. --- pyproject.toml | 1 - src/khoj/interface/web/chat.html | 284 ++++++++++++++++++------------- src/khoj/routers/api_chat.py | 128 +++++++------- 3 files changed, 222 insertions(+), 191 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 939a1d9e..2669f5ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "dateparser >= 1.1.1", "defusedxml == 0.7.1", "fastapi >= 0.104.1", - "sse-starlette ~= 2.1.0", "python-multipart >= 0.0.7", "jinja2 == 3.1.4", "openai >= 1.0.0", diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 3e07a860..b1ff3eba 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -74,13 +74,12 @@ To get started, just start typing below. You can also type / to see a list of co }, 1000); }); } - var sseConnection = null; + let region = null; let city = null; let countryName = null; let timezone = null; let waitingForLocation = true; - let chatMessageState = { newResponseTextEl: null, newResponseEl: null, @@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co .finally(() => { console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); waitingForLocation = false; - initializeSSE(); + initMessageState(); }); function formatDate(date) { @@ -599,7 +598,7 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - sendMessageViaSSE(isVoice); + renderMessageStream(isVoice); return; let query = document.getElementById("chat-input").value.trim(); @@ -1067,7 +1066,7 @@ To get started, just start typing below. You can also type / to see a list of co window.onload = loadChat; - function initializeSSE(isVoice=false) { + function initMessageState(isVoice=false) { if (waitingForLocation) { console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); return; @@ -1084,136 +1083,180 @@ To get started, just start typing below. You can also type / to see a list of co } } - function sendSSEMessage(query) { + function sendMessageStream(query) { let chatBody = document.getElementById("chat-body"); - let sseProtocol = window.location.protocol; - let sseUrl = `/api/chat/stream?q=${query}`; + let chatStreamUrl = `/api/chat/stream?q=${query}`; if (chatBody.dataset.conversationId) { - sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; + chatStreamUrl += `&conversation_id=${chatBody.dataset.conversationId}`; + chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; - function handleChatResponse(event) { - // Get the last element in the chat-body - let chunk = event.data; - try { - if (chunk.includes("application/json")) - chunk = JSON.parse(chunk); - } catch (error) { - // If the chunk is not a JSON object, continue. + fetch(chatStreamUrl) + .then(response => { + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; + + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + console.log("Stream complete"); + handleChunk(buffer); + buffer = ''; + return; + } + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + chunks = processJsonObjects(buffer); + chunks.objects.forEach(obj => handleChunk(obj)); + buffer = chunks.remainder; + } + readStream(); + }); + } + + readStream(); + }) + .catch(error => { + console.error('Error:', error); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + } + chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" + }); + + function processJsonObjects(str) { + let startIndex = str.indexOf('{'); + if (startIndex === -1) return { objects: [str], remainder: '' }; + const objects = [str.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; + + for (let i = startIndex; i < str.length; i++) { + if (str[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (str[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = str.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } } - const contentType = chunk["content-type"] - if (contentType === "application/json") { - // Handle JSON response - try { - if (chunk.image || chunk.detail) { - ({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse)); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else { - rawResponse = chunk.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - chatMessageState.rawResponse += chunk; - } finally { - addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); - } - } else { - // Handle streamed response of type text/event-stream or text/plain - if (chunk && chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse)); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else { - // If the chunk is not a JSON object, just display it as is - chatMessageState.rawResponse += chunk; - if (chatMessageState.newResponseTextEl) { - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + return { + objects: objects, + remainder: openBraces > 0 ? str.slice(startIndex) : '' }; } - }; - sseConnection = new EventSource(sseUrl); - sseConnection.onmessage = handleChatResponse; - sseConnection.addEventListener("complete_llm_response", handleChatResponse); - sseConnection.addEventListener("status", (event) => { - console.log(`${event.data}`); - handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false); - }); - sseConnection.addEventListener("rate_limit", (event) => { - handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true); - }); - sseConnection.addEventListener("start_llm_response", (event) => { - console.log("Started streaming", new Date()); - }); - sseConnection.addEventListener("end_llm_response", (event) => { - sseConnection.close(); - console.log("Stopped streaming", new Date()); + function handleChunk(rawChunk) { + // Split the chunk into lines + console.log("Chunk:", rawChunk); + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + processChunk(jsonChunk); + } catch (e) { + const jsonChunk = {type: 'message', data: rawChunk}; + processChunk(jsonChunk); + } + } else if (rawChunk.length > 0) { + const jsonChunk = {type: 'message', data: rawChunk}; + processChunk(jsonChunk); + } + } + function processChunk(chunk) { + console.log(chunk); + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); - // Automatically respond with voice if the subscribed user has sent voice message - if (chatMessageState.isVoice && "{{ is_active }}" == "True") - textToSpeech(chatMessageState.rawResponse); + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); - // Append any references after all the data has been streamed - finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); - const liveQuery = chatMessageState.rawQuery; - // Reset variables - chatMessageState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + } + } else if (chunk.type === "references") { + const rawReferenceAsJson = JSON.parse(chunk.data); + console.log(`${chunk.type}: ${rawReferenceAsJson}`); + chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; + } else if (chunk.type === 'message') { + if (chunk.data.trim()?.startsWith("{") && chunk.data.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunk.data.trim()); + handleJsonResponse(jsonData); + } catch (e) { + // Handle text response chunk with compiled references + if (chunk?.data.includes("### compiled references:")) { + chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; + // Handle text response chunk + } else { + chatMessageState.rawResponse += chunk.data; + } + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + // Handle text response chunk with compiled references + if (chunk?.data.includes("### compiled references:")) { + chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; + // Handle text response chunk + } else { + chatMessageState.rawResponse += chunk.data; + } + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } } - // Reset status icon - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Ready"; - statusDotText.style.marginTop = "5px"; - }); - sseConnection.onclose = function(event) { - sseConnection = null; - console.debug("SSE is closed now."); - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Ready"; - statusDotText.style.marginTop = "5px"; - } - sseConnection.onerror = function(event) { - console.log("SSE error observed:", event); - sseConnection.close(); - sseConnection = null; - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "red"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Server Error"; - if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { - chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + chatMessageState.references = { + notes: jsonData.context || {}, + online: jsonData.online_results || {} + }; + } + addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); } - chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" - } - sseConnection.onopen = function(event) { - console.debug("SSE is open now.") - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "orange"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Processing"; } } - function sendMessageViaSSE(isVoice=false) { + function renderMessageStream(isVoice=false) { let chatBody = document.getElementById("chat-body"); var query = document.getElementById("chat-input").value.trim(); @@ -1253,7 +1296,7 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API - sendSSEMessage(query); + sendMessageStream(query); let rawResponse = ""; let references = {}; @@ -1267,6 +1310,7 @@ To get started, just start typing below. You can also type / to see a list of co isVoice: isVoice, } } + var userMessages = []; var userMessageIndex = -1; function loadChat() { @@ -1276,7 +1320,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatHistoryUrl = `/api/chat/history?client=web`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - initializeSSE(); + initMessageState(); loadFileFiltersFromConversation(); } @@ -1316,7 +1360,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatBody = document.getElementById("chat-body"); chatBody.dataset.conversationId = response.conversation_id; loadFileFiltersFromConversation(); - initializeSSE(); + initMessageState(); chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; let agentMetadata = response.agent; diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4c3603cf..e6b60282 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -11,7 +11,6 @@ from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse -from sse_starlette import EventSourceResponse from starlette.authentication import requires from khoj.app.settings import ALLOWED_HOSTS @@ -543,15 +542,24 @@ async def stream_chat( async def send_event(event_type: str, data: str): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): + connection_alive = False return try: if event_type == "message": yield data else: - yield {"event": event_type, "data": data, "retry": 15000} + yield json.dumps({"type": event_type, "data": data}) except Exception as e: connection_alive = False - logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}") + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + + async def send_llm_response(response: str): + async for result in send_event("start_llm_response", ""): + yield result + async for result in send_event("message", response): + yield result + async for result in send_event("end_llm_response", ""): + yield result user: KhojUser = request.user.object conversation = await ConversationAdapters.aget_conversation_by_user( @@ -585,17 +593,10 @@ async def stream_chat( except HTTPException as e: async for result in send_event("rate_limit", e.detail): yield result - break + return if is_query_empty(q): - async for event in send_event("start_llm_response", ""): - yield event - async for event in send_event( - "message", - "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", - ): - yield event - async for event in send_event("end_llm_response", ""): + async for event in send_llm_response("Please ask your query to get started."): yield event return @@ -645,25 +646,19 @@ async def stream_chat( response_log = ( "No files selected for summarization. Please add files using the section on the left." ) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event elif len(file_filters) > 1: response_log = "Only one file can be selected for summarization." - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event else: try: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) if len(file_object) == 0: response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event return contextual_data = " ".join([file.raw_text for file in file_object]) if not q: @@ -675,17 +670,13 @@ async def stream_chat( response = await extract_relevant_summary(q, contextual_data) response_log = str(response) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event except Exception as e: response_log = "Error summarizing file." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event await sync_to_async(save_to_conversation_log)( q, response_log, @@ -714,10 +705,8 @@ async def stream_chat( formatted_help = help_message.format( model=model_type, version=state.khoj_version, device=get_device() ) - async for result in send_event("complete_llm_response", formatted_help): + async for result in send_llm_response(formatted_help): yield result - async for event in send_event("end_llm_response", ""): - yield event return custom_filters.append("site:khoj.dev") conversation_commands.append(ConversationCommand.Online) @@ -730,10 +719,8 @@ async def stream_chat( except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_event("complete_llm_response", error_message): + async for result in send_llm_response(error_message): yield result - async for event in send_event("end_llm_response", ""): - yield event return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) @@ -760,10 +747,8 @@ async def stream_chat( api="chat", **common.__dict__, ) - async for result in send_event("complete_llm_response", llm_response): + async for result in send_llm_response(llm_response): yield result - async for event in send_event("end_llm_response", ""): - yield event return compiled_references, inferred_queries, defiltered_query = [], [], None @@ -797,9 +782,7 @@ async def stream_chat( if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( user ): - async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"): - yield result - async for event in send_event("end_llm_response", ""): + async for result in send_llm_response(f"{no_entries_found.format()}"): yield event return @@ -818,10 +801,8 @@ async def stream_chat( except ValueError as e: error_message = f"Error searching online: {e}. Attempting to respond without online results" logger.warning(error_message) - async for result in send_event("complete_llm_response", error_message): + async for result in send_llm_response(error_message): yield result - async for event in send_event("end_llm_response", ""): - yield event return if ConversationCommand.Webpage in conversation_commands: @@ -873,15 +854,13 @@ async def stream_chat( if image is None or status_code != 200: content_obj = { - "image": image, + "content-type": "application/json", "intentType": intent_type, "detail": improved_image_prompt, - "content-type": "application/json", + "image": image, } - async for result in send_event("complete_llm_response", json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj)): yield result - async for event in send_event("end_llm_response", ""): - yield event return await sync_to_async(save_to_conversation_log)( @@ -898,19 +877,22 @@ async def stream_chat( online_results=online_results, ) content_obj = { - "image": image, - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "context": compiled_references, "content-type": "application/json", + "intentType": intent_type, + "context": compiled_references, "online_results": online_results, + "inferredQueries": [improved_image_prompt], + "image": image, } - async for result in send_event("complete_llm_response", json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj)): yield result - async for event in send_event("end_llm_response", ""): - yield event return + async for result in send_event( + "references", json.dumps({"context": compiled_references, "online_results": online_results}) + ): + yield result + async for result in send_event("status", f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( @@ -941,27 +923,33 @@ async def stream_chat( async for result in send_event("start_llm_response", ""): yield result + continue_stream = True async for item in iterator: if item is None: - break - if connection_alive: - try: - async for result in send_event("message", f"{item}"): - yield result - except Exception as e: - connection_alive = False - logger.info( - f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}" - ) - async for result in send_event("end_llm_response", ""): - yield result + async for result in send_event("end_llm_response", ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + try: + async for result in send_event("message", f"{item}"): + yield result + except Exception as e: + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + # Stop streaming after compiled references section of response starts + # References are being processed via the references event rather than the message event + if "### compiled references:" in item: + continue_stream = False except asyncio.CancelledError: - break + logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True) + return except Exception as e: - logger.error(f"Error in SSE endpoint: {e}", exc_info=True) - break + logger.error(f"General Error in API endpoint: {e}", exc_info=True) + return - return EventSourceResponse(event_generator(q)) + return StreamingResponse(event_generator(q), media_type="text/plain") @api_chat.get("", response_class=Response)