diff --git a/pyproject.toml b/pyproject.toml index 939a1d9e..2669f5ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "dateparser >= 1.1.1", "defusedxml == 0.7.1", "fastapi >= 0.104.1", - "sse-starlette ~= 2.1.0", "python-multipart >= 0.0.7", "jinja2 == 3.1.4", "openai >= 1.0.0", diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 3e07a860..b1ff3eba 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -74,13 +74,12 @@ To get started, just start typing below. You can also type / to see a list of co }, 1000); }); } - var sseConnection = null; + let region = null; let city = null; let countryName = null; let timezone = null; let waitingForLocation = true; - let chatMessageState = { newResponseTextEl: null, newResponseEl: null, @@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co .finally(() => { console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); waitingForLocation = false; - initializeSSE(); + initMessageState(); }); function formatDate(date) { @@ -599,7 +598,7 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - sendMessageViaSSE(isVoice); + renderMessageStream(isVoice); return; let query = document.getElementById("chat-input").value.trim(); @@ -1067,7 +1066,7 @@ To get started, just start typing below. You can also type / to see a list of co window.onload = loadChat; - function initializeSSE(isVoice=false) { + function initMessageState(isVoice=false) { if (waitingForLocation) { console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); return; @@ -1084,136 +1083,180 @@ To get started, just start typing below. You can also type / to see a list of co } } - function sendSSEMessage(query) { + function sendMessageStream(query) { let chatBody = document.getElementById("chat-body"); - let sseProtocol = window.location.protocol; - let sseUrl = `/api/chat/stream?q=${query}`; + let chatStreamUrl = `/api/chat/stream?q=${query}`; if (chatBody.dataset.conversationId) { - sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; + chatStreamUrl += `&conversation_id=${chatBody.dataset.conversationId}`; + chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; - function handleChatResponse(event) { - // Get the last element in the chat-body - let chunk = event.data; - try { - if (chunk.includes("application/json")) - chunk = JSON.parse(chunk); - } catch (error) { - // If the chunk is not a JSON object, continue. + fetch(chatStreamUrl) + .then(response => { + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; + + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + console.log("Stream complete"); + handleChunk(buffer); + buffer = ''; + return; + } + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + chunks = processJsonObjects(buffer); + chunks.objects.forEach(obj => handleChunk(obj)); + buffer = chunks.remainder; + } + readStream(); + }); + } + + readStream(); + }) + .catch(error => { + console.error('Error:', error); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + } + chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" + }); + + function processJsonObjects(str) { + let startIndex = str.indexOf('{'); + if (startIndex === -1) return { objects: [str], remainder: '' }; + const objects = [str.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; + + for (let i = startIndex; i < str.length; i++) { + if (str[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (str[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = str.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } } - const contentType = chunk["content-type"] - if (contentType === "application/json") { - // Handle JSON response - try { - if (chunk.image || chunk.detail) { - ({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse)); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else { - rawResponse = chunk.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - chatMessageState.rawResponse += chunk; - } finally { - addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); - } - } else { - // Handle streamed response of type text/event-stream or text/plain - if (chunk && chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse)); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else { - // If the chunk is not a JSON object, just display it as is - chatMessageState.rawResponse += chunk; - if (chatMessageState.newResponseTextEl) { - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + return { + objects: objects, + remainder: openBraces > 0 ? str.slice(startIndex) : '' }; } - }; - sseConnection = new EventSource(sseUrl); - sseConnection.onmessage = handleChatResponse; - sseConnection.addEventListener("complete_llm_response", handleChatResponse); - sseConnection.addEventListener("status", (event) => { - console.log(`${event.data}`); - handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false); - }); - sseConnection.addEventListener("rate_limit", (event) => { - handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true); - }); - sseConnection.addEventListener("start_llm_response", (event) => { - console.log("Started streaming", new Date()); - }); - sseConnection.addEventListener("end_llm_response", (event) => { - sseConnection.close(); - console.log("Stopped streaming", new Date()); + function handleChunk(rawChunk) { + // Split the chunk into lines + console.log("Chunk:", rawChunk); + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + processChunk(jsonChunk); + } catch (e) { + const jsonChunk = {type: 'message', data: rawChunk}; + processChunk(jsonChunk); + } + } else if (rawChunk.length > 0) { + const jsonChunk = {type: 'message', data: rawChunk}; + processChunk(jsonChunk); + } + } + function processChunk(chunk) { + console.log(chunk); + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); - // Automatically respond with voice if the subscribed user has sent voice message - if (chatMessageState.isVoice && "{{ is_active }}" == "True") - textToSpeech(chatMessageState.rawResponse); + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); - // Append any references after all the data has been streamed - finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); - const liveQuery = chatMessageState.rawQuery; - // Reset variables - chatMessageState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + } + } else if (chunk.type === "references") { + const rawReferenceAsJson = JSON.parse(chunk.data); + console.log(`${chunk.type}: ${rawReferenceAsJson}`); + chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; + } else if (chunk.type === 'message') { + if (chunk.data.trim()?.startsWith("{") && chunk.data.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunk.data.trim()); + handleJsonResponse(jsonData); + } catch (e) { + // Handle text response chunk with compiled references + if (chunk?.data.includes("### compiled references:")) { + chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; + // Handle text response chunk + } else { + chatMessageState.rawResponse += chunk.data; + } + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + // Handle text response chunk with compiled references + if (chunk?.data.includes("### compiled references:")) { + chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; + // Handle text response chunk + } else { + chatMessageState.rawResponse += chunk.data; + } + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } } - // Reset status icon - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Ready"; - statusDotText.style.marginTop = "5px"; - }); - sseConnection.onclose = function(event) { - sseConnection = null; - console.debug("SSE is closed now."); - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Ready"; - statusDotText.style.marginTop = "5px"; - } - sseConnection.onerror = function(event) { - console.log("SSE error observed:", event); - sseConnection.close(); - sseConnection = null; - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "red"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Server Error"; - if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { - chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + chatMessageState.references = { + notes: jsonData.context || {}, + online: jsonData.online_results || {} + }; + } + addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); } - chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" - } - sseConnection.onopen = function(event) { - console.debug("SSE is open now.") - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "orange"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Processing"; } } - function sendMessageViaSSE(isVoice=false) { + function renderMessageStream(isVoice=false) { let chatBody = document.getElementById("chat-body"); var query = document.getElementById("chat-input").value.trim(); @@ -1253,7 +1296,7 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API - sendSSEMessage(query); + sendMessageStream(query); let rawResponse = ""; let references = {}; @@ -1267,6 +1310,7 @@ To get started, just start typing below. You can also type / to see a list of co isVoice: isVoice, } } + var userMessages = []; var userMessageIndex = -1; function loadChat() { @@ -1276,7 +1320,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatHistoryUrl = `/api/chat/history?client=web`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - initializeSSE(); + initMessageState(); loadFileFiltersFromConversation(); } @@ -1316,7 +1360,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatBody = document.getElementById("chat-body"); chatBody.dataset.conversationId = response.conversation_id; loadFileFiltersFromConversation(); - initializeSSE(); + initMessageState(); chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; let agentMetadata = response.agent; diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4c3603cf..e6b60282 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -11,7 +11,6 @@ from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse -from sse_starlette import EventSourceResponse from starlette.authentication import requires from khoj.app.settings import ALLOWED_HOSTS @@ -543,15 +542,24 @@ async def stream_chat( async def send_event(event_type: str, data: str): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): + connection_alive = False return try: if event_type == "message": yield data else: - yield {"event": event_type, "data": data, "retry": 15000} + yield json.dumps({"type": event_type, "data": data}) except Exception as e: connection_alive = False - logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}") + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + + async def send_llm_response(response: str): + async for result in send_event("start_llm_response", ""): + yield result + async for result in send_event("message", response): + yield result + async for result in send_event("end_llm_response", ""): + yield result user: KhojUser = request.user.object conversation = await ConversationAdapters.aget_conversation_by_user( @@ -585,17 +593,10 @@ async def stream_chat( except HTTPException as e: async for result in send_event("rate_limit", e.detail): yield result - break + return if is_query_empty(q): - async for event in send_event("start_llm_response", ""): - yield event - async for event in send_event( - "message", - "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", - ): - yield event - async for event in send_event("end_llm_response", ""): + async for event in send_llm_response("Please ask your query to get started."): yield event return @@ -645,25 +646,19 @@ async def stream_chat( response_log = ( "No files selected for summarization. Please add files using the section on the left." ) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event elif len(file_filters) > 1: response_log = "Only one file can be selected for summarization." - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event else: try: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) if len(file_object) == 0: response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event return contextual_data = " ".join([file.raw_text for file in file_object]) if not q: @@ -675,17 +670,13 @@ async def stream_chat( response = await extract_relevant_summary(q, contextual_data) response_log = str(response) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event except Exception as e: response_log = "Error summarizing file." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event await sync_to_async(save_to_conversation_log)( q, response_log, @@ -714,10 +705,8 @@ async def stream_chat( formatted_help = help_message.format( model=model_type, version=state.khoj_version, device=get_device() ) - async for result in send_event("complete_llm_response", formatted_help): + async for result in send_llm_response(formatted_help): yield result - async for event in send_event("end_llm_response", ""): - yield event return custom_filters.append("site:khoj.dev") conversation_commands.append(ConversationCommand.Online) @@ -730,10 +719,8 @@ async def stream_chat( except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_event("complete_llm_response", error_message): + async for result in send_llm_response(error_message): yield result - async for event in send_event("end_llm_response", ""): - yield event return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) @@ -760,10 +747,8 @@ async def stream_chat( api="chat", **common.__dict__, ) - async for result in send_event("complete_llm_response", llm_response): + async for result in send_llm_response(llm_response): yield result - async for event in send_event("end_llm_response", ""): - yield event return compiled_references, inferred_queries, defiltered_query = [], [], None @@ -797,9 +782,7 @@ async def stream_chat( if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( user ): - async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"): - yield result - async for event in send_event("end_llm_response", ""): + async for result in send_llm_response(f"{no_entries_found.format()}"): yield event return @@ -818,10 +801,8 @@ async def stream_chat( except ValueError as e: error_message = f"Error searching online: {e}. Attempting to respond without online results" logger.warning(error_message) - async for result in send_event("complete_llm_response", error_message): + async for result in send_llm_response(error_message): yield result - async for event in send_event("end_llm_response", ""): - yield event return if ConversationCommand.Webpage in conversation_commands: @@ -873,15 +854,13 @@ async def stream_chat( if image is None or status_code != 200: content_obj = { - "image": image, + "content-type": "application/json", "intentType": intent_type, "detail": improved_image_prompt, - "content-type": "application/json", + "image": image, } - async for result in send_event("complete_llm_response", json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj)): yield result - async for event in send_event("end_llm_response", ""): - yield event return await sync_to_async(save_to_conversation_log)( @@ -898,19 +877,22 @@ async def stream_chat( online_results=online_results, ) content_obj = { - "image": image, - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "context": compiled_references, "content-type": "application/json", + "intentType": intent_type, + "context": compiled_references, "online_results": online_results, + "inferredQueries": [improved_image_prompt], + "image": image, } - async for result in send_event("complete_llm_response", json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj)): yield result - async for event in send_event("end_llm_response", ""): - yield event return + async for result in send_event( + "references", json.dumps({"context": compiled_references, "online_results": online_results}) + ): + yield result + async for result in send_event("status", f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( @@ -941,27 +923,33 @@ async def stream_chat( async for result in send_event("start_llm_response", ""): yield result + continue_stream = True async for item in iterator: if item is None: - break - if connection_alive: - try: - async for result in send_event("message", f"{item}"): - yield result - except Exception as e: - connection_alive = False - logger.info( - f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}" - ) - async for result in send_event("end_llm_response", ""): - yield result + async for result in send_event("end_llm_response", ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + try: + async for result in send_event("message", f"{item}"): + yield result + except Exception as e: + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + # Stop streaming after compiled references section of response starts + # References are being processed via the references event rather than the message event + if "### compiled references:" in item: + continue_stream = False except asyncio.CancelledError: - break + logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True) + return except Exception as e: - logger.error(f"Error in SSE endpoint: {e}", exc_info=True) - break + logger.error(f"General Error in API endpoint: {e}", exc_info=True) + return - return EventSourceResponse(event_generator(q)) + return StreamingResponse(event_generator(q), media_type="text/plain") @api_chat.get("", response_class=Response)