From 91fe41106eb244191cffac5d8742d7a48b1a9b9d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 21 Jul 2024 12:10:13 +0530 Subject: [PATCH 01/20] Convert Websocket into Server Side Event (SSE) API endpoint - Convert functions in SSE API path into async generators using yields - Validate image generation, online, notes lookup and general paths of chat request are handled fine by the web client and server API --- pyproject.toml | 1 + src/khoj/interface/web/chat.html | 235 +++---- src/khoj/processor/tools/online_search.py | 19 +- src/khoj/routers/api.py | 15 +- src/khoj/routers/api_chat.py | 774 ++++++++++++---------- src/khoj/routers/helpers.py | 22 +- 6 files changed, 577 insertions(+), 489 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2669f5ff..939a1d9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "dateparser >= 1.1.1", "defusedxml == 0.7.1", "fastapi >= 0.104.1", + "sse-starlette ~= 2.1.0", "python-multipart >= 0.0.7", "jinja2 == 3.1.4", "openai >= 1.0.0", diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index ad8ced27..3e07a860 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -74,14 +74,14 @@ To get started, just start typing below. You can also type / to see a list of co }, 1000); }); } - var websocket = null; + var sseConnection = null; let region = null; let city = null; let countryName = null; let timezone = null; let waitingForLocation = true; - let websocketState = { + let chatMessageState = { newResponseTextEl: null, newResponseEl: null, loadingEllipsis: null, @@ -105,7 +105,7 @@ To get started, just start typing below. You can also type / to see a list of co .finally(() => { console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); waitingForLocation = false; - setupWebSocket(); + initializeSSE(); }); function formatDate(date) { @@ -599,10 +599,8 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - if (websocket) { - sendMessageViaWebSocket(isVoice); - return; - } + sendMessageViaSSE(isVoice); + return; let query = document.getElementById("chat-input").value.trim(); let resultsCount = localStorage.getItem("khojResultsCount") || 5; @@ -1069,17 +1067,13 @@ To get started, just start typing below. You can also type / to see a list of co window.onload = loadChat; - function setupWebSocket(isVoice=false) { - let chatBody = document.getElementById("chat-body"); - let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`; - + function initializeSSE(isVoice=false) { if (waitingForLocation) { console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); return; } - websocketState = { + chatMessageState = { newResponseTextEl: null, newResponseEl: null, loadingEllipsis: null, @@ -1088,121 +1082,138 @@ To get started, just start typing below. You can also type / to see a list of co rawQuery: "", isVoice: isVoice, } + } + + function sendSSEMessage(query) { + let chatBody = document.getElementById("chat-body"); + let sseProtocol = window.location.protocol; + let sseUrl = `/api/chat/stream?q=${query}`; if (chatBody.dataset.conversationId) { - webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`; - webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; - - websocket = new WebSocket(webSocketUrl); - websocket.onmessage = function(event) { + sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`; + sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; + function handleChatResponse(event) { // Get the last element in the chat-body let chunk = event.data; - if (chunk == "start_llm_response") { - console.log("Started streaming", new Date()); - } else if (chunk == "end_llm_response") { - console.log("Stopped streaming", new Date()); + try { + if (chunk.includes("application/json")) + chunk = JSON.parse(chunk); + } catch (error) { + // If the chunk is not a JSON object, continue. + } - // Automatically respond with voice if the subscribed user has sent voice message - if (websocketState.isVoice && "{{ is_active }}" == "True") - textToSpeech(websocketState.rawResponse); - - // Append any references after all the data has been streamed - finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl); - - const liveQuery = websocketState.rawQuery; - // Reset variables - websocketState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, - isVoice: false, - } - } else { + const contentType = chunk["content-type"] + if (contentType === "application/json") { + // Handle JSON response try { - if (chunk.includes("application/json")) - { - chunk = JSON.parse(chunk); + if (chunk.image || chunk.detail) { + ({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse)); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else { + rawResponse = chunk.response; } } catch (error) { - // If the chunk is not a JSON object, continue. + // If the chunk is not a JSON object, just display it as is + chatMessageState.rawResponse += chunk; + } finally { + addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + if (chunk && chunk.includes("### compiled references:")) { + ({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse)); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else { + // If the chunk is not a JSON object, just display it as is + chatMessageState.rawResponse += chunk; + if (chatMessageState.newResponseTextEl) { + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } } - const contentType = chunk["content-type"] - - if (contentType === "application/json") { - // Handle JSON response - try { - if (chunk.image || chunk.detail) { - ({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse)); - websocketState.rawResponse = rawResponse; - websocketState.references = references; - } else if (chunk.type == "status") { - handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false); - } else if (chunk.type == "rate_limit") { - handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true); - } else { - rawResponse = chunk.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - websocketState.rawResponse += chunk; - } finally { - if (chunk.type != "status" && chunk.type != "rate_limit") { - addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references); - } - } - } else { - - // Handle streamed response of type text/event-stream or text/plain - if (chunk && chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse)); - websocketState.rawResponse = rawResponse; - websocketState.references = references; - } else { - // If the chunk is not a JSON object, just display it as is - websocketState.rawResponse += chunk; - if (websocketState.newResponseTextEl) { - handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }; - } + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }; } }; - websocket.onclose = function(event) { - websocket = null; - console.log("WebSocket is closed now."); - let setupWebSocketButton = document.createElement("button"); - setupWebSocketButton.textContent = "Reconnect to Server"; - setupWebSocketButton.onclick = setupWebSocket; - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "red"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.innerHTML = ""; - statusDotText.style.marginTop = "5px"; - statusDotText.appendChild(setupWebSocketButton); - } - websocket.onerror = function(event) { - console.log("WebSocket error observed:", event); - } - websocket.onopen = function(event) { - console.log("WebSocket is open now.") + sseConnection = new EventSource(sseUrl); + sseConnection.onmessage = handleChatResponse; + sseConnection.addEventListener("complete_llm_response", handleChatResponse); + sseConnection.addEventListener("status", (event) => { + console.log(`${event.data}`); + handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false); + }); + sseConnection.addEventListener("rate_limit", (event) => { + handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true); + }); + sseConnection.addEventListener("start_llm_response", (event) => { + console.log("Started streaming", new Date()); + }); + sseConnection.addEventListener("end_llm_response", (event) => { + sseConnection.close(); + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + } + + // Reset status icon let statusDotIcon = document.getElementById("connection-status-icon"); statusDotIcon.style.backgroundColor = "green"; let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Connected to Server"; + statusDotText.textContent = "Ready"; + statusDotText.style.marginTop = "5px"; + }); + sseConnection.onclose = function(event) { + sseConnection = null; + console.debug("SSE is closed now."); + let statusDotIcon = document.getElementById("connection-status-icon"); + statusDotIcon.style.backgroundColor = "green"; + let statusDotText = document.getElementById("connection-status-text"); + statusDotText.textContent = "Ready"; + statusDotText.style.marginTop = "5px"; + } + sseConnection.onerror = function(event) { + console.log("SSE error observed:", event); + sseConnection.close(); + sseConnection = null; + let statusDotIcon = document.getElementById("connection-status-icon"); + statusDotIcon.style.backgroundColor = "red"; + let statusDotText = document.getElementById("connection-status-text"); + statusDotText.textContent = "Server Error"; + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + } + chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" + } + sseConnection.onopen = function(event) { + console.debug("SSE is open now.") + let statusDotIcon = document.getElementById("connection-status-icon"); + statusDotIcon.style.backgroundColor = "orange"; + let statusDotText = document.getElementById("connection-status-text"); + statusDotText.textContent = "Processing"; } } - function sendMessageViaWebSocket(isVoice=false) { + function sendMessageViaSSE(isVoice=false) { let chatBody = document.getElementById("chat-body"); var query = document.getElementById("chat-input").value.trim(); @@ -1242,11 +1253,11 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API - websocket.send(query); + sendSSEMessage(query); let rawResponse = ""; let references = {}; - websocketState = { + chatMessageState = { newResponseTextEl, newResponseEl, loadingEllipsis, @@ -1265,7 +1276,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatHistoryUrl = `/api/chat/history?client=web`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - setupWebSocket(); + initializeSSE(); loadFileFiltersFromConversation(); } @@ -1305,7 +1316,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatBody = document.getElementById("chat-body"); chatBody.dataset.conversationId = response.conversation_id; loadFileFiltersFromConversation(); - setupWebSocket(); + initializeSSE(); chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; let agentMetadata = response.agent; diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 72191077..1f8a5c9e 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -56,7 +56,8 @@ async def search_online( query += " ".join(custom_filters) if not is_internet_connected(): logger.warn("Cannot search online as not connected to internet") - return {} + yield {} + return # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries(query, conversation_history, location) @@ -66,7 +67,8 @@ async def search_online( logger.info(f"🌐 Searching the Internet for {list(subqueries)}") if send_status_func: subqueries_str = "\n- " + "\n- ".join(list(subqueries)) - await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}") + async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"): + yield {"status": event} with timer(f"Internet searches for {list(subqueries)} took", logger): search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina @@ -89,7 +91,8 @@ async def search_online( logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) - await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") + async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): + yield {"status": event} tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] results = await asyncio.gather(*tasks) @@ -98,7 +101,7 @@ async def search_online( if webpage_extract is not None: response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract} - return response_dict + yield response_dict async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: @@ -127,13 +130,15 @@ async def read_webpages( "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") if send_status_func: - await send_status_func(f"**🧐 Inferring web pages to read**") + async for event in send_status_func(f"**🧐 Inferring web pages to read**"): + yield {"status": event} urls = await infer_webpage_urls(query, conversation_history, location) logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) - await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") + async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): + yield {"status": event} tasks = [read_webpage_and_extract_content(query, url) for url in urls] results = await asyncio.gather(*tasks) @@ -141,7 +146,7 @@ async def read_webpages( response[query]["webpages"] = [ {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None ] - return response + yield response async def read_webpage_and_extract_content( diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cbe19891..836b963f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,7 +6,6 @@ import os import threading import time import uuid -from random import random from typing import Any, Callable, List, Optional, Union import cron_descriptor @@ -298,11 +297,13 @@ async def extract_references_and_questions( not ConversationCommand.Notes in conversation_commands and not ConversationCommand.Default in conversation_commands ): - return compiled_references, inferred_queries, q + yield compiled_references, inferred_queries, q + return if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") - return compiled_references, inferred_queries, q + yield compiled_references, inferred_queries, q + return # Extract filter terms from user message defiltered_query = q @@ -313,7 +314,8 @@ async def extract_references_and_questions( if not conversation: logger.error(f"Conversation with id {conversation_id} not found.") - return compiled_references, inferred_queries, defiltered_query + yield compiled_references, inferred_queries, defiltered_query + return filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) using_offline_chat = False @@ -372,7 +374,8 @@ async def extract_references_and_questions( logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") if send_status_func: inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) - await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}") + async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"): + yield {"status": event} for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n search_results.extend( @@ -391,7 +394,7 @@ async def extract_references_and_questions( {"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results ] - return compiled_references, inferred_queries, defiltered_query + yield compiled_references, inferred_queries, defiltered_query @api.get("/health", response_class=Response) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index be28622b..4c3603cf 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,17 +1,18 @@ +import asyncio import json import logging import math from datetime import datetime +from functools import partial from typing import Any, Dict, List, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async -from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse +from sse_starlette import EventSourceResponse from starlette.authentication import requires -from starlette.websockets import WebSocketDisconnect -from websockets import ConnectionClosedOK from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -526,380 +527,441 @@ async def set_conversation_title( ) -@api_chat.websocket("/ws") -async def websocket_endpoint( - websocket: WebSocket, +@api_chat.get("/stream") +async def stream_chat( + request: Request, + q: str, conversation_id: int, city: Optional[str] = None, region: Optional[str] = None, country: Optional[str] = None, timezone: Optional[str] = None, ): - connection_alive = True + async def event_generator(q: str): + connection_alive = True - async def send_status_update(message: str): - nonlocal connection_alive - if not connection_alive: - return - - status_packet = { - "type": "status", - "message": message, - "content-type": "application/json", - } - try: - await websocket.send_text(json.dumps(status_packet)) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_complete_llm_response(llm_response: str): - nonlocal connection_alive - if not connection_alive: - return - try: - await websocket.send_text("start_llm_response") - await websocket.send_text(llm_response) - await websocket.send_text("end_llm_response") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_message(message: str): - nonlocal connection_alive - if not connection_alive: - return - try: - await websocket.send_text(message) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_rate_limit_message(message: str): - nonlocal connection_alive - if not connection_alive: - return - - status_packet = { - "type": "rate_limit", - "message": message, - "content-type": "application/json", - } - try: - await websocket.send_text(json.dumps(status_packet)) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - user: KhojUser = websocket.user.object - conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=websocket.user.client_app, conversation_id=conversation_id - ) - - hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - - daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - - location = None - - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - await websocket.accept() - while connection_alive: - try: - if conversation: - await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) - q = await websocket.receive_text() - - # Refresh these because the connection to the database might have been closed - await conversation.arefresh_from_db() - - except WebSocketDisconnect: - logger.debug(f"User {user} disconnected web socket") - break - - try: - await sync_to_async(hourly_limiter)(websocket) - await sync_to_async(daily_limiter)(websocket) - except HTTPException as e: - await send_rate_limit_message(e.detail) - break - - if is_query_empty(q): - await send_message("start_llm_response") - await send_message( - "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?" - ) - await send_message("end_llm_response") - continue - - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - await send_status_update(f"**👀 Understanding Query**: {q}") - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}") - - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}") - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = "No files selected for summarization. Please add files using the section on the left." - await send_complete_llm_response(response_log) - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - await send_complete_llm_response(response_log) - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - await send_complete_llm_response(response_log) - continue - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}") - response = await extract_relevant_summary(q, contextual_data) - response_log = str(response) - await send_complete_llm_response(response_log) - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - await send_complete_llm_response(response_log) - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=websocket.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - continue - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - await send_complete_llm_response(formatted_help) - continue - # Adding specification to search online specifically on khoj.dev pages. - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: + async def send_event(event_type: str, data: str): + nonlocal connection_alive + if not connection_alive or await request.is_disconnected(): + return try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, websocket.url, meta_log - ) + if event_type == "message": + yield data + else: + yield {"event": event_type, "data": data, "retry": 15000} except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - await send_complete_llm_response( - f"Unable to create automation. Ensure the automation doesn't already exist." - ) - continue + connection_alive = False + logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}") - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=websocket.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - common = CommonQueryParamsClass( - client=websocket.user.client_app, - user_agent=websocket.headers.get("user-agent"), - host=websocket.headers.get("host"), - ) - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - **common.__dict__, - ) - await send_complete_llm_response(llm_response) - continue - - compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update + user: KhojUser = request.user.object + conversation = await ConversationAdapters.aget_conversation_by_user( + user, client_application=request.user.client_app, conversation_id=conversation_id ) - if compiled_references: - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - await send_status_update(f"**📜 Found Relevant Notes**: {headings}") + hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - online_results: Dict = dict() + daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - await send_complete_llm_response(f"{no_entries_found.format()}") - continue + await is_ready_to_chat(user) - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) + user_name = await aget_user_name(user) - if ConversationCommand.Online in conversation_commands: + location = None + + if city or region or country: + location = LocationData(city=city, region=region, country=country) + + while connection_alive: try: - online_results = await search_online( - defiltered_query, meta_log, location, send_status_update, custom_filters - ) - except ValueError as e: - logger.warning(f"Error searching online: {e}. Attempting to respond without online results") - await send_complete_llm_response( - f"Error searching online: {e}. Attempting to respond without online results" - ) - continue + if conversation: + await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) - if ConversationCommand.Webpage in conversation_commands: - try: - direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update) - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + # Refresh these because the connection to the database might have been closed + await conversation.arefresh_from_db() - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - - await send_status_update(f"**📚 Read web pages**: {webpages}") - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - image, status_code, improved_image_prompt, intent_type = await text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=send_status_update, - ) - if image is None or status_code != 200: - content_obj = { - "image": image, - "intentType": intent_type, - "detail": improved_image_prompt, - "content-type": "application/json", - } - await send_complete_llm_response(json.dumps(content_obj)) - continue - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=websocket.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore - - await send_complete_llm_response(json.dumps(content_obj)) - continue - - await send_status_update(f"**💭 Generating a well-informed response**") - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - websocket.user.client_app, - conversation_id, - location, - user_name, - ) - - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) - iterator = AsyncIteratorWrapper(llm_response) - - await send_message("start_llm_response") - - async for item in iterator: - if item is None: - break - if connection_alive: try: - await send_message(f"{item}") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + await sync_to_async(hourly_limiter)(request) + await sync_to_async(daily_limiter)(request) + except HTTPException as e: + async for result in send_event("rate_limit", e.detail): + yield result + break - await send_message("end_llm_response") + if is_query_empty(q): + async for event in send_event("start_llm_response", ""): + yield event + async for event in send_event( + "message", + "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", + ): + yield event + async for event in send_event("end_llm_response", ""): + yield event + return + + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_commands = [get_conversation_command(query=q, any_references=True)] + + async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + yield result + + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result + + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) + async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands: + response_log = "" + if len(file_filters) == 0: + response_log = ( + "No files selected for summarization. Please add files using the section on the left." + ) + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + elif len(file_filters) > 1: + response_log = "Only one file can be selected for summarization." + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + else: + try: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + if len(file_object) == 0: + response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" + async for result in send_event( + "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" + ): + yield result + + response = await extract_relevant_summary(q, contextual_data) + response_log = str(response) + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + except Exception as e: + response_log = "Error summarizing file." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + return + + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format( + model=model_type, version=state.khoj_version, device=get_device() + ) + async for result in send_event("complete_llm_response", formatted_help): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_event("complete_llm_response", error_message): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + ) + common = CommonQueryParamsClass( + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + **common.__dict__, + ) + async for result in send_event("complete_llm_response", llm_response): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + meta_log, + q, + 7, + 0.18, + conversation_id, + conversation_commands, + location, + partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join( + set([c.get("compiled", c).split("\n")[0] for c in compiled_references]) + ) + async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): + yield result + + online_results: Dict = dict() + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( + user + ): + async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_event("complete_llm_response", error_message): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages( + defiltered_query, meta_log, location, partial(send_event, "status") + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + if ConversationCommand.Image in conversation_commands: + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + send_status_func=partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "image": image, + "intentType": intent_type, + "detail": improved_image_prompt, + "content-type": "application/json", + } + async for result in send_event("complete_llm_response", json.dumps(content_obj)): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + ) + content_obj = { + "image": image, + "intentType": intent_type, + "inferredQueries": [improved_image_prompt], + "context": compiled_references, + "content-type": "application/json", + "online_results": online_results, + } + async for result in send_event("complete_llm_response", json.dumps(content_obj)): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + async for result in send_event("status", f"**💭 Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + ) + + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata=chat_metadata, + ) + iterator = AsyncIteratorWrapper(llm_response) + + async for result in send_event("start_llm_response", ""): + yield result + + async for item in iterator: + if item is None: + break + if connection_alive: + try: + async for result in send_event("message", f"{item}"): + yield result + except Exception as e: + connection_alive = False + logger.info( + f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}" + ) + async for result in send_event("end_llm_response", ""): + yield result + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in SSE endpoint: {e}", exc_info=True) + break + + return EventSourceResponse(event_generator(q)) @api_chat.get("", response_class=Response) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e0f91df7..d23df6f0 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -755,7 +755,7 @@ async def text_to_image( references: List[Dict[str, Any]], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, -) -> Tuple[Optional[str], int, Optional[str], str]: +): status_code = 200 image = None response = None @@ -767,7 +767,8 @@ async def text_to_image( # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return text2image_model = text_to_image_config.model_name chat_history = "" @@ -781,7 +782,8 @@ async def text_to_image( with timer("Improve the original user query", logger): if send_status_func: - await send_status_func("**✍🏽 Enhancing the Painting Prompt**") + async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): + yield {"status": event} improved_image_prompt = await generate_better_image_prompt( message, chat_history, @@ -792,7 +794,8 @@ async def text_to_image( ) if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): + yield {"status": event} if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): @@ -817,12 +820,14 @@ async def text_to_image( logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: with timer("Generate image with Stability AI", logger): @@ -844,7 +849,8 @@ async def text_to_image( logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with Stability AI error: {e}" status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return with timer("Convert image to webp", logger): # Convert png to webp for faster loading @@ -864,7 +870,7 @@ async def text_to_image( intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 image = base64.b64encode(webp_image_bytes).decode("utf-8") - return image_url or image, status_code, improved_image_prompt, intent_type.value + yield image_url or image, status_code, improved_image_prompt, intent_type.value class ApiUserRateLimiter: From b8d3e3669ac14b752ee08d96e65b2f3d2d1bfb41 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 00:20:23 +0530 Subject: [PATCH 02/20] Stream Status Messages via Streaming Response from server to web client - Overview Use simpler HTTP Streaming Response to send status messages, alongside response and references from server to clients via API. Update web client to use the streamed response to show train of thought, stream response and render references. - Motivation This should allow other Khoj clients to pass auth headers and recieve Khoj's train of thought messages from server over simple HTTP streaming API. It'll also eventually deduplicate chat logic across /websocket and /chat API endpoints and help maintainability and dev velocity - Details - Pass references as a separate streaming message type for simpler parsing. Remove passing "### compiled references" altogether once the original /api/chat API is deprecated/merged with the new one and clients have been updated to consume the references using this new mechanism - Save message to conversation even if client disconnects. This is done by not breaking out of the async iterator that is sending the llm response. As the save conversation is called at the end of the iteration - Handle parsing chunked json responses as a valid json on client. This requires additional logic on client side but makes the client more robust to server chunking json response such that each chunk isn't itself necessarily a valid json. --- pyproject.toml | 1 - src/khoj/interface/web/chat.html | 284 ++++++++++++++++++------------- src/khoj/routers/api_chat.py | 128 +++++++------- 3 files changed, 222 insertions(+), 191 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 939a1d9e..2669f5ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "dateparser >= 1.1.1", "defusedxml == 0.7.1", "fastapi >= 0.104.1", - "sse-starlette ~= 2.1.0", "python-multipart >= 0.0.7", "jinja2 == 3.1.4", "openai >= 1.0.0", diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 3e07a860..b1ff3eba 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -74,13 +74,12 @@ To get started, just start typing below. You can also type / to see a list of co }, 1000); }); } - var sseConnection = null; + let region = null; let city = null; let countryName = null; let timezone = null; let waitingForLocation = true; - let chatMessageState = { newResponseTextEl: null, newResponseEl: null, @@ -105,7 +104,7 @@ To get started, just start typing below. You can also type / to see a list of co .finally(() => { console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); waitingForLocation = false; - initializeSSE(); + initMessageState(); }); function formatDate(date) { @@ -599,7 +598,7 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - sendMessageViaSSE(isVoice); + renderMessageStream(isVoice); return; let query = document.getElementById("chat-input").value.trim(); @@ -1067,7 +1066,7 @@ To get started, just start typing below. You can also type / to see a list of co window.onload = loadChat; - function initializeSSE(isVoice=false) { + function initMessageState(isVoice=false) { if (waitingForLocation) { console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); return; @@ -1084,136 +1083,180 @@ To get started, just start typing below. You can also type / to see a list of co } } - function sendSSEMessage(query) { + function sendMessageStream(query) { let chatBody = document.getElementById("chat-body"); - let sseProtocol = window.location.protocol; - let sseUrl = `/api/chat/stream?q=${query}`; + let chatStreamUrl = `/api/chat/stream?q=${query}`; if (chatBody.dataset.conversationId) { - sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; + chatStreamUrl += `&conversation_id=${chatBody.dataset.conversationId}`; + chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; - function handleChatResponse(event) { - // Get the last element in the chat-body - let chunk = event.data; - try { - if (chunk.includes("application/json")) - chunk = JSON.parse(chunk); - } catch (error) { - // If the chunk is not a JSON object, continue. + fetch(chatStreamUrl) + .then(response => { + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; + + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + console.log("Stream complete"); + handleChunk(buffer); + buffer = ''; + return; + } + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + chunks = processJsonObjects(buffer); + chunks.objects.forEach(obj => handleChunk(obj)); + buffer = chunks.remainder; + } + readStream(); + }); + } + + readStream(); + }) + .catch(error => { + console.error('Error:', error); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + } + chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" + }); + + function processJsonObjects(str) { + let startIndex = str.indexOf('{'); + if (startIndex === -1) return { objects: [str], remainder: '' }; + const objects = [str.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; + + for (let i = startIndex; i < str.length; i++) { + if (str[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (str[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = str.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } } - const contentType = chunk["content-type"] - if (contentType === "application/json") { - // Handle JSON response - try { - if (chunk.image || chunk.detail) { - ({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse)); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else { - rawResponse = chunk.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - chatMessageState.rawResponse += chunk; - } finally { - addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); - } - } else { - // Handle streamed response of type text/event-stream or text/plain - if (chunk && chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse)); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else { - // If the chunk is not a JSON object, just display it as is - chatMessageState.rawResponse += chunk; - if (chatMessageState.newResponseTextEl) { - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + return { + objects: objects, + remainder: openBraces > 0 ? str.slice(startIndex) : '' }; } - }; - sseConnection = new EventSource(sseUrl); - sseConnection.onmessage = handleChatResponse; - sseConnection.addEventListener("complete_llm_response", handleChatResponse); - sseConnection.addEventListener("status", (event) => { - console.log(`${event.data}`); - handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false); - }); - sseConnection.addEventListener("rate_limit", (event) => { - handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true); - }); - sseConnection.addEventListener("start_llm_response", (event) => { - console.log("Started streaming", new Date()); - }); - sseConnection.addEventListener("end_llm_response", (event) => { - sseConnection.close(); - console.log("Stopped streaming", new Date()); + function handleChunk(rawChunk) { + // Split the chunk into lines + console.log("Chunk:", rawChunk); + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + processChunk(jsonChunk); + } catch (e) { + const jsonChunk = {type: 'message', data: rawChunk}; + processChunk(jsonChunk); + } + } else if (rawChunk.length > 0) { + const jsonChunk = {type: 'message', data: rawChunk}; + processChunk(jsonChunk); + } + } + function processChunk(chunk) { + console.log(chunk); + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); - // Automatically respond with voice if the subscribed user has sent voice message - if (chatMessageState.isVoice && "{{ is_active }}" == "True") - textToSpeech(chatMessageState.rawResponse); + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); - // Append any references after all the data has been streamed - finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); - const liveQuery = chatMessageState.rawQuery; - // Reset variables - chatMessageState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + } + } else if (chunk.type === "references") { + const rawReferenceAsJson = JSON.parse(chunk.data); + console.log(`${chunk.type}: ${rawReferenceAsJson}`); + chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; + } else if (chunk.type === 'message') { + if (chunk.data.trim()?.startsWith("{") && chunk.data.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunk.data.trim()); + handleJsonResponse(jsonData); + } catch (e) { + // Handle text response chunk with compiled references + if (chunk?.data.includes("### compiled references:")) { + chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; + // Handle text response chunk + } else { + chatMessageState.rawResponse += chunk.data; + } + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + // Handle text response chunk with compiled references + if (chunk?.data.includes("### compiled references:")) { + chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; + // Handle text response chunk + } else { + chatMessageState.rawResponse += chunk.data; + } + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } } - // Reset status icon - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Ready"; - statusDotText.style.marginTop = "5px"; - }); - sseConnection.onclose = function(event) { - sseConnection = null; - console.debug("SSE is closed now."); - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "green"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Ready"; - statusDotText.style.marginTop = "5px"; - } - sseConnection.onerror = function(event) { - console.log("SSE error observed:", event); - sseConnection.close(); - sseConnection = null; - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "red"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Server Error"; - if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { - chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + chatMessageState.references = { + notes: jsonData.context || {}, + online: jsonData.online_results || {} + }; + } + addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); } - chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" - } - sseConnection.onopen = function(event) { - console.debug("SSE is open now.") - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "orange"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Processing"; } } - function sendMessageViaSSE(isVoice=false) { + function renderMessageStream(isVoice=false) { let chatBody = document.getElementById("chat-body"); var query = document.getElementById("chat-input").value.trim(); @@ -1253,7 +1296,7 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API - sendSSEMessage(query); + sendMessageStream(query); let rawResponse = ""; let references = {}; @@ -1267,6 +1310,7 @@ To get started, just start typing below. You can also type / to see a list of co isVoice: isVoice, } } + var userMessages = []; var userMessageIndex = -1; function loadChat() { @@ -1276,7 +1320,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatHistoryUrl = `/api/chat/history?client=web`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - initializeSSE(); + initMessageState(); loadFileFiltersFromConversation(); } @@ -1316,7 +1360,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatBody = document.getElementById("chat-body"); chatBody.dataset.conversationId = response.conversation_id; loadFileFiltersFromConversation(); - initializeSSE(); + initMessageState(); chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; let agentMetadata = response.agent; diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4c3603cf..e6b60282 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -11,7 +11,6 @@ from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse -from sse_starlette import EventSourceResponse from starlette.authentication import requires from khoj.app.settings import ALLOWED_HOSTS @@ -543,15 +542,24 @@ async def stream_chat( async def send_event(event_type: str, data: str): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): + connection_alive = False return try: if event_type == "message": yield data else: - yield {"event": event_type, "data": data, "retry": 15000} + yield json.dumps({"type": event_type, "data": data}) except Exception as e: connection_alive = False - logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}") + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + + async def send_llm_response(response: str): + async for result in send_event("start_llm_response", ""): + yield result + async for result in send_event("message", response): + yield result + async for result in send_event("end_llm_response", ""): + yield result user: KhojUser = request.user.object conversation = await ConversationAdapters.aget_conversation_by_user( @@ -585,17 +593,10 @@ async def stream_chat( except HTTPException as e: async for result in send_event("rate_limit", e.detail): yield result - break + return if is_query_empty(q): - async for event in send_event("start_llm_response", ""): - yield event - async for event in send_event( - "message", - "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", - ): - yield event - async for event in send_event("end_llm_response", ""): + async for event in send_llm_response("Please ask your query to get started."): yield event return @@ -645,25 +646,19 @@ async def stream_chat( response_log = ( "No files selected for summarization. Please add files using the section on the left." ) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event elif len(file_filters) > 1: response_log = "Only one file can be selected for summarization." - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event else: try: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) if len(file_object) == 0: response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event return contextual_data = " ".join([file.raw_text for file in file_object]) if not q: @@ -675,17 +670,13 @@ async def stream_chat( response = await extract_relevant_summary(q, contextual_data) response_log = str(response) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event except Exception as e: response_log = "Error summarizing file." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_event("complete_llm_response", response_log): + async for result in send_llm_response(response_log): yield result - async for event in send_event("end_llm_response", ""): - yield event await sync_to_async(save_to_conversation_log)( q, response_log, @@ -714,10 +705,8 @@ async def stream_chat( formatted_help = help_message.format( model=model_type, version=state.khoj_version, device=get_device() ) - async for result in send_event("complete_llm_response", formatted_help): + async for result in send_llm_response(formatted_help): yield result - async for event in send_event("end_llm_response", ""): - yield event return custom_filters.append("site:khoj.dev") conversation_commands.append(ConversationCommand.Online) @@ -730,10 +719,8 @@ async def stream_chat( except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_event("complete_llm_response", error_message): + async for result in send_llm_response(error_message): yield result - async for event in send_event("end_llm_response", ""): - yield event return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) @@ -760,10 +747,8 @@ async def stream_chat( api="chat", **common.__dict__, ) - async for result in send_event("complete_llm_response", llm_response): + async for result in send_llm_response(llm_response): yield result - async for event in send_event("end_llm_response", ""): - yield event return compiled_references, inferred_queries, defiltered_query = [], [], None @@ -797,9 +782,7 @@ async def stream_chat( if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( user ): - async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"): - yield result - async for event in send_event("end_llm_response", ""): + async for result in send_llm_response(f"{no_entries_found.format()}"): yield event return @@ -818,10 +801,8 @@ async def stream_chat( except ValueError as e: error_message = f"Error searching online: {e}. Attempting to respond without online results" logger.warning(error_message) - async for result in send_event("complete_llm_response", error_message): + async for result in send_llm_response(error_message): yield result - async for event in send_event("end_llm_response", ""): - yield event return if ConversationCommand.Webpage in conversation_commands: @@ -873,15 +854,13 @@ async def stream_chat( if image is None or status_code != 200: content_obj = { - "image": image, + "content-type": "application/json", "intentType": intent_type, "detail": improved_image_prompt, - "content-type": "application/json", + "image": image, } - async for result in send_event("complete_llm_response", json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj)): yield result - async for event in send_event("end_llm_response", ""): - yield event return await sync_to_async(save_to_conversation_log)( @@ -898,19 +877,22 @@ async def stream_chat( online_results=online_results, ) content_obj = { - "image": image, - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "context": compiled_references, "content-type": "application/json", + "intentType": intent_type, + "context": compiled_references, "online_results": online_results, + "inferredQueries": [improved_image_prompt], + "image": image, } - async for result in send_event("complete_llm_response", json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj)): yield result - async for event in send_event("end_llm_response", ""): - yield event return + async for result in send_event( + "references", json.dumps({"context": compiled_references, "online_results": online_results}) + ): + yield result + async for result in send_event("status", f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( @@ -941,27 +923,33 @@ async def stream_chat( async for result in send_event("start_llm_response", ""): yield result + continue_stream = True async for item in iterator: if item is None: - break - if connection_alive: - try: - async for result in send_event("message", f"{item}"): - yield result - except Exception as e: - connection_alive = False - logger.info( - f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}" - ) - async for result in send_event("end_llm_response", ""): - yield result + async for result in send_event("end_llm_response", ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + try: + async for result in send_event("message", f"{item}"): + yield result + except Exception as e: + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + # Stop streaming after compiled references section of response starts + # References are being processed via the references event rather than the message event + if "### compiled references:" in item: + continue_stream = False except asyncio.CancelledError: - break + logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True) + return except Exception as e: - logger.error(f"Error in SSE endpoint: {e}", exc_info=True) - break + logger.error(f"General Error in API endpoint: {e}", exc_info=True) + return - return EventSourceResponse(event_generator(q)) + return StreamingResponse(event_generator(q), media_type="text/plain") @api_chat.get("", response_class=Response) From 6b9550238f33e947886ca7cf35ffdb6a3fc93655 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 17:09:41 +0530 Subject: [PATCH 03/20] Simplify advanced streaming chat API, align params with normal chat API --- src/khoj/routers/api_chat.py | 702 +++++++++++++++++------------------ 1 file changed, 342 insertions(+), 360 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e6b60282..34879b86 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import math from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -529,29 +528,47 @@ async def set_conversation_title( @api_chat.get("/stream") async def stream_chat( request: Request, + common: CommonQueryParams, q: str, - conversation_id: int, + n: int = 7, + d: float = 0.18, + title: Optional[str] = None, + conversation_id: Optional[int] = None, city: Optional[str] = None, region: Optional[str] = None, country: Optional[str] = None, timezone: Optional[str] = None, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), ): async def event_generator(q: str): connection_alive = True + user: KhojUser = request.user.object + q = unquote(q) async def send_event(event_type: str, data: str): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client") return try: if event_type == "message": yield data else: yield json.dumps({"type": event_type, "data": data}) + except asyncio.CancelledError: + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client") + return except Exception as e: connection_alive = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) + return async def send_llm_response(response: str): async for result in send_event("start_llm_response", ""): @@ -561,393 +578,358 @@ async def stream_chat( async for result in send_event("end_llm_response", ""): yield result - user: KhojUser = request.user.object conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=request.user.client_app, conversation_id=conversation_id + user, client_application=request.user.client_app, conversation_id=conversation_id, title=title ) - - hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - - daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + if not conversation: + async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"): + yield result await is_ready_to_chat(user) user_name = await aget_user_name(user) - location = None - if city or region or country: location = LocationData(city=city, region=region, country=country) - while connection_alive: - try: - if conversation: - await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return - # Refresh these because the connection to the database might have been closed - await conversation.arefresh_from_db() + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_commands = [get_conversation_command(query=q, any_references=True)] - try: - await sync_to_async(hourly_limiter)(request) - await sync_to_async(daily_limiter)(request) - except HTTPException as e: - async for result in send_event("rate_limit", e.detail): - yield result - return + async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + yield result - if is_query_empty(q): - async for event in send_llm_response("Please ask your query to get started."): - yield event - return + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - conversation_commands = [get_conversation_command(query=q, any_references=True)] + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result - async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) + async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands: + response_log = "" + if len(file_filters) == 0: + response_log = "No files selected for summarization. Please add files using the section on the left." + async for result in send_llm_response(response_log): yield result - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + elif len(file_filters) > 1: + response_log = "Only one file can be selected for summarization." + async for result in send_llm_response(response_log): + yield result + else: + try: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + if len(file_object) == 0: + response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." + async for result in send_llm_response(response_log): + yield result + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" async for result in send_event( - "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + response = await extract_relevant_summary(q, contextual_data) + response_log = str(response) + async for result in send_llm_response(response_log): yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = ( - "No files selected for summarization. Please add files using the section on the left." - ) - async for result in send_llm_response(response_log): - yield result - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary(q, contextual_data) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - return - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format( - model=model_type, version=state.khoj_version, device=get_device() - ) - async for result in send_llm_response(formatted_help): - yield result - return - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): - yield result - return - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - common = CommonQueryParamsClass( - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - **common.__dict__, - ) - async for result in send_llm_response(llm_response): + except Exception as e: + response_log = "Error summarizing file." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_llm_response(response_log): yield result - return + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + return - compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - 7, - 0.18, - conversation_id, - conversation_commands, - location, - partial(send_event, "status"), + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) + async for result in send_llm_response(formatted_help): + yield result + return + # Adding specification to search online specifically on khoj.dev pages. + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_llm_response(error_message): + yield result + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + ) + common = CommonQueryParamsClass( + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + **common.__dict__, + ) + async for result in send_llm_response(llm_response): + yield result + return + + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + (d or 0.18), + conversation_id, + conversation_commands, + location, + partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): + yield result + + online_results: Dict = dict() + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}"): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters ): if isinstance(result, dict) and "status" in result: yield result["status"] else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join( - set([c.get("compiled", c).split("\n")[0] for c in compiled_references]) - ) - async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): - yield result - - online_results: Dict = dict() - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( - user - ): - async for result in send_llm_response(f"{no_entries_found.format()}"): - yield event - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters - ): - if isinstance(result, dict) and "status" in result: - yield result["status"] - else: - online_results = result - except ValueError as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_llm_response(error_message): - yield result - return - - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, meta_log, location, partial(send_event, "status") - ): - if isinstance(result, dict) and "status" in result: - yield result["status"] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): - yield result - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", - exc_info=True, - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - async for result in text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=partial(send_event, "status"), - ): - if isinstance(result, dict) and "status" in result: - yield result["status"] - else: - image, status_code, improved_image_prompt, intent_type = result - - if image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "context": compiled_references, - "online_results": online_results, - "inferredQueries": [improved_image_prompt], - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - async for result in send_event( - "references", json.dumps({"context": compiled_references, "online_results": online_results}) - ): + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_llm_response(error_message): yield result - - async for result in send_event("status", f"**💭 Generating a well-informed response**"): - yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation_id, - location, - user_name, - ) - - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) - iterator = AsyncIteratorWrapper(llm_response) - - async for result in send_event("start_llm_response", ""): - yield result - - continue_stream = True - async for item in iterator: - if item is None: - async for result in send_event("end_llm_response", ""): - yield result - logger.debug("Finished streaming response") - return - if not connection_alive or not continue_stream: - continue - try: - async for result in send_event("message", f"{item}"): - yield result - except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - # Stop streaming after compiled references section of response starts - # References are being processed via the references event rather than the message event - if "### compiled references:" in item: - continue_stream = False - except asyncio.CancelledError: - logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True) return + + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + if ConversationCommand.Image in conversation_commands: + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + send_status_func=partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "detail": improved_image_prompt, + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + ) + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "context": compiled_references, + "online_results": online_results, + "inferredQueries": [improved_image_prompt], + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + async for result in send_event( + "references", json.dumps({"context": compiled_references, "online_results": online_results}) + ): + yield result + + async for result in send_event("status", f"**💭 Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + ) + + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata=chat_metadata, + ) + iterator = AsyncIteratorWrapper(llm_response) + + async for result in send_event("start_llm_response", ""): + yield result + + continue_stream = True + async for item in iterator: + if item is None: + async for result in send_event("end_llm_response", ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + # Stop streaming after compiled references section of response starts + # References are being processed via the references event rather than the message event + if "### compiled references:" in item: + continue_stream = False + item = item.split("### compiled references:")[0] + try: + async for result in send_event("message", f"{item}"): + yield result except Exception as e: - logger.error(f"General Error in API endpoint: {e}", exc_info=True) - return + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") return StreamingResponse(event_generator(q), media_type="text/plain") From 2d4b284218eb396bc7f42d01a0434cad80e77a9f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 17:31:17 +0530 Subject: [PATCH 04/20] Simplify streaming chat function in web client --- src/khoj/interface/web/chat.html | 524 ++++++++++++------------------- 1 file changed, 200 insertions(+), 324 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index b1ff3eba..00139232 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -598,11 +598,9 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - renderMessageStream(isVoice); - return; + let chatBody = document.getElementById("chat-body"); - let query = document.getElementById("chat-input").value.trim(); - let resultsCount = localStorage.getItem("khojResultsCount") || 5; + var query = document.getElementById("chat-input").value.trim(); console.log(`Query: ${query}`); // Short circuit on empty query @@ -621,31 +619,20 @@ To get started, just start typing below. You can also type / to see a list of co document.getElementById("chat-input").value = ""; autoResize(); document.getElementById("chat-input").setAttribute("disabled", "disabled"); - let chat_body = document.getElementById("chat-body"); - let conversationID = chat_body.dataset.conversationId; + let newResponseEl = document.createElement("div"); + newResponseEl.classList.add("chat-message", "khoj"); + newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); + chatBody.appendChild(newResponseEl); - if (!conversationID) { - let response = await fetch('/api/chat/sessions', { method: "POST" }); - let data = await response.json(); - conversationID = data.conversation_id; - chat_body.dataset.conversationId = conversationID; - refreshChatSessionsPanel(); - } - - let new_response = document.createElement("div"); - new_response.classList.add("chat-message", "khoj"); - new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); - chat_body.appendChild(new_response); - - let newResponseText = document.createElement("div"); - newResponseText.classList.add("chat-message-text", "khoj"); - new_response.appendChild(newResponseText); + let newResponseTextEl = document.createElement("div"); + newResponseTextEl.classList.add("chat-message-text", "khoj"); + newResponseEl.appendChild(newResponseTextEl); // Temporary status message to indicate that Khoj is thinking let loadingEllipsis = createLoadingEllipse(); - newResponseText.appendChild(loadingEllipsis); + newResponseTextEl.appendChild(loadingEllipsis); document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; let chatTooltip = document.getElementById("chat-tooltip"); @@ -654,65 +641,21 @@ To get started, just start typing below. You can also type / to see a list of co let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Generate backend API URL to execute query - let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; - // Call specified Khoj API - let response = await fetch(url); + await sendMessageStream(query); let rawResponse = ""; - let references = null; - const contentType = response.headers.get("content-type"); + let references = {}; - if (contentType === "application/json") { - // Handle JSON response - try { - const responseAsJson = await response.json(); - if (responseAsJson.image || responseAsJson.detail) { - ({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse)); - } else { - rawResponse = responseAsJson.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - addMessageToChatBody(rawResponse, newResponseText, references); - } - } else { - // Handle streamed response of type text/event-stream or text/plain - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let references = {}; - - readStream(); - - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - finalizeChatBodyResponse(references, newResponseText); - return; - } - - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); - - if (chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse)); - readStream(); - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - handleStreamResponse(newResponseText, rawResponse, query, loadingEllipsis); - readStream(); - } - }); - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }; + chatMessageState = { + newResponseTextEl, + newResponseEl, + loadingEllipsis, + references, + rawResponse, + rawQuery: query, + isVoice: isVoice, } - }; + } function createLoadingEllipse() { // Temporary status message to indicate that Khoj is thinking @@ -750,22 +693,6 @@ To get started, just start typing below. You can also type / to see a list of co document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; } - function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - rawResponseElement.innerHTML = ""; - rawResponseElement.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - return { rawResponse, references }; - } - function handleImageResponse(imageJson, rawResponse) { if (imageJson.image) { const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; @@ -806,11 +733,188 @@ To get started, just start typing below. You can also type / to see a list of co } function finalizeChatBodyResponse(references, newResponseElement) { - if (references != null && Object.keys(references).length > 0) { + if (!!newResponseElement && references != null && Object.keys(references).length > 0) { newResponseElement.appendChild(createReferenceSection(references)); } document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); + document.getElementById("chat-input")?.removeAttribute("disabled"); + } + + function collectJsonsInBufferedMessageChunk(chunk) { + // Collect list of JSON objects and raw strings in the chunk + // Return the list of objects and the remaining raw string + console.log("Raw Chunk:", chunk); + let startIndex = chunk.indexOf('{'); + if (startIndex === -1) return { objects: [chunk], remainder: '' }; + const objects = [chunk.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; + + for (let i = startIndex; i < chunk.length; i++) { + if (chunk[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (chunk[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = chunk.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } + } + + return { + objects: objects, + remainder: openBraces > 0 ? chunk.slice(startIndex) : '' + }; + } + + function convertMessageChunkToJson(rawChunk) { + // Split the chunk into lines + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } + } + + function processMessageChunk(rawChunk) { + const chunk = convertMessageChunkToJson(rawChunk); + console.debug("Chunk:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + } + } else if (chunk.type === "references") { + const rawReferenceAsJson = JSON.parse(chunk.data); + chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + handleJsonResponse(jsonData); + } catch (e) { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } + } + + function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + chatMessageState.references = { + notes: jsonData.context || {}, + online: jsonData.online_results || {} + }; + } + addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); + } + + async function sendMessageStream(query) { + let chatBody = document.getElementById("chat-body"); + let conversationId = chatBody.dataset.conversationId; + + if (!conversationId) { + let response = await fetch('/api/chat/sessions', { method: "POST" }); + let data = await response.json(); + conversationId = data.conversation_id; + chatBody.dataset.conversationId = conversationId; + refreshChatSessionsPanel(); + } + + let chatStreamUrl = `/api/chat/stream?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&client=web`; + chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; + + fetch(chatStreamUrl) + .then(response => { + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; + + function readStream() { + reader.read().then(({ done, value }) => { + // If the stream is done + if (done) { + // Process the last chunk + processMessageChunk(buffer); + buffer = ''; + console.log("Stream complete"); + return; + } + + // Read chunk from stream and append it to the buffer + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + // Check if the buffer contains (0 or more) complete JSON objects + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + let chunks = collectJsonsInBufferedMessageChunk(buffer); + chunks.objects.forEach(processMessageChunk); + buffer = chunks.remainder; + } + + // Continue reading the stream + readStream(); + }); + } + + readStream(); + }) + .catch(error => { + console.error('Error:', error); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + } + chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" + }); } function incrementalChat(event) { @@ -1083,234 +1187,6 @@ To get started, just start typing below. You can also type / to see a list of co } } - function sendMessageStream(query) { - let chatBody = document.getElementById("chat-body"); - let chatStreamUrl = `/api/chat/stream?q=${query}`; - - if (chatBody.dataset.conversationId) { - chatStreamUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) - ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` - : ''; - - fetch(chatStreamUrl) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; - let netBracketCount = 0; - - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - console.log("Stream complete"); - handleChunk(buffer); - buffer = ''; - return; - } - - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; - if (netBracketCount === 0) { - chunks = processJsonObjects(buffer); - chunks.objects.forEach(obj => handleChunk(obj)); - buffer = chunks.remainder; - } - readStream(); - }); - } - - readStream(); - }) - .catch(error => { - console.error('Error:', error); - if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { - chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); - } - chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" - }); - - function processJsonObjects(str) { - let startIndex = str.indexOf('{'); - if (startIndex === -1) return { objects: [str], remainder: '' }; - const objects = [str.slice(0, startIndex)]; - let openBraces = 0; - let currentObject = ''; - - for (let i = startIndex; i < str.length; i++) { - if (str[i] === '{') { - if (openBraces === 0) startIndex = i; - openBraces++; - } - if (str[i] === '}') { - openBraces--; - if (openBraces === 0) { - currentObject = str.slice(startIndex, i + 1); - objects.push(currentObject); - currentObject = ''; - } - } - } - - return { - objects: objects, - remainder: openBraces > 0 ? str.slice(startIndex) : '' - }; - } - - function handleChunk(rawChunk) { - // Split the chunk into lines - console.log("Chunk:", rawChunk); - if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { - try { - let jsonChunk = JSON.parse(rawChunk); - if (!jsonChunk.type) - jsonChunk = {type: 'message', data: jsonChunk}; - processChunk(jsonChunk); - } catch (e) { - const jsonChunk = {type: 'message', data: rawChunk}; - processChunk(jsonChunk); - } - } else if (rawChunk.length > 0) { - const jsonChunk = {type: 'message', data: rawChunk}; - processChunk(jsonChunk); - } - } - function processChunk(chunk) { - console.log(chunk); - if (chunk.type ==='status') { - console.log(`status: ${chunk.data}`); - const statusMessage = chunk.data; - handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false); - } else if (chunk.type === 'start_llm_response') { - console.log("Started streaming", new Date()); - } else if (chunk.type === 'end_llm_response') { - console.log("Stopped streaming", new Date()); - - // Automatically respond with voice if the subscribed user has sent voice message - if (chatMessageState.isVoice && "{{ is_active }}" == "True") - textToSpeech(chatMessageState.rawResponse); - - // Append any references after all the data has been streamed - finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); - - const liveQuery = chatMessageState.rawQuery; - // Reset variables - chatMessageState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, - } - } else if (chunk.type === "references") { - const rawReferenceAsJson = JSON.parse(chunk.data); - console.log(`${chunk.type}: ${rawReferenceAsJson}`); - chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; - } else if (chunk.type === 'message') { - if (chunk.data.trim()?.startsWith("{") && chunk.data.trim()?.endsWith("}")) { - // Try process chunk data as if it is a JSON object - try { - const jsonData = JSON.parse(chunk.data.trim()); - handleJsonResponse(jsonData); - } catch (e) { - // Handle text response chunk with compiled references - if (chunk?.data.includes("### compiled references:")) { - chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; - // Handle text response chunk - } else { - chatMessageState.rawResponse += chunk.data; - } - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } else { - // Handle text response chunk with compiled references - if (chunk?.data.includes("### compiled references:")) { - chatMessageState.rawResponse += chunk.data.split("### compiled references:")[0]; - // Handle text response chunk - } else { - chatMessageState.rawResponse += chunk.data; - } - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } - } - - function handleJsonResponse(jsonData) { - if (jsonData.image || jsonData.detail) { - let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; - } else if (jsonData.response) { - chatMessageState.rawResponse = jsonData.response; - chatMessageState.references = { - notes: jsonData.context || {}, - online: jsonData.online_results || {} - }; - } - addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); - } - } - } - - function renderMessageStream(isVoice=false) { - let chatBody = document.getElementById("chat-body"); - - var query = document.getElementById("chat-input").value.trim(); - console.log(`Query: ${query}`); - - if (userMessages.length >= 10) { - userMessages.shift(); - } - userMessages.push(query); - resetUserMessageIndex(); - - // Add message by user to chat body - renderMessage(query, "you"); - document.getElementById("chat-input").value = ""; - autoResize(); - document.getElementById("chat-input").setAttribute("disabled", "disabled"); - - let newResponseEl = document.createElement("div"); - newResponseEl.classList.add("chat-message", "khoj"); - newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); - chatBody.appendChild(newResponseEl); - - let newResponseTextEl = document.createElement("div"); - newResponseTextEl.classList.add("chat-message-text", "khoj"); - newResponseEl.appendChild(newResponseTextEl); - - // Temporary status message to indicate that Khoj is thinking - let loadingEllipsis = createLoadingEllipse(); - - newResponseTextEl.appendChild(loadingEllipsis); - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - - let chatTooltip = document.getElementById("chat-tooltip"); - chatTooltip.style.display = "none"; - - let chatInput = document.getElementById("chat-input"); - chatInput.classList.remove("option-enabled"); - - // Call specified Khoj API - sendMessageStream(query); - let rawResponse = ""; - let references = {}; - - chatMessageState = { - newResponseTextEl, - newResponseEl, - loadingEllipsis, - references, - rawResponse, - rawQuery: query, - isVoice: isVoice, - } - } - var userMessages = []; var userMessageIndex = -1; function loadChat() { From daec439d5250f4440ddf6006eb2804ef08b185a3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 20:29:45 +0530 Subject: [PATCH 05/20] Replace old chat router with new chat router with advanced streaming - Details Only return notes refs, online refs, inferred queries and generated response in non-streaming mode. Do not return train of throught and other status messages Incorporate missing logic from old chat API router into new one. - Motivation So we can halve chat API code by getting rid of the duplicate logic for the websocket router The deduplicated code: - Avoids inadvertant logic drift between the 2 routers - Improves dev velocity --- src/khoj/interface/web/chat.html | 47 ++--- src/khoj/routers/api_chat.py | 333 +++++-------------------------- 2 files changed, 61 insertions(+), 319 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 00139232..6855c196 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -709,27 +709,11 @@ To get started, just start typing below. You can also type / to see a list of co rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } } - let references = {}; - if (imageJson.context && imageJson.context.length > 0) { - const rawReferenceAsJson = imageJson.context; - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - } - if (imageJson.detail) { - // If response has detail field, response is an error message. - rawResponse += imageJson.detail; - } - return { rawResponse, references }; - } - function addMessageToChatBody(rawResponse, newResponseElement, references) { - newResponseElement.innerHTML = ""; - newResponseElement.appendChild(formatHTMLMessage(rawResponse)); + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; - finalizeChatBodyResponse(references, newResponseElement); + return rawResponse; } function finalizeChatBodyResponse(references, newResponseElement) { @@ -743,7 +727,6 @@ To get started, just start typing below. You can also type / to see a list of co function collectJsonsInBufferedMessageChunk(chunk) { // Collect list of JSON objects and raw strings in the chunk // Return the list of objects and the remaining raw string - console.log("Raw Chunk:", chunk); let startIndex = chunk.indexOf('{'); if (startIndex === -1) return { objects: [chunk], remainder: '' }; const objects = [chunk.slice(0, startIndex)]; @@ -819,11 +802,13 @@ To get started, just start typing below. You can also type / to see a list of co isVoice: false, } } else if (chunk.type === "references") { - const rawReferenceAsJson = JSON.parse(chunk.data); - chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; } else if (chunk.type === 'message') { const chunkData = chunk.data; - if (chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { // Try process chunk data as if it is a JSON object try { const jsonData = JSON.parse(chunkData.trim()); @@ -841,17 +826,15 @@ To get started, just start typing below. You can also type / to see a list of co function handleJsonResponse(jsonData) { if (jsonData.image || jsonData.detail) { - let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; + chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); } else if (jsonData.response) { chatMessageState.rawResponse = jsonData.response; - chatMessageState.references = { - notes: jsonData.context || {}, - online: jsonData.online_results || {} - }; } - addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); + + if (chatMessageState.newResponseTextEl) { + chatMessageState.newResponseTextEl.innerHTML = ""; + chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); + } } async function sendMessageStream(query) { @@ -866,7 +849,7 @@ To get started, just start typing below. You can also type / to see a list of co refreshChatSessionsPanel(); } - let chatStreamUrl = `/api/chat/stream?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&client=web`; + let chatStreamUrl = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&stream=true&client=web`; chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 34879b86..d8826264 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -525,13 +525,14 @@ async def set_conversation_title( ) -@api_chat.get("/stream") -async def stream_chat( +@api_chat.get("") +async def chat( request: Request, common: CommonQueryParams, q: str, n: int = 7, d: float = 0.18, + stream: Optional[bool] = False, title: Optional[str] = None, conversation_id: Optional[int] = None, city: Optional[str] = None, @@ -550,7 +551,7 @@ async def stream_chat( user: KhojUser = request.user.object q = unquote(q) - async def send_event(event_type: str, data: str): + async def send_event(event_type: str, data: str | dict): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): connection_alive = False @@ -559,7 +560,9 @@ async def stream_chat( try: if event_type == "message": yield data - else: + elif event_type == "references": + yield json.dumps({"type": event_type, "data": data}) + elif stream: yield json.dumps({"type": event_type, "data": data}) except asyncio.CancelledError: connection_alive = False @@ -744,6 +747,8 @@ async def stream_chat( yield result return + # Gather Context + ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None async for result in extract_references_and_questions( request, @@ -778,6 +783,7 @@ async def stream_chat( if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): conversation_commands.remove(ConversationCommand.Notes) + ## Gather Online References if ConversationCommand.Online in conversation_commands: try: async for result in search_online( @@ -794,6 +800,7 @@ async def stream_chat( yield result return + ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: try: async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): @@ -818,6 +825,19 @@ async def stream_chat( exc_info=True, ) + ## Send Gathered References + async for result in send_event( + "references", + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "online_results": online_results, + }, + ): + yield result + + # Generate Output + ## Generate Image Output if ConversationCommand.Image in conversation_commands: update_telemetry_state( request=request, @@ -875,11 +895,7 @@ async def stream_chat( yield result return - async for result in send_event( - "references", json.dumps({"context": compiled_references, "online_results": online_results}) - ): - yield result - + ## Generate Text Output async for result in send_event("status", f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( @@ -897,6 +913,8 @@ async def stream_chat( user_name, ) + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata["conversation_command"] = cmd_set chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None update_telemetry_state( @@ -905,12 +923,13 @@ async def stream_chat( api="chat", metadata=chat_metadata, ) - iterator = AsyncIteratorWrapper(llm_response) + # Send Response async for result in send_event("start_llm_response", ""): yield result continue_stream = True + iterator = AsyncIteratorWrapper(llm_response) async for item in iterator: if item is None: async for result in send_event("end_llm_response", ""): @@ -931,282 +950,22 @@ async def stream_chat( continue_stream = False logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - return StreamingResponse(event_generator(q), media_type="text/plain") - - -@api_chat.get("", response_class=Response) -@requires(["authenticated"]) -async def chat( - request: Request, - common: CommonQueryParams, - q: str, - n: Optional[int] = 5, - d: Optional[float] = 0.22, - stream: Optional[bool] = False, - title: Optional[str] = None, - conversation_id: Optional[int] = None, - city: Optional[str] = None, - region: Optional[str] = None, - country: Optional[str] = None, - timezone: Optional[str] = None, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), -) -> Response: - user: KhojUser = request.user.object - q = unquote(q) - if is_query_empty(q): - return Response( - content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", - media_type="text/plain", - status_code=400, - ) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.info(f"Chat request by {user.username}: {q}") - - await is_ready_to_chat(user) - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - _custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - help_str = "/" + ConversationCommand.Help - if q.strip() == help_str: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) - # Adding specification to search online specifically on khoj.dev pages. - _custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, request.user.client_app, conversation_id, title - ) - conversation_id = conversation.id if conversation else None - - if not conversation: - return Response( - content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400 - ) - else: - meta_log = conversation.conversation_log - - if ConversationCommand.Summarize in conversation_commands: - file_filters = conversation.file_filters - llm_response = "" - if len(file_filters) == 0: - llm_response = "No files selected for summarization. Please add files using the section on the left." - elif len(file_filters) > 1: - llm_response = "Only one file can be selected for summarization." - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200) - contextual_data = " ".join([file.raw_text for file in file_object]) - summarizeStr = "/" + ConversationCommand.Summarize - if q.strip() == summarizeStr: - q = "Create a general summary of the file" - response = await extract_relevant_summary(q, contextual_data) - llm_response = str(response) - except Exception as e: - logger.error(f"Error summarizing file for {user.email}: {e}") - llm_response = "Error summarizing file." - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - conversation.conversation_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - **common.__dict__, - ) - return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200) - - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - location = None - - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - user_name = await aget_user_name(user) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True) - return Response( - content=f"Unable to create automation. Ensure the automation doesn't already exist.", - media_type="text/plain", - status_code=500, - ) - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - - if stream: - return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) - else: - return Response(content=llm_response, media_type="text/plain", status_code=200) - - compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location - ) - online_results: Dict[str, Dict] = {} - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - no_entries_found_format = no_entries_found.format() - if stream: - return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) - else: - response_obj = {"response": no_entries_found_format} - return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) - - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references): - no_notes_found_format = no_notes_found.format() - if stream: - return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200) - else: - response_obj = {"response": no_notes_found_format} - return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - if ConversationCommand.Online in conversation_commands: - try: - online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters) - except ValueError as e: - logger.warning(f"Error searching online: {e}. Attempting to respond without online results") - - if ConversationCommand.Webpage in conversation_commands: - try: - online_results = await read_webpages(defiltered_query, meta_log, location) - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - **common.__dict__, - ) - image, status_code, improved_image_prompt, intent_type = await text_to_image( - q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results - ) - if image is None: - content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} - return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation.id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore - return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - - # Get the (streamed) chat response from the LLM of choice. - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation.id, - location, - user_name, - ) - - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - **common.__dict__, - ) - - if llm_response is None: - return Response(content=llm_response, media_type="text/plain", status_code=500) - + ## Stream Text Response if stream: - return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) + return StreamingResponse(event_generator(q), media_type="text/plain") + ## Non-Streaming Text Response + else: + # Get the full response from the generator if the stream is not requested. + response_obj = {} + actual_response = "" + iterator = event_generator(q) + async for item in iterator: + try: + item_json = json.loads(item) + if "type" in item_json and item_json["type"] == "references": + response_obj = item_json["data"] + except: + actual_response += item + response_obj["response"] = actual_response - iterator = AsyncIteratorWrapper(llm_response) - - # Get the full response from the generator if the stream is not requested. - aggregated_gpt_response = "" - async for item in iterator: - if item is None: - break - aggregated_gpt_response += item - - actual_response = aggregated_gpt_response.split("### compiled references:")[0] - - response_obj = { - "response": actual_response, - "inferredQueries": inferred_queries, - "context": compiled_references, - "online_results": online_results, - } - - return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) + return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) From b224d7ffad8b0260fb5230aa07c20a0a538d9cb0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 20:34:30 +0530 Subject: [PATCH 06/20] Simplify get_conversation_by_user DB adapter code --- src/khoj/database/adapters/__init__.py | 32 +++++++------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 0c0724ee..2dae40ed 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -680,34 +680,18 @@ class ConversationAdapters: async def aget_conversation_by_user( user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None ) -> Optional[Conversation]: + query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent") + if conversation_id: - return ( - await Conversation.objects.filter(user=user, client=client_application, id=conversation_id) - .prefetch_related("agent") - .afirst() - ) + return await query.filter(id=conversation_id).afirst() elif title: - return ( - await Conversation.objects.filter(user=user, client=client_application, title=title) - .prefetch_related("agent") - .afirst() - ) - else: - conversation = ( - Conversation.objects.filter(user=user, client=client_application) - .prefetch_related("agent") - .order_by("-updated_at") - ) + return await query.filter(title=title).afirst() - if await conversation.aexists(): - return await conversation.prefetch_related("agent").afirst() + conversation = await query.order_by("-updated_at").afirst() - return await ( - Conversation.objects.filter(user=user, client=client_application) - .prefetch_related("agent") - .order_by("-updated_at") - .afirst() - ) or await Conversation.objects.prefetch_related("agent").acreate(user=user, client=client_application) + return conversation or await Conversation.objects.prefetch_related("agent").acreate( + user=user, client=client_application + ) @staticmethod async def adelete_conversation_by_user( From 8303b091290784d249df666929810586a9459d4b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 14:36:53 +0530 Subject: [PATCH 07/20] Convert snake case to camel case in chat view of obsidian plugin --- src/interface/obsidian/src/chat_view.ts | 46 ++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index b8d95d6b..9ad187b0 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -409,16 +409,16 @@ export class KhojChatView extends KhojPaneView { message = DOMPurify.sanitize(message); // Convert the message to html, sanitize the message html and render it to the real DOM - let chat_message_body_text_el = this.contentEl.createDiv(); - chat_message_body_text_el.className = "chat-message-text-response"; - chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this); + let chatMessageBodyTextEl = this.contentEl.createDiv(); + chatMessageBodyTextEl.className = "chat-message-text-response"; + chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this); // Add a copy button to each chat message, if it doesn't already exist if (willReplace === true) { - this.renderActionButtons(message, chat_message_body_text_el); + this.renderActionButtons(message, chatMessageBodyTextEl); } - return chat_message_body_text_el; + return chatMessageBodyTextEl; } markdownTextToSanitizedHtml(markdownText: string, component: ItemView): string { @@ -502,23 +502,23 @@ export class KhojChatView extends KhojPaneView { class: `khoj-chat-message ${sender}` }, }) - let chat_message_body_el = chatMessageEl.createDiv(); - chat_message_body_el.addClasses(["khoj-chat-message-text", sender]); - let chat_message_body_text_el = chat_message_body_el.createDiv(); + let chatMessageBodyEl = chatMessageEl.createDiv(); + chatMessageBodyEl.addClasses(["khoj-chat-message-text", sender]); + let chatMessageBodyTextEl = chatMessageBodyEl.createDiv(); // Sanitize the markdown to render message = DOMPurify.sanitize(message); if (raw) { - chat_message_body_text_el.innerHTML = message; + chatMessageBodyTextEl.innerHTML = message; } else { // @ts-ignore - chat_message_body_text_el.innerHTML = this.markdownTextToSanitizedHtml(message, this); + chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this); } // Add action buttons to each chat message element if (willReplace === true) { - this.renderActionButtons(message, chat_message_body_text_el); + this.renderActionButtons(message, chatMessageBodyTextEl); } // Remove user-select: none property to make text selectable @@ -531,14 +531,14 @@ export class KhojChatView extends KhojPaneView { } createKhojResponseDiv(dt?: Date): HTMLDivElement { - let message_time = this.formatDate(dt ?? new Date()); + let messageTime = this.formatDate(dt ?? new Date()); // Append message to conversation history HTML element. // The chat logs should display above the message input box to follow standard UI semantics - let chat_body_el = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; - let chat_message_el = chat_body_el.createDiv({ + let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; + let chatMessageEl = chatBodyEl.createDiv({ attr: { - "data-meta": `🏮 Khoj at ${message_time}`, + "data-meta": `🏮 Khoj at ${messageTime}`, class: `khoj-chat-message khoj` }, }).createDiv({ @@ -550,7 +550,7 @@ export class KhojChatView extends KhojPaneView { // Scroll to bottom after inserting chat messages this.scrollChatToBottom(); - return chat_message_el; + return chatMessageEl; } async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) { @@ -566,7 +566,7 @@ export class KhojChatView extends KhojPaneView { this.scrollChatToBottom(); } - renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) { + renderActionButtons(message: string, chatMessageBodyTextEl: HTMLElement) { let copyButton = this.contentEl.createEl('button'); copyButton.classList.add("chat-action-button"); copyButton.title = "Copy Message to Clipboard"; @@ -593,10 +593,10 @@ export class KhojChatView extends KhojPaneView { } // Append buttons to parent element - chat_message_body_text_el.append(copyButton, pasteToFile); + chatMessageBodyTextEl.append(copyButton, pasteToFile); if (speechButton) { - chat_message_body_text_el.append(speechButton); + chatMessageBodyTextEl.append(speechButton); } } @@ -895,16 +895,16 @@ export class KhojChatView extends KhojPaneView { let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement; this.renderMessage(chatBodyEl, query, "you"); - let conversationID = chatBodyEl.dataset.conversationId; - if (!conversationID) { + let conversationId = chatBodyEl.dataset.conversationId; + if (!conversationId) { let chatUrl = `${this.setting.khojUrl}/api/chat/sessions?client=obsidian`; let response = await fetch(chatUrl, { method: "POST", headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` }, }); let data = await response.json(); - conversationID = data.conversation_id; - chatBodyEl.dataset.conversationId = conversationID; + conversationId = data.conversation_id; + chatBodyEl.dataset.conversationId = conversationId; } // Get chat response from Khoj backend From 3f5f418d0ea87205914c2c6d4fb9f534bb53a008 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 15:02:31 +0530 Subject: [PATCH 08/20] Use new chat streaming API to show Khoj train of thought in Obsidian client --- src/interface/obsidian/src/chat_view.ts | 292 +++++++++++++++--------- 1 file changed, 179 insertions(+), 113 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 9ad187b0..121d0a87 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -12,6 +12,25 @@ export interface ChatJsonResult { inferredQueries?: string[]; } +interface ChunkResult { + objects: string[]; + remainder: string; +} + +interface MessageChunk { + type: string; + data: any; +} + +interface ChatMessageState { + newResponseTextEl: HTMLElement | null; + newResponseEl: HTMLElement | null; + loadingEllipsis: HTMLElement | null; + references: any; + rawResponse: string; + rawQuery: string; + isVoice: boolean; +} interface Location { region: string; @@ -26,6 +45,7 @@ export class KhojChatView extends KhojPaneView { waitingForLocation: boolean; location: Location; keyPressTimeout: NodeJS.Timeout | null = null; + chatMessageState: ChatMessageState; constructor(leaf: WorkspaceLeaf, setting: KhojSetting) { super(leaf, setting); @@ -410,7 +430,6 @@ export class KhojChatView extends KhojPaneView { // Convert the message to html, sanitize the message html and render it to the real DOM let chatMessageBodyTextEl = this.contentEl.createDiv(); - chatMessageBodyTextEl.className = "chat-message-text-response"; chatMessageBodyTextEl.innerHTML = this.markdownTextToSanitizedHtml(message, this); // Add a copy button to each chat message, if it doesn't already exist @@ -541,11 +560,7 @@ export class KhojChatView extends KhojPaneView { "data-meta": `🏮 Khoj at ${messageTime}`, class: `khoj-chat-message khoj` }, - }).createDiv({ - attr: { - class: `khoj-chat-message-text khoj` - }, - }).createDiv(); + }) // Scroll to bottom after inserting chat messages this.scrollChatToBottom(); @@ -554,14 +569,14 @@ export class KhojChatView extends KhojPaneView { } async renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) { - this.result += additionalMessage; + this.chatMessageState.rawResponse += additionalMessage; htmlElement.innerHTML = ""; // Sanitize the markdown to render - this.result = DOMPurify.sanitize(this.result); + this.chatMessageState.rawResponse = DOMPurify.sanitize(this.chatMessageState.rawResponse); // @ts-ignore - htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.result, this); + htmlElement.innerHTML = this.markdownTextToSanitizedHtml(this.chatMessageState.rawResponse, this); // Render action buttons for the message - this.renderActionButtons(this.result, htmlElement); + this.renderActionButtons(this.chatMessageState.rawResponse, htmlElement); // Scroll to bottom of modal, till the send message input box this.scrollChatToBottom(); } @@ -854,35 +869,147 @@ export class KhojChatView extends KhojPaneView { return true; } - async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise { + collectJsonsInBufferedMessageChunk(chunk: string): ChunkResult { + // Collect list of JSON objects and raw strings in the chunk + // Return the list of objects and the remaining raw string + let startIndex = chunk.indexOf('{'); + if (startIndex === -1) return { objects: [chunk], remainder: '' }; + const objects: string[] = [chunk.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; + + for (let i = startIndex; i < chunk.length; i++) { + if (chunk[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (chunk[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = chunk.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } + } + + return { + objects: objects, + remainder: openBraces > 0 ? chunk.slice(startIndex) : '' + }; + } + + convertMessageChunkToJson(rawChunk: string): MessageChunk { + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } + return {type: '', data: ''}; + } + + processMessageChunk(rawChunk: string): void { + const chunk = this.convertMessageChunkToJson(rawChunk); + console.debug("Chunk:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type === 'status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, statusMessage, this.chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active) + this.textToSpeech(this.chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + this.finalizeChatBodyResponse(this.chatMessageState.references, this.chatMessageState.newResponseTextEl); + + const liveQuery = this.chatMessageState.rawQuery; + // Reset variables + this.chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + }; + } else if (chunk.type === "references") { + this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + this.handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + this.handleJsonResponse(jsonData); + } catch (e) { + this.chatMessageState.rawResponse += chunkData; + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); + } + } else { + this.chatMessageState.rawResponse += chunkData; + this.handleStreamResponse(this.chatMessageState.newResponseTextEl, this.chatMessageState.rawResponse, this.chatMessageState.loadingEllipsis); + } + } + } + + handleJsonResponse(jsonData: any): void { + if (jsonData.image || jsonData.detail) { + this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse); + } else if (jsonData.response) { + this.chatMessageState.rawResponse = jsonData.response; + } + + if (this.chatMessageState.newResponseTextEl) { + this.chatMessageState.newResponseTextEl.innerHTML = ""; + this.chatMessageState.newResponseTextEl.appendChild(this.formatHTMLMessage(this.chatMessageState.rawResponse)); + } + } + + async readChatStream(response: Response): Promise { // Exit if response body is empty if (response.body == null) return; const reader = response.body.getReader(); const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; while (true) { const { value, done } = await reader.read(); if (done) { - // Automatically respond with voice if the subscribed user has sent voice message - if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result); + this.processMessageChunk(buffer); + buffer = ''; // Break if the stream is done break; } - let responseText = decoder.decode(value); - if (responseText.includes("### compiled references:")) { - // Render any references used to generate the response - const [additionalResponse, rawReference] = responseText.split("### compiled references:", 2); - await this.renderIncrementalMessage(responseElement, additionalResponse); + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; - const rawReferenceAsJson = JSON.parse(rawReference); - let references = this.extractReferences(rawReferenceAsJson); - responseElement.appendChild(this.createReferenceSection(references)); - } else { - // Render incremental chat response - await this.renderIncrementalMessage(responseElement, responseText); + // Check if the buffer contains (0 or more) complete JSON objects + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + let chunks = this.collectJsonsInBufferedMessageChunk(buffer); + chunks.objects.forEach((chunk) => this.processMessageChunk(chunk)); + buffer = chunks.remainder; } } } @@ -909,69 +1036,45 @@ export class KhojChatView extends KhojPaneView { // Get chat response from Khoj backend let encodedQuery = encodeURIComponent(query); - let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`; - let responseElement = this.createKhojResponseDiv(); + let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&conversation_id=${conversationId}&n=${this.setting.resultsCount}&stream=true&client=obsidian`; + if (!!this.location) chatUrl += `®ion=${this.location.region}&city=${this.location.city}&country=${this.location.countryName}&timezone=${this.location.timezone}`; + + let newResponseEl = this.createKhojResponseDiv(); + let newResponseTextEl = newResponseEl.createDiv(); + newResponseTextEl.classList.add("khoj-chat-message-text", "khoj"); // Temporary status message to indicate that Khoj is thinking - this.result = ""; let loadingEllipsis = this.createLoadingEllipse(); - responseElement.appendChild(loadingEllipsis); + newResponseTextEl.appendChild(loadingEllipsis); + + // Set chat message state + this.chatMessageState = { + newResponseEl: newResponseEl, + newResponseTextEl: newResponseTextEl, + loadingEllipsis: loadingEllipsis, + references: {}, + rawQuery: query, + rawResponse: "", + isVoice: isVoice, + }; let response = await fetch(chatUrl, { method: "GET", headers: { - "Content-Type": "text/event-stream", + "Content-Type": "text/plain", "Authorization": `Bearer ${this.setting.khojApiKey}`, }, }) try { - if (response.body === null) { - throw new Error("Response body is null"); - } + if (response.body === null) throw new Error("Response body is null"); - // Clear loading status message - if (responseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { - responseElement.removeChild(loadingEllipsis); - } - - // Reset collated chat result to empty string - this.result = ""; - responseElement.innerHTML = ""; - if (response.headers.get("content-type") === "application/json") { - let responseText = "" - try { - const responseAsJson = await response.json() as ChatJsonResult; - if (responseAsJson.image) { - // If response has image field, response is a generated image. - if (responseAsJson.intentType === "text-to-image") { - responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image2") { - responseText += `![${query}](${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image-v3") { - responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`; - } - const inferredQuery = responseAsJson.inferredQueries?.[0]; - if (inferredQuery) { - responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`; - } - } else if (responseAsJson.detail) { - responseText = responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - responseText = await response.text(); - } finally { - await this.renderIncrementalMessage(responseElement, responseText); - } - } else { - // Stream and render chat response - await this.readChatStream(response, responseElement, isVoice); - } + // Stream and render chat response + await this.readChatStream(response); } catch (err) { - console.log(`Khoj chat response failed with\n${err}`); + console.error(`Khoj chat response failed with\n${err}`); let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; - responseElement.innerHTML = errorMsg + newResponseTextEl.textContent = errorMsg; } } @@ -1196,7 +1299,7 @@ export class KhojChatView extends KhojPaneView { handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) { if (!newResponseElement) return; - if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { + if (replace && newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { newResponseElement.removeChild(loadingEllipsis); } if (replace) { @@ -1206,20 +1309,6 @@ export class KhojChatView extends KhojPaneView { this.scrollChatToBottom(); } - handleCompiledReferences(rawResponseElement: HTMLElement | null, chunk: string, references: any, rawResponse: string) { - if (!rawResponseElement || !chunk) return { rawResponse, references }; - - const [additionalResponse, rawReference] = chunk.split("### compiled references:", 2); - rawResponse += additionalResponse; - rawResponseElement.innerHTML = ""; - rawResponseElement.appendChild(this.formatHTMLMessage(rawResponse)); - - const rawReferenceAsJson = JSON.parse(rawReference); - references = this.extractReferences(rawReferenceAsJson); - - return { rawResponse, references }; - } - handleImageResponse(imageJson: any, rawResponse: string) { if (imageJson.image) { const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; @@ -1236,33 +1325,10 @@ export class KhojChatView extends KhojPaneView { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } } - let references = {}; - if (imageJson.context && imageJson.context.length > 0) { - references = this.extractReferences(imageJson.context); - } - if (imageJson.detail) { - // If response has detail field, response is an error message. - rawResponse += imageJson.detail; - } - return { rawResponse, references }; - } + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; - extractReferences(rawReferenceAsJson: any): object { - let references: any = {}; - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - return references; - } - - addMessageToChatBody(rawResponse: string, newResponseElement: HTMLElement | null, references: any) { - if (!newResponseElement) return; - newResponseElement.innerHTML = ""; - newResponseElement.appendChild(this.formatHTMLMessage(rawResponse)); - - this.finalizeChatBodyResponse(references, newResponseElement); + return rawResponse; } finalizeChatBodyResponse(references: object, newResponseElement: HTMLElement | null) { From 54b42036836967cde68565ceebe3d390f35437dd Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 15:05:06 +0530 Subject: [PATCH 09/20] Update chat API client tests to mix testing of batch and streaming mode --- src/khoj/utils/fs_syncer.py | 2 +- tests/test_client.py | 6 ++-- tests/test_offline_chat_director.py | 12 +++----- tests/test_openai_chat_director.py | 48 +++++++++++++---------------- 4 files changed, 30 insertions(+), 38 deletions(-) diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 5a20f418..3177d7ee 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -22,7 +22,7 @@ magika = Magika() def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: - files = {} + files: dict[str, dict] = {"docx": {}, "image": {}} if search_type == SearchType.All or search_type == SearchType.Org: org_config = LocalOrgConfig.objects.filter(user=user).first() diff --git a/tests/test_client.py b/tests/test_client.py index 24d2dff6..c4246a78 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -455,13 +455,13 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU @pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY") @pytest.mark.django_db(transaction=True) -def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): +async def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): # Arrange headers = {"Authorization": f"Bearer {api_user2.token}"} # Act - auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers) - no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true') + auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"', headers=headers) + no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"') # Assert assert auth_response.status_code == 200 diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index a72dae56..f9cec075 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -68,10 +68,8 @@ def test_chat_with_online_content(client_offline_chat): # Act q = "/online give me the link to paul graham's essay how to do great work" encoded_q = quote(q, safe="") - response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = client_offline_chat.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = [ @@ -92,10 +90,8 @@ def test_chat_with_online_webpage_content(client_offline_chat): # Act q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" encoded_q = quote(q, safe="") - response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = client_offline_chat.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = ["185", "1871", "horse"] diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 26d93d31..7a05a3dd 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -49,8 +49,8 @@ def create_conversation(message_list, user, agent=None): @pytest.mark.django_db(transaction=True) def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # Act - response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') - response_message = response.content.decode("utf-8") + response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"') + response_message = response.json()["response"] # Assert expected_responses = ["Khoj", "khoj"] @@ -67,10 +67,8 @@ def test_chat_with_online_content(chat_client): # Act q = "/online give me the link to paul graham's essay how to do great work" encoded_q = quote(q, safe="") - response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = chat_client.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = [ @@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(chat_client): # Act q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" encoded_q = quote(q, safe="") - response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") - response_message = response.content.decode("utf-8") - - response_message = response_message.split("### compiled references")[0] + response = chat_client.get(f"/api/chat?q={encoded_q}") + response_message = response.json()["response"] # Assert expected_responses = ["185", "1871", "horse"] @@ -144,7 +140,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho # Act response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') - response_message = response.content.decode("utf-8") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -168,7 +164,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n # Act response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"') - response_message = response.content.decode("utf-8") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -191,7 +187,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') - response_message = response.content.decode("utf-8") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -215,8 +211,8 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use create_conversation(message_list, default_user2) # Act - response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') - response_message = response.content.decode("utf-8") + response = chat_client.get(f'/api/chat?q="Where was I born?"') + response_message = response.json()["response"] # Assert expected_responses = [ @@ -226,6 +222,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use "do not have", "don't have", "where were you born?", + "where you were born?", ] assert response.status_code == 200 @@ -280,8 +277,8 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default create_conversation(message_list, default_user2) # Act - response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true") - response_message = response.content.decode("utf-8") + response = chat_client_no_background.get(f"/api/chat?q={query}") + response_message = response.json()["response"] # Assert assert response.status_code == 200 @@ -527,8 +524,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c create_conversation(message_list, default_user2) # Act - response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true') - response_message = response.content.decode("utf-8").split("### compiled references")[0] + response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else.') + response_message = response.json()["response"] # Assert expected_responses = ["test", "Test"] @@ -544,9 +541,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c @pytest.mark.chatquality def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): # Act - - response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true') - response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() + response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"') + response_message = response.json()["response"].lower() # Assert expected_responses = [ @@ -658,8 +654,8 @@ def test_answer_in_chat_history_by_conversation_id_with_agent( def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information" # Act - response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true') - response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() + response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"') + response_message = response.json()["response"].lower() # Assert expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] @@ -683,8 +679,8 @@ def test_answer_using_file_filter(chat_client): 'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' ) - response = chat_client.get(f"/api/chat?q={query}&stream=true") - response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() + response = chat_client.get(f"/api/chat?q={query}") + response_message = response.json()["response"].lower() # Assert expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] From c5ad17261642af6a0454bf6bafde0d3931103018 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 16:52:05 +0530 Subject: [PATCH 10/20] Keep loading animation at message end & reduce lists padding in Obsidian Previously loading animation would be at top of message. Moving it to bottom is more intuitve and easier to track. Remove white-space: pre from list elements. It was adding too much y axis padding to chat messages (and train of thought) --- src/interface/obsidian/src/chat_view.ts | 15 ++++++++++----- src/interface/obsidian/styles.css | 6 ++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 121d0a87..efde958b 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -1299,13 +1299,18 @@ export class KhojChatView extends KhojPaneView { handleStreamResponse(newResponseElement: HTMLElement | null, rawResponse: string, loadingEllipsis: HTMLElement | null, replace = true) { if (!newResponseElement) return; - if (replace && newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) newResponseElement.removeChild(loadingEllipsis); - } - if (replace) { - newResponseElement.innerHTML = ""; - } + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element newResponseElement.appendChild(this.formatHTMLMessage(rawResponse, false, replace)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view this.scrollChatToBottom(); } diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index afd8fd19..42c1b3ce 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -85,6 +85,12 @@ If your plugin does not need CSS, delete this file. margin-left: auto; white-space: pre-line; } +/* Override white-space for ul, ol, li under khoj-chat-message-text.khoj */ +.khoj-chat-message-text.khoj ul, +.khoj-chat-message-text.khoj ol, +.khoj-chat-message-text.khoj li { + white-space: normal; +} /* add left protrusion to khoj chat bubble */ .khoj-chat-message-text.khoj:after { content: ''; From fc33162ec6ab71d66e36887ed78ff85447e38be8 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 17:44:07 +0530 Subject: [PATCH 11/20] Use new chat streaming API to show Khoj train of thought in Desktop app Show loading spinner at end of current message --- src/interface/desktop/chat.html | 337 ++++++++++++++++++++++---------- 1 file changed, 234 insertions(+), 103 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 383fc536..3550799e 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -61,6 +61,14 @@ let city = null; let countryName = null; let timezone = null; + let chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + isVoice: false, + } fetch("https://ipapi.co/json") .then(response => response.json()) @@ -75,10 +83,9 @@ return; }); - async function chat() { - // Extract required fields for search from form + async function chat(isVoice=false) { + // Extract chat message from chat input form let query = document.getElementById("chat-input").value.trim(); - let resultsCount = localStorage.getItem("khojResultsCount") || 5; console.log(`Query: ${query}`); // Short circuit on empty query @@ -106,9 +113,6 @@ await refreshChatSessionsPanel(); } - // Generate backend API URL to execute query - let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; - let newResponseEl = document.createElement("div"); newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); @@ -119,6 +123,51 @@ newResponseEl.appendChild(newResponseTextEl); // Temporary status message to indicate that Khoj is thinking + let loadingEllipsis = createLoadingEllipsis(); + + newResponseTextEl.appendChild(loadingEllipsis); + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + + let chatTooltip = document.getElementById("chat-tooltip"); + chatTooltip.style.display = "none"; + + let chatInput = document.getElementById("chat-input"); + chatInput.classList.remove("option-enabled"); + + // Setup chat message state + chatMessageState = { + newResponseTextEl, + newResponseEl, + loadingEllipsis, + references: {}, + rawResponse: "", + rawQuery: query, + isVoice: isVoice, + } + + // Call Khoj chat API + let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`; + chatApi += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; + + const response = await fetch(chatApi, { headers }); + + try { + if (!response.ok) throw new Error(response.statusText); + if (!response.body) throw new Error("Response body is empty"); + // Stream and render chat response + await readChatStream(response); + } catch (err) { + console.error(`Khoj chat response failed with\n${err}`); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; + newResponseTextEl.textContent = errorMsg; + } + } + + function createLoadingEllipsis() { let loadingEllipsis = document.createElement("div"); loadingEllipsis.classList.add("lds-ellipsis"); @@ -139,115 +188,197 @@ loadingEllipsis.appendChild(thirdEllipsis); loadingEllipsis.appendChild(fourthEllipsis); - newResponseTextEl.appendChild(loadingEllipsis); + return loadingEllipsis; + } + + function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { + if (!newResponseElement) return; + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) + newResponseElement.removeChild(loadingEllipsis); + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element + newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + } - let chatTooltip = document.getElementById("chat-tooltip"); - chatTooltip.style.display = "none"; + function handleImageResponse(imageJson, rawResponse) { + if (imageJson.image) { + const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; - let chatInput = document.getElementById("chat-input"); - chatInput.classList.remove("option-enabled"); - - // Call Khoj chat API - let response = await fetch(chatApi, { headers }); - let rawResponse = ""; - let references = null; - const contentType = response.headers.get("content-type"); - - if (contentType === "application/json") { - // Handle JSON response - try { - const responseAsJson = await response.json(); - if (responseAsJson.image) { - // If response has image field, response is a generated image. - if (responseAsJson.intentType === "text-to-image") { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image2") { - rawResponse += `![${query}](${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image-v3") { - rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; - } - const inferredQueries = responseAsJson.inferredQueries?.[0]; - if (inferredQueries) { - rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`; - } - } - if (responseAsJson.context) { - const rawReferenceAsJson = responseAsJson.context; - references = createReferenceSection(rawReferenceAsJson); - } - if (responseAsJson.detail) { - // If response has detail field, response is an error message. - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - if (references != null) { - newResponseTextEl.appendChild(references); - } - - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); + // If response has image field, response is a generated image. + if (imageJson.intentType === "text-to-image") { + rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image2") { + rawResponse += `![generated_image](${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image-v3") { + rawResponse = `![](data:image/webp;base64,${imageJson.image})`; } - } else { - // Handle streamed response of type text/event-stream or text/plain - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let references = {}; + if (inferredQuery) { + rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; + } + } - readStream(); + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != {}) { - newResponseTextEl.appendChild(createReferenceSection(references)); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; - } + return rawResponse; + } - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + function finalizeChatBodyResponse(references, newResponseElement) { + if (!!newResponseElement && references != null && Object.keys(references).length > 0) { + newResponseElement.appendChild(createReferenceSection(references)); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input")?.removeAttribute("disabled"); + } - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); + function collectJsonsInBufferedMessageChunk(chunk) { + // Collect list of JSON objects and raw strings in the chunk + // Return the list of objects and the remaining raw string + let startIndex = chunk.indexOf('{'); + if (startIndex === -1) return { objects: [chunk], remainder: '' }; + const objects = [chunk.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - readStream(); - } else { - // Display response from Khoj - if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) { - newResponseTextEl.removeChild(loadingEllipsis); - } + for (let i = startIndex; i < chunk.length; i++) { + if (chunk[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (chunk[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = chunk.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } + } - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); + return { + objects: objects, + remainder: openBraces > 0 ? chunk.slice(startIndex) : '' + }; + } - readStream(); - } + function convertMessageChunkToJson(rawChunk) { + // Split the chunk into lines + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } + } - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + function processMessageChunk(rawChunk) { + const chunk = convertMessageChunkToJson(rawChunk); + console.debug("Chunk:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + } + } else if (chunk.type === "references") { + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + handleJsonResponse(jsonData); + } catch (e) { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } + } + + function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + } + + if (chatMessageState.newResponseTextEl) { + chatMessageState.newResponseTextEl.innerHTML = ""; + chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); + } + } + + async function readChatStream(response) { + if (!response.body) return; + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; + + while (true) { + const { value, done } = await reader.read(); + // If the stream is done + if (done) { + // Process the last chunk + processMessageChunk(buffer); + buffer = ''; + break; + } + + // Read chunk from stream and append it to the buffer + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + // Check if the buffer contains (0 or more) complete JSON objects + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + let chunks = collectJsonsInBufferedMessageChunk(buffer); + chunks.objects.forEach((chunk) => processMessageChunk(chunk)); + buffer = chunks.remainder; } } } From fafc4671737e07895241e42b8b89a5b5837eaf1b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 17:59:41 +0530 Subject: [PATCH 12/20] Put loading spinner at bottom of chat message in web client --- src/khoj/interface/web/chat.html | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 6855c196..81865da2 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -683,13 +683,19 @@ To get started, just start typing below. You can also type / to see a list of co } function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { - if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) { + if (!newResponseElement) return; + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) newResponseElement.removeChild(loadingEllipsis); - } - if (replace) { - newResponseElement.innerHTML = ""; - } + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; } @@ -777,7 +783,7 @@ To get started, just start typing below. You can also type / to see a list of co if (chunk.type ==='status') { console.log(`status: ${chunk.data}`); const statusMessage = chunk.data; - handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, null, false); + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false); } else if (chunk.type === 'start_llm_response') { console.log("Started streaming", new Date()); } else if (chunk.type === 'end_llm_response') { From e439a6ddac0f95caa88ff08e31295ac492b167a5 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 18:15:01 +0530 Subject: [PATCH 13/20] Use async/await in web client chat stream instead of promises Align streaming logic across web, desktop and obsidian clients --- src/khoj/interface/web/chat.html | 130 +++++++++++++++---------------- 1 file changed, 62 insertions(+), 68 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 81865da2..b9ed5609 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -598,8 +598,7 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - let chatBody = document.getElementById("chat-body"); - + // Extract chat message from chat input form var query = document.getElementById("chat-input").value.trim(); console.log(`Query: ${query}`); @@ -620,6 +619,16 @@ To get started, just start typing below. You can also type / to see a list of co autoResize(); document.getElementById("chat-input").setAttribute("disabled", "disabled"); + let chatBody = document.getElementById("chat-body"); + let conversationID = chatBody.dataset.conversationId; + if (!conversationID) { + let response = await fetch(`${hostURL}/api/chat/sessions`, { method: "POST" }); + let data = await response.json(); + conversationID = data.conversation_id; + chatBody.dataset.conversationId = conversationID; + await refreshChatSessionsPanel(); + } + let newResponseEl = document.createElement("div"); newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); @@ -641,20 +650,37 @@ To get started, just start typing below. You can also type / to see a list of co let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API - await sendMessageStream(query); - let rawResponse = ""; - let references = {}; - + // Setup chat message state chatMessageState = { newResponseTextEl, newResponseEl, loadingEllipsis, - references, - rawResponse, + references: {}, + rawResponse: "", rawQuery: query, isVoice: isVoice, } + + // Call Khoj chat API + let chatApi = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=web`; + chatApi += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; + + const response = await fetch(chatApi); + + try { + if (!response.ok) throw new Error(response.statusText); + if (!response.body) throw new Error("Response body is empty"); + // Stream and render chat response + await readChatStream(response); + } catch (err) { + console.error(`Khoj chat response failed with\n${err}`); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; + newResponseTextEl.innerHTML = errorMsg; + } } function createLoadingEllipse() { @@ -843,67 +869,35 @@ To get started, just start typing below. You can also type / to see a list of co } } - async function sendMessageStream(query) { - let chatBody = document.getElementById("chat-body"); - let conversationId = chatBody.dataset.conversationId; + async function readChatStream(response) { + if (!response.body) return; + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; - if (!conversationId) { - let response = await fetch('/api/chat/sessions', { method: "POST" }); - let data = await response.json(); - conversationId = data.conversation_id; - chatBody.dataset.conversationId = conversationId; - refreshChatSessionsPanel(); + while (true) { + const { value, done } = await reader.read(); + // If the stream is done + if (done) { + // Process the last chunk + processMessageChunk(buffer); + buffer = ''; + break; + } + + // Read chunk from stream and append it to the buffer + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + // Check if the buffer contains (0 or more) complete JSON objects + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + let chunks = collectJsonsInBufferedMessageChunk(buffer); + chunks.objects.forEach((chunk) => processMessageChunk(chunk)); + buffer = chunks.remainder; + } } - - let chatStreamUrl = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&stream=true&client=web`; - chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) - ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` - : ''; - - fetch(chatStreamUrl) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; - let netBracketCount = 0; - - function readStream() { - reader.read().then(({ done, value }) => { - // If the stream is done - if (done) { - // Process the last chunk - processMessageChunk(buffer); - buffer = ''; - console.log("Stream complete"); - return; - } - - // Read chunk from stream and append it to the buffer - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - // Check if the buffer contains (0 or more) complete JSON objects - netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; - if (netBracketCount === 0) { - let chunks = collectJsonsInBufferedMessageChunk(buffer); - chunks.objects.forEach(processMessageChunk); - buffer = chunks.remainder; - } - - // Continue reading the stream - readStream(); - }); - } - - readStream(); - }) - .catch(error => { - console.error('Error:', error); - if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { - chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); - } - chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" - }); } function incrementalChat(event) { From 0277d16daf068894065fba73e0c924f25a90edc0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 18:41:12 +0530 Subject: [PATCH 14/20] Share desktop chat streaming utility funcs across chat, shortcut views Null check menu, menuContainer to avoid errors on Khoj mini --- src/interface/desktop/chat.html | 216 ---------------------------- src/interface/desktop/chatutils.js | 216 ++++++++++++++++++++++++++++ src/interface/desktop/shortcut.html | 148 ++++--------------- src/interface/desktop/utils.js | 4 +- 4 files changed, 247 insertions(+), 337 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 3550799e..57657ef1 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -167,222 +167,6 @@ } } - function createLoadingEllipsis() { - let loadingEllipsis = document.createElement("div"); - loadingEllipsis.classList.add("lds-ellipsis"); - - let firstEllipsis = document.createElement("div"); - firstEllipsis.classList.add("lds-ellipsis-item"); - - let secondEllipsis = document.createElement("div"); - secondEllipsis.classList.add("lds-ellipsis-item"); - - let thirdEllipsis = document.createElement("div"); - thirdEllipsis.classList.add("lds-ellipsis-item"); - - let fourthEllipsis = document.createElement("div"); - fourthEllipsis.classList.add("lds-ellipsis-item"); - - loadingEllipsis.appendChild(firstEllipsis); - loadingEllipsis.appendChild(secondEllipsis); - loadingEllipsis.appendChild(thirdEllipsis); - loadingEllipsis.appendChild(fourthEllipsis); - - return loadingEllipsis; - } - - function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { - if (!newResponseElement) return; - // Remove loading ellipsis if it exists - if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) - newResponseElement.removeChild(loadingEllipsis); - // Clear the response element if replace is true - if (replace) newResponseElement.innerHTML = ""; - - // Append response to the response element - newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); - - // Append loading ellipsis if it exists - if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); - // Scroll to bottom of chat view - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - } - - function handleImageResponse(imageJson, rawResponse) { - if (imageJson.image) { - const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; - - // If response has image field, response is a generated image. - if (imageJson.intentType === "text-to-image") { - rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; - } else if (imageJson.intentType === "text-to-image2") { - rawResponse += `![generated_image](${imageJson.image})`; - } else if (imageJson.intentType === "text-to-image-v3") { - rawResponse = `![](data:image/webp;base64,${imageJson.image})`; - } - if (inferredQuery) { - rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; - } - } - - // If response has detail field, response is an error message. - if (imageJson.detail) rawResponse += imageJson.detail; - - return rawResponse; - } - - function finalizeChatBodyResponse(references, newResponseElement) { - if (!!newResponseElement && references != null && Object.keys(references).length > 0) { - newResponseElement.appendChild(createReferenceSection(references)); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input")?.removeAttribute("disabled"); - } - - function collectJsonsInBufferedMessageChunk(chunk) { - // Collect list of JSON objects and raw strings in the chunk - // Return the list of objects and the remaining raw string - let startIndex = chunk.indexOf('{'); - if (startIndex === -1) return { objects: [chunk], remainder: '' }; - const objects = [chunk.slice(0, startIndex)]; - let openBraces = 0; - let currentObject = ''; - - for (let i = startIndex; i < chunk.length; i++) { - if (chunk[i] === '{') { - if (openBraces === 0) startIndex = i; - openBraces++; - } - if (chunk[i] === '}') { - openBraces--; - if (openBraces === 0) { - currentObject = chunk.slice(startIndex, i + 1); - objects.push(currentObject); - currentObject = ''; - } - } - } - - return { - objects: objects, - remainder: openBraces > 0 ? chunk.slice(startIndex) : '' - }; - } - - function convertMessageChunkToJson(rawChunk) { - // Split the chunk into lines - if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { - try { - let jsonChunk = JSON.parse(rawChunk); - if (!jsonChunk.type) - jsonChunk = {type: 'message', data: jsonChunk}; - return jsonChunk; - } catch (e) { - return {type: 'message', data: rawChunk}; - } - } else if (rawChunk.length > 0) { - return {type: 'message', data: rawChunk}; - } - } - - function processMessageChunk(rawChunk) { - const chunk = convertMessageChunkToJson(rawChunk); - console.debug("Chunk:", chunk); - if (!chunk || !chunk.type) return; - if (chunk.type ==='status') { - console.log(`status: ${chunk.data}`); - const statusMessage = chunk.data; - handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false); - } else if (chunk.type === 'start_llm_response') { - console.log("Started streaming", new Date()); - } else if (chunk.type === 'end_llm_response') { - console.log("Stopped streaming", new Date()); - - // Automatically respond with voice if the subscribed user has sent voice message - if (chatMessageState.isVoice && "{{ is_active }}" == "True") - textToSpeech(chatMessageState.rawResponse); - - // Append any references after all the data has been streamed - finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); - - const liveQuery = chatMessageState.rawQuery; - // Reset variables - chatMessageState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, - isVoice: false, - } - } else if (chunk.type === "references") { - chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; - } else if (chunk.type === 'message') { - const chunkData = chunk.data; - if (typeof chunkData === 'object' && chunkData !== null) { - // If chunkData is already a JSON object - handleJsonResponse(chunkData); - } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { - // Try process chunk data as if it is a JSON object - try { - const jsonData = JSON.parse(chunkData.trim()); - handleJsonResponse(jsonData); - } catch (e) { - chatMessageState.rawResponse += chunkData; - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } else { - chatMessageState.rawResponse += chunkData; - handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); - } - } - } - - function handleJsonResponse(jsonData) { - if (jsonData.image || jsonData.detail) { - chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); - } else if (jsonData.response) { - chatMessageState.rawResponse = jsonData.response; - } - - if (chatMessageState.newResponseTextEl) { - chatMessageState.newResponseTextEl.innerHTML = ""; - chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); - } - } - - async function readChatStream(response) { - if (!response.body) return; - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; - let netBracketCount = 0; - - while (true) { - const { value, done } = await reader.read(); - // If the stream is done - if (done) { - // Process the last chunk - processMessageChunk(buffer); - buffer = ''; - break; - } - - // Read chunk from stream and append it to the buffer - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - // Check if the buffer contains (0 or more) complete JSON objects - netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; - if (netBracketCount === 0) { - let chunks = collectJsonsInBufferedMessageChunk(buffer); - chunks.objects.forEach((chunk) => processMessageChunk(chunk)); - buffer = chunks.remainder; - } - } - } - function incrementalChat(event) { if (!event.shiftKey && event.key === 'Enter') { event.preventDefault(); diff --git a/src/interface/desktop/chatutils.js b/src/interface/desktop/chatutils.js index 42cfa986..84f5e431 100644 --- a/src/interface/desktop/chatutils.js +++ b/src/interface/desktop/chatutils.js @@ -364,3 +364,219 @@ function createReferenceSection(references, createLinkerSection=false) { return referencesDiv; } + +function createLoadingEllipsis() { + let loadingEllipsis = document.createElement("div"); + loadingEllipsis.classList.add("lds-ellipsis"); + + let firstEllipsis = document.createElement("div"); + firstEllipsis.classList.add("lds-ellipsis-item"); + + let secondEllipsis = document.createElement("div"); + secondEllipsis.classList.add("lds-ellipsis-item"); + + let thirdEllipsis = document.createElement("div"); + thirdEllipsis.classList.add("lds-ellipsis-item"); + + let fourthEllipsis = document.createElement("div"); + fourthEllipsis.classList.add("lds-ellipsis-item"); + + loadingEllipsis.appendChild(firstEllipsis); + loadingEllipsis.appendChild(secondEllipsis); + loadingEllipsis.appendChild(thirdEllipsis); + loadingEllipsis.appendChild(fourthEllipsis); + + return loadingEllipsis; +} + +function handleStreamResponse(newResponseElement, rawResponse, rawQuery, loadingEllipsis, replace=true) { + if (!newResponseElement) return; + // Remove loading ellipsis if it exists + if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) + newResponseElement.removeChild(loadingEllipsis); + // Clear the response element if replace is true + if (replace) newResponseElement.innerHTML = ""; + + // Append response to the response element + newResponseElement.appendChild(formatHTMLMessage(rawResponse, false, replace, rawQuery)); + + // Append loading ellipsis if it exists + if (!replace && loadingEllipsis) newResponseElement.appendChild(loadingEllipsis); + // Scroll to bottom of chat view + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; +} + +function handleImageResponse(imageJson, rawResponse) { + if (imageJson.image) { + const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image"; + + // If response has image field, response is a generated image. + if (imageJson.intentType === "text-to-image") { + rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image2") { + rawResponse += `![generated_image](${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image-v3") { + rawResponse = `![](data:image/webp;base64,${imageJson.image})`; + } + if (inferredQuery) { + rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; + } + } + + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; + + return rawResponse; +} + +function finalizeChatBodyResponse(references, newResponseElement) { + if (!!newResponseElement && references != null && Object.keys(references).length > 0) { + newResponseElement.appendChild(createReferenceSection(references)); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input")?.removeAttribute("disabled"); +} + +function collectJsonsInBufferedMessageChunk(chunk) { + // Collect list of JSON objects and raw strings in the chunk + // Return the list of objects and the remaining raw string + let startIndex = chunk.indexOf('{'); + if (startIndex === -1) return { objects: [chunk], remainder: '' }; + const objects = [chunk.slice(0, startIndex)]; + let openBraces = 0; + let currentObject = ''; + + for (let i = startIndex; i < chunk.length; i++) { + if (chunk[i] === '{') { + if (openBraces === 0) startIndex = i; + openBraces++; + } + if (chunk[i] === '}') { + openBraces--; + if (openBraces === 0) { + currentObject = chunk.slice(startIndex, i + 1); + objects.push(currentObject); + currentObject = ''; + } + } + } + + return { + objects: objects, + remainder: openBraces > 0 ? chunk.slice(startIndex) : '' + }; +} + +function convertMessageChunkToJson(rawChunk) { + // Split the chunk into lines + if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { + try { + let jsonChunk = JSON.parse(rawChunk); + if (!jsonChunk.type) + jsonChunk = {type: 'message', data: jsonChunk}; + return jsonChunk; + } catch (e) { + return {type: 'message', data: rawChunk}; + } + } else if (rawChunk.length > 0) { + return {type: 'message', data: rawChunk}; + } +} + +function processMessageChunk(rawChunk) { + const chunk = convertMessageChunkToJson(rawChunk); + console.debug("Chunk:", chunk); + if (!chunk || !chunk.type) return; + if (chunk.type ==='status') { + console.log(`status: ${chunk.data}`); + const statusMessage = chunk.data; + handleStreamResponse(chatMessageState.newResponseTextEl, statusMessage, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, false); + } else if (chunk.type === 'start_llm_response') { + console.log("Started streaming", new Date()); + } else if (chunk.type === 'end_llm_response') { + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + isVoice: false, + } + } else if (chunk.type === "references") { + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; + } else if (chunk.type === 'message') { + const chunkData = chunk.data; + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + // Try process chunk data as if it is a JSON object + try { + const jsonData = JSON.parse(chunkData.trim()); + handleJsonResponse(jsonData); + } catch (e) { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } else { + chatMessageState.rawResponse += chunkData; + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } + } +} + +function handleJsonResponse(jsonData) { + if (jsonData.image || jsonData.detail) { + chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); + } else if (jsonData.response) { + chatMessageState.rawResponse = jsonData.response; + } + + if (chatMessageState.newResponseTextEl) { + chatMessageState.newResponseTextEl.innerHTML = ""; + chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); + } +} + +async function readChatStream(response) { + if (!response.body) return; + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let netBracketCount = 0; + + while (true) { + const { value, done } = await reader.read(); + // If the stream is done + if (done) { + // Process the last chunk + processMessageChunk(buffer); + buffer = ''; + break; + } + + // Read chunk from stream and append it to the buffer + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + // Check if the buffer contains (0 or more) complete JSON objects + netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; + if (netBracketCount === 0) { + let chunks = collectJsonsInBufferedMessageChunk(buffer); + chunks.objects.forEach((chunk) => processMessageChunk(chunk)); + buffer = chunks.remainder; + } + } +} diff --git a/src/interface/desktop/shortcut.html b/src/interface/desktop/shortcut.html index 4af26f0d..52207f20 100644 --- a/src/interface/desktop/shortcut.html +++ b/src/interface/desktop/shortcut.html @@ -346,7 +346,7 @@ inp.focus(); } - async function chat() { + async function chat(isVoice=false) { //set chat body to empty let chatBody = document.getElementById("chat-body"); chatBody.innerHTML = ""; @@ -375,9 +375,6 @@ chat_body.dataset.conversationId = conversationID; } - // Generate backend API URL to execute query - let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}`; - let newResponseEl = document.createElement("div"); newResponseEl.classList.add("chat-message", "khoj"); newResponseEl.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); @@ -388,128 +385,41 @@ newResponseEl.appendChild(newResponseTextEl); // Temporary status message to indicate that Khoj is thinking - let loadingEllipsis = document.createElement("div"); - loadingEllipsis.classList.add("lds-ellipsis"); - - let firstEllipsis = document.createElement("div"); - firstEllipsis.classList.add("lds-ellipsis-item"); - - let secondEllipsis = document.createElement("div"); - secondEllipsis.classList.add("lds-ellipsis-item"); - - let thirdEllipsis = document.createElement("div"); - thirdEllipsis.classList.add("lds-ellipsis-item"); - - let fourthEllipsis = document.createElement("div"); - fourthEllipsis.classList.add("lds-ellipsis-item"); - - loadingEllipsis.appendChild(firstEllipsis); - loadingEllipsis.appendChild(secondEllipsis); - loadingEllipsis.appendChild(thirdEllipsis); - loadingEllipsis.appendChild(fourthEllipsis); - - newResponseTextEl.appendChild(loadingEllipsis); + let loadingEllipsis = createLoadingEllipsis(); document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - // Call Khoj chat API - let response = await fetch(chatApi, { headers }); - let rawResponse = ""; - let references = null; - const contentType = response.headers.get("content-type"); toggleLoading(); - if (contentType === "application/json") { - // Handle JSON response - try { - const responseAsJson = await response.json(); - if (responseAsJson.image) { - // If response has image field, response is a generated image. - if (responseAsJson.intentType === "text-to-image") { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image2") { - rawResponse += `![${query}](${responseAsJson.image})`; - } else if (responseAsJson.intentType === "text-to-image-v3") { - rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; - } - const inferredQueries = responseAsJson.inferredQueries?.[0]; - if (inferredQueries) { - rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`; - } - } - if (responseAsJson.context) { - const rawReferenceAsJson = responseAsJson.context; - references = createReferenceSection(rawReferenceAsJson, createLinkerSection=true); - } - if (responseAsJson.detail) { - // If response has detail field, response is an error message. - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - if (references != null) { - newResponseTextEl.appendChild(references); - } + // Setup chat message state + chatMessageState = { + newResponseTextEl, + newResponseEl, + loadingEllipsis, + references: {}, + rawResponse: "", + rawQuery: query, + isVoice: isVoice, + } - document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - } - } else { - // Handle streamed response of type text/event-stream or text/plain - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let references = {}; + // Construct API URL to execute chat query + let chatApi = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationID}&stream=true&client=desktop`; + chatApi += (!!region && !!city && !!countryName && !!timezone) + ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` + : ''; - readStream(); + const response = await fetch(chatApi, { headers }); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != {}) { - newResponseTextEl.appendChild(createReferenceSection(references, createLinkerSection=true)); - } - document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - return; - } - - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); - - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - readStream(); - } else { - // Display response from Khoj - if (newResponseTextEl.getElementsByClassName("lds-ellipsis").length > 0) { - newResponseTextEl.removeChild(loadingEllipsis); - } - - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseTextEl.innerHTML = ""; - newResponseTextEl.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); - } - - // Scroll to bottom of chat window as chat response is streamed - document.body.scrollTop = document.getElementById("chat-body").scrollHeight; - }); - } + try { + if (!response.ok) throw new Error(response.statusText); + if (!response.body) throw new Error("Response body is empty"); + // Stream and render chat response + await readChatStream(response); + } catch (err) { + console.error(`Khoj chat response failed with\n${err}`); + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + let errorMsg = "Sorry, unable to get response from Khoj backend ❤️‍🩹. Retry or contact developers for help at team@khoj.dev or on Discord"; + newResponseTextEl.textContent = errorMsg; } document.body.scrollTop = document.getElementById("chat-body").scrollHeight; } diff --git a/src/interface/desktop/utils.js b/src/interface/desktop/utils.js index c880a7cd..af0234ea 100644 --- a/src/interface/desktop/utils.js +++ b/src/interface/desktop/utils.js @@ -34,8 +34,8 @@ function toggleNavMenu() { document.addEventListener('click', function(event) { let menu = document.getElementById("khoj-nav-menu"); let menuContainer = document.getElementById("khoj-nav-menu-container"); - let isClickOnMenu = menuContainer.contains(event.target) || menuContainer === event.target; - if (isClickOnMenu === false && menu.classList.contains("show")) { + let isClickOnMenu = menuContainer?.contains(event.target) || menuContainer === event.target; + if (menu && isClickOnMenu === false && menu.classList.contains("show")) { menu.classList.remove("show"); } }); From eb4e12d3c57cee94c9012ee9c13b0a40debd0be4 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 19:50:43 +0530 Subject: [PATCH 15/20] s/online_context/onlineContext chat API response field for consistency This will align the name of the online context field returned by current chat message and chat history --- src/interface/desktop/chatutils.js | 2 +- src/interface/obsidian/src/chat_view.ts | 2 +- src/khoj/interface/web/chat.html | 2 +- src/khoj/routers/api_chat.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/interface/desktop/chatutils.js b/src/interface/desktop/chatutils.js index 84f5e431..4f4fb64e 100644 --- a/src/interface/desktop/chatutils.js +++ b/src/interface/desktop/chatutils.js @@ -515,7 +515,7 @@ function processMessageChunk(rawChunk) { isVoice: false, } } else if (chunk.type === "references") { - chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; } else if (chunk.type === 'message') { const chunkData = chunk.data; if (typeof chunkData === 'object' && chunkData !== null) { diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index efde958b..a6c62fd5 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -947,7 +947,7 @@ export class KhojChatView extends KhojPaneView { isVoice: false, }; } else if (chunk.type === "references") { - this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; + this.chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; } else if (chunk.type === 'message') { const chunkData = chunk.data; if (typeof chunkData === 'object' && chunkData !== null) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index b9ed5609..616e66bc 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -834,7 +834,7 @@ To get started, just start typing below. You can also type / to see a list of co isVoice: false, } } else if (chunk.type === "references") { - chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.onlineContext}; } else if (chunk.type === 'message') { const chunkData = chunk.data; if (typeof chunkData === 'object' && chunkData !== null) { diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d8826264..019d0fa9 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -831,7 +831,7 @@ async def chat( { "inferredQueries": inferred_queries, "context": compiled_references, - "online_results": online_results, + "onlineContext": online_results, }, ): yield result @@ -887,7 +887,7 @@ async def chat( "content-type": "application/json", "intentType": intent_type, "context": compiled_references, - "online_results": online_results, + "onlineContext": online_results, "inferredQueries": [improved_image_prompt], "image": image, } From b36a7833a66d2e0d793bdfeb918665e17bd84d78 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 19:53:51 +0530 Subject: [PATCH 16/20] Remove the old mechanism of streaming compiled references Do not need response generator to stuff compiled references in chat stream using "### compiled references:" separator. References are now sent to clients as structured json while streaming --- src/khoj/processor/conversation/utils.py | 4 ---- src/khoj/routers/api_chat.py | 5 ----- 2 files changed, 9 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5d68d17d..f675d2eb 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -62,10 +62,6 @@ class ThreadedGenerator: self.queue.put(data) def close(self): - if self.compiled_references and len(self.compiled_references) > 0: - self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}") - if self.online_results and len(self.online_results) > 0: - self.queue.put(f"### compiled references:{json.dumps(self.online_results)}") self.queue.put(StopIteration) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 019d0fa9..a6c4cd57 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -938,11 +938,6 @@ async def chat( return if not connection_alive or not continue_stream: continue - # Stop streaming after compiled references section of response starts - # References are being processed via the references event rather than the message event - if "### compiled references:" in item: - continue_stream = False - item = item.split("### compiled references:")[0] try: async for result in send_event("message", f"{item}"): yield result From 70201e8db82cb86fdbd92a504b549171163b9bed Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Jul 2024 22:02:45 +0530 Subject: [PATCH 17/20] Log total, ttft chat response time on start, end llm_response events - Deduplicate code to collect chat telemetry by relying on end_llm_response event - Log time to first token and total chat response time for latency analysis of Khoj as an agent. Not just the latency of the LLM - Remove duplicate timer in the image generation path --- src/khoj/routers/api_chat.py | 81 ++++++++++++++++-------------------- src/khoj/routers/helpers.py | 23 +++++----- 2 files changed, 47 insertions(+), 57 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a6c4cd57..22fb4f03 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import time from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -22,11 +23,7 @@ from khoj.database.adapters import ( aget_user_name, ) from khoj.database.models import KhojUser -from khoj.processor.conversation.prompts import ( - help_message, - no_entries_found, - no_notes_found, -) +from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online @@ -34,7 +31,6 @@ from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, CommonQueryParams, - CommonQueryParamsClass, ConversationCommandRateLimiter, agenerate_chat_response, aget_relevant_information_sources, @@ -547,22 +543,27 @@ async def chat( ), ): async def event_generator(q: str): + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} connection_alive = True user: KhojUser = request.user.object q = unquote(q) async def send_event(event_type: str, data: str | dict): - nonlocal connection_alive + nonlocal connection_alive, ttft if not connection_alive or await request.is_disconnected(): connection_alive = False logger.warn(f"User {user} disconnected from {common.client} client") return try: + if event_type == "end_llm_response": + collect_telemetry() + if event_type == "start_llm_response": + ttft = time.perf_counter() - start_time if event_type == "message": yield data - elif event_type == "references": - yield json.dumps({"type": event_type, "data": data}) - elif stream: + elif event_type == "references" or stream: yield json.dumps({"type": event_type, "data": data}) except asyncio.CancelledError: connection_alive = False @@ -581,12 +582,36 @@ async def chat( async for result in send_event("end_llm_response", ""): yield result + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + chat_metadata["latency"] = f"{latency:.3f}" + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + metadata=chat_metadata, + ) + conversation = await ConversationAdapters.aget_conversation_by_user( user, client_application=request.user.client_app, conversation_id=conversation_id, title=title ) if not conversation: - async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"): + async for result in send_llm_response(f"Conversation {conversation_id} not found"): yield result + return await is_ready_to_chat(user) @@ -684,12 +709,6 @@ async def chat( client_application=request.user.client_app, conversation_id=conversation_id, ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) return custom_filters = [] @@ -732,17 +751,6 @@ async def chat( inferred_queries=[query_to_run], automation_id=automation.id, ) - common = CommonQueryParamsClass( - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - **common.__dict__, - ) async for result in send_llm_response(llm_response): yield result return @@ -839,12 +847,6 @@ async def chat( # Generate Output ## Generate Image Output if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) async for result in text_to_image( q, user, @@ -913,17 +915,6 @@ async def chat( user_name, ) - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) - # Send Response async for result in send_event("start_llm_response", ""): yield result diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d23df6f0..7b8af5d9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -780,18 +780,17 @@ async def text_to_image( chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" - with timer("Improve the original user query", logger): - if send_status_func: - async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): - yield {"status": event} - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - model_type=text_to_image_config.model_type, - ) + if send_status_func: + async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): + yield {"status": event} + improved_image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + ) if send_status_func: async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): From 37b8fc5577ad8b1dd154faba47fbf4d0aacd2819 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 24 Jul 2024 16:51:04 +0530 Subject: [PATCH 18/20] Extract events even when http chunk contains partial or mutiple events Previous logic was more brittle to break with simple unbalanced '{' or '}' string present in the event data. This method of trying to identify valid json obj was fairly brittle. It only allowed json objects or processed event as raw strings. Now we buffer chunk until we see our unicode magic delimiter and only then process it. This is much less likely to break based on event data and the delimiter is more tunable if we want to reduce rendering breakage likelihood further --- src/interface/desktop/chatutils.js | 49 ++++++----------------- src/interface/obsidian/src/chat_view.ts | 49 ++++++----------------- src/khoj/interface/web/chat.html | 52 +++++++------------------ src/khoj/routers/api_chat.py | 6 ++- 4 files changed, 43 insertions(+), 113 deletions(-) diff --git a/src/interface/desktop/chatutils.js b/src/interface/desktop/chatutils.js index 4f4fb64e..5213979f 100644 --- a/src/interface/desktop/chatutils.js +++ b/src/interface/desktop/chatutils.js @@ -437,36 +437,6 @@ function finalizeChatBodyResponse(references, newResponseElement) { document.getElementById("chat-input")?.removeAttribute("disabled"); } -function collectJsonsInBufferedMessageChunk(chunk) { - // Collect list of JSON objects and raw strings in the chunk - // Return the list of objects and the remaining raw string - let startIndex = chunk.indexOf('{'); - if (startIndex === -1) return { objects: [chunk], remainder: '' }; - const objects = [chunk.slice(0, startIndex)]; - let openBraces = 0; - let currentObject = ''; - - for (let i = startIndex; i < chunk.length; i++) { - if (chunk[i] === '{') { - if (openBraces === 0) startIndex = i; - openBraces++; - } - if (chunk[i] === '}') { - openBraces--; - if (openBraces === 0) { - currentObject = chunk.slice(startIndex, i + 1); - objects.push(currentObject); - currentObject = ''; - } - } - } - - return { - objects: objects, - remainder: openBraces > 0 ? chunk.slice(startIndex) : '' - }; -} - function convertMessageChunkToJson(rawChunk) { // Split the chunk into lines if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { @@ -554,8 +524,8 @@ async function readChatStream(response) { if (!response.body) return; const reader = response.body.getReader(); const decoder = new TextDecoder(); + const eventDelimiter = '␃🔚␗'; let buffer = ''; - let netBracketCount = 0; while (true) { const { value, done } = await reader.read(); @@ -569,14 +539,19 @@ async function readChatStream(response) { // Read chunk from stream and append it to the buffer const chunk = decoder.decode(value, { stream: true }); + console.debug("Raw Chunk:", chunk) + // Start buffering chunks until complete event is received buffer += chunk; - // Check if the buffer contains (0 or more) complete JSON objects - netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; - if (netBracketCount === 0) { - let chunks = collectJsonsInBufferedMessageChunk(buffer); - chunks.objects.forEach((chunk) => processMessageChunk(chunk)); - buffer = chunks.remainder; + // Once the buffer contains a complete event + let newEventIndex; + while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { + // Extract the event from the buffer + const event = buffer.slice(0, newEventIndex); + buffer = buffer.slice(newEventIndex + eventDelimiter.length); + + // Process the event + if (event) processMessageChunk(event); } } } diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index a6c62fd5..cbd0f7bf 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -869,36 +869,6 @@ export class KhojChatView extends KhojPaneView { return true; } - collectJsonsInBufferedMessageChunk(chunk: string): ChunkResult { - // Collect list of JSON objects and raw strings in the chunk - // Return the list of objects and the remaining raw string - let startIndex = chunk.indexOf('{'); - if (startIndex === -1) return { objects: [chunk], remainder: '' }; - const objects: string[] = [chunk.slice(0, startIndex)]; - let openBraces = 0; - let currentObject = ''; - - for (let i = startIndex; i < chunk.length; i++) { - if (chunk[i] === '{') { - if (openBraces === 0) startIndex = i; - openBraces++; - } - if (chunk[i] === '}') { - openBraces--; - if (openBraces === 0) { - currentObject = chunk.slice(startIndex, i + 1); - objects.push(currentObject); - currentObject = ''; - } - } - } - - return { - objects: objects, - remainder: openBraces > 0 ? chunk.slice(startIndex) : '' - }; - } - convertMessageChunkToJson(rawChunk: string): MessageChunk { if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { try { @@ -988,8 +958,8 @@ export class KhojChatView extends KhojPaneView { const reader = response.body.getReader(); const decoder = new TextDecoder(); + const eventDelimiter = '␃🔚␗'; let buffer = ''; - let netBracketCount = 0; while (true) { const { value, done } = await reader.read(); @@ -1002,14 +972,19 @@ export class KhojChatView extends KhojPaneView { } const chunk = decoder.decode(value, { stream: true }); + console.debug("Raw Chunk:", chunk) + // Start buffering chunks until complete event is received buffer += chunk; - // Check if the buffer contains (0 or more) complete JSON objects - netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; - if (netBracketCount === 0) { - let chunks = this.collectJsonsInBufferedMessageChunk(buffer); - chunks.objects.forEach((chunk) => this.processMessageChunk(chunk)); - buffer = chunks.remainder; + // Once the buffer contains a complete event + let newEventIndex; + while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { + // Extract the event from the buffer + const event = buffer.slice(0, newEventIndex); + buffer = buffer.slice(newEventIndex + eventDelimiter.length); + + // Process the event + if (event) this.processMessageChunk(event); } } } diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 616e66bc..024af9ad 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -756,38 +756,9 @@ To get started, just start typing below. You can also type / to see a list of co document.getElementById("chat-input")?.removeAttribute("disabled"); } - function collectJsonsInBufferedMessageChunk(chunk) { - // Collect list of JSON objects and raw strings in the chunk - // Return the list of objects and the remaining raw string - let startIndex = chunk.indexOf('{'); - if (startIndex === -1) return { objects: [chunk], remainder: '' }; - const objects = [chunk.slice(0, startIndex)]; - let openBraces = 0; - let currentObject = ''; - - for (let i = startIndex; i < chunk.length; i++) { - if (chunk[i] === '{') { - if (openBraces === 0) startIndex = i; - openBraces++; - } - if (chunk[i] === '}') { - openBraces--; - if (openBraces === 0) { - currentObject = chunk.slice(startIndex, i + 1); - objects.push(currentObject); - currentObject = ''; - } - } - } - - return { - objects: objects, - remainder: openBraces > 0 ? chunk.slice(startIndex) : '' - }; - } - function convertMessageChunkToJson(rawChunk) { // Split the chunk into lines + console.debug("Raw Event:", rawChunk); if (rawChunk?.startsWith("{") && rawChunk?.endsWith("}")) { try { let jsonChunk = JSON.parse(rawChunk); @@ -804,7 +775,7 @@ To get started, just start typing below. You can also type / to see a list of co function processMessageChunk(rawChunk) { const chunk = convertMessageChunkToJson(rawChunk); - console.debug("Chunk:", chunk); + console.debug("Json Event:", chunk); if (!chunk || !chunk.type) return; if (chunk.type ==='status') { console.log(`status: ${chunk.data}`); @@ -873,8 +844,8 @@ To get started, just start typing below. You can also type / to see a list of co if (!response.body) return; const reader = response.body.getReader(); const decoder = new TextDecoder(); + const eventDelimiter = '␃🔚␗'; let buffer = ''; - let netBracketCount = 0; while (true) { const { value, done } = await reader.read(); @@ -888,14 +859,19 @@ To get started, just start typing below. You can also type / to see a list of co // Read chunk from stream and append it to the buffer const chunk = decoder.decode(value, { stream: true }); + console.debug("Raw Chunk:", chunk) + // Start buffering chunks until complete event is received buffer += chunk; - // Check if the buffer contains (0 or more) complete JSON objects - netBracketCount += (chunk.match(/{/g) || []).length - (chunk.match(/}/g) || []).length; - if (netBracketCount === 0) { - let chunks = collectJsonsInBufferedMessageChunk(buffer); - chunks.objects.forEach((chunk) => processMessageChunk(chunk)); - buffer = chunks.remainder; + // Once the buffer contains a complete event + let newEventIndex; + while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) { + // Extract the event from the buffer + const event = buffer.slice(0, newEventIndex); + buffer = buffer.slice(newEventIndex + eventDelimiter.length); + + // Process the event + if (event) processMessageChunk(event); } } } diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 22fb4f03..9154bff8 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -548,6 +548,7 @@ async def chat( chat_metadata: dict = {} connection_alive = True user: KhojUser = request.user.object + event_delimiter = "␃🔚␗" q = unquote(q) async def send_event(event_type: str, data: str | dict): @@ -564,7 +565,7 @@ async def chat( if event_type == "message": yield data elif event_type == "references" or stream: - yield json.dumps({"type": event_type, "data": data}) + yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False) except asyncio.CancelledError: connection_alive = False logger.warn(f"User {user} disconnected from {common.client} client") @@ -573,6 +574,9 @@ async def chat( connection_alive = False logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) return + finally: + if stream: + yield event_delimiter async def send_llm_response(response: str): async for result in send_event("start_llm_response", ""): From ebe92ef16de3740935b41005e4ba82dbe0f9c106 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 24 Jul 2024 17:18:14 +0530 Subject: [PATCH 19/20] Do not send references twice in streamed image response Remove unused image content to reduce response payload size. References are collated, sent separately --- src/khoj/routers/api_chat.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 9154bff8..5e1cb1a8 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -890,10 +890,7 @@ async def chat( online_results=online_results, ) content_obj = { - "content-type": "application/json", "intentType": intent_type, - "context": compiled_references, - "onlineContext": online_results, "inferredQueries": [improved_image_prompt], "image": image, } From 778c571288ec116b873c94fa1aea0ab5ca3c0262 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 26 Jul 2024 00:18:37 +0530 Subject: [PATCH 20/20] Use enum to track chat stream event types in chat api router --- src/khoj/processor/tools/online_search.py | 9 +-- src/khoj/routers/api.py | 3 +- src/khoj/routers/api_chat.py | 73 ++++++++++++----------- src/khoj/routers/helpers.py | 13 +++- 4 files changed, 56 insertions(+), 42 deletions(-) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 1f8a5c9e..c087de70 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -11,6 +11,7 @@ from bs4 import BeautifulSoup from markdownify import markdownify from khoj.routers.helpers import ( + ChatEvent, extract_relevant_info, generate_online_subqueries, infer_webpage_urls, @@ -68,7 +69,7 @@ async def search_online( if send_status_func: subqueries_str = "\n- " + "\n- ".join(list(subqueries)) async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} with timer(f"Internet searches for {list(subqueries)} took", logger): search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina @@ -92,7 +93,7 @@ async def search_online( if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] results = await asyncio.gather(*tasks) @@ -131,14 +132,14 @@ async def read_webpages( logger.info(f"Inferring web pages to read") if send_status_func: async for event in send_status_func(f"**🧐 Inferring web pages to read**"): - yield {"status": event} + yield {ChatEvent.STATUS: event} urls = await infer_webpage_urls(query, conversation_history, location) logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} tasks = [read_webpage_and_extract_content(query, url) for url in urls] results = await asyncio.gather(*tasks) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 836b963f..81599dd6 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -36,6 +36,7 @@ from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, + ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, acreate_title_from_query, @@ -375,7 +376,7 @@ async def extract_references_and_questions( if send_status_func: inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n search_results.extend( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 5e1cb1a8..63529b8e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -30,6 +30,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, + ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, agenerate_chat_response, @@ -551,24 +552,24 @@ async def chat( event_delimiter = "␃🔚␗" q = unquote(q) - async def send_event(event_type: str, data: str | dict): + async def send_event(event_type: ChatEvent, data: str | dict): nonlocal connection_alive, ttft if not connection_alive or await request.is_disconnected(): connection_alive = False logger.warn(f"User {user} disconnected from {common.client} client") return try: - if event_type == "end_llm_response": + if event_type == ChatEvent.END_LLM_RESPONSE: collect_telemetry() - if event_type == "start_llm_response": + if event_type == ChatEvent.START_LLM_RESPONSE: ttft = time.perf_counter() - start_time - if event_type == "message": + if event_type == ChatEvent.MESSAGE: yield data - elif event_type == "references" or stream: - yield json.dumps({"type": event_type, "data": data}, ensure_ascii=False) - except asyncio.CancelledError: + elif event_type == ChatEvent.REFERENCES or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except asyncio.CancelledError as e: connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client") + logger.warn(f"User {user} disconnected from {common.client} client: {e}") return except Exception as e: connection_alive = False @@ -579,11 +580,11 @@ async def chat( yield event_delimiter async def send_llm_response(response: str): - async for result in send_event("start_llm_response", ""): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result - async for result in send_event("message", response): + async for result in send_event(ChatEvent.MESSAGE, response): yield result - async for result in send_event("end_llm_response", ""): + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result def collect_telemetry(): @@ -632,7 +633,7 @@ async def chat( user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_commands = [get_conversation_command(query=q, any_references=True)] - async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + async for result in send_event(ChatEvent.STATUS, f"**👀 Understanding Query**: {q}"): yield result meta_log = conversation.conversation_log @@ -642,12 +643,12 @@ async def chat( conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) async for result in send_event( - "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ChatEvent.STATUS, f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" ): yield result mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + async for result in send_event(ChatEvent.STATUS, f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: conversation_commands.append(mode) @@ -690,7 +691,7 @@ async def chat( if not q: q = "Create a general summary of the file" async for result in send_event( - "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" + ChatEvent.STATUS, f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" ): yield result @@ -771,10 +772,10 @@ async def chat( conversation_id, conversation_commands, location, - partial(send_event, "status"), + partial(send_event, ChatEvent.STATUS), ): - if isinstance(result, dict) and "status" in result: - yield result["status"] + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: compiled_references.extend(result[0]) inferred_queries.extend(result[1]) @@ -782,7 +783,7 @@ async def chat( if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): + async for result in send_event(ChatEvent.STATUS, f"**📜 Found Relevant Notes**: {headings}"): yield result online_results: Dict = dict() @@ -799,10 +800,10 @@ async def chat( if ConversationCommand.Online in conversation_commands: try: async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters + defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters ): - if isinstance(result, dict) and "status" in result: - yield result["status"] + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: online_results = result except ValueError as e: @@ -815,9 +816,11 @@ async def chat( ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: try: - async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): - if isinstance(result, dict) and "status" in result: - yield result["status"] + async for result in read_webpages( + defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: direct_web_pages = result webpages = [] @@ -829,7 +832,7 @@ async def chat( for webpage in direct_web_pages[query]["webpages"]: webpages.append(webpage["link"]) - async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): + async for result in send_event(ChatEvent.STATUS, f"**📚 Read web pages**: {webpages}"): yield result except ValueError as e: logger.warning( @@ -839,7 +842,7 @@ async def chat( ## Send Gathered References async for result in send_event( - "references", + ChatEvent.REFERENCES, { "inferredQueries": inferred_queries, "context": compiled_references, @@ -858,10 +861,10 @@ async def chat( location_data=location, references=compiled_references, online_results=online_results, - send_status_func=partial(send_event, "status"), + send_status_func=partial(send_event, ChatEvent.STATUS), ): - if isinstance(result, dict) and "status" in result: - yield result["status"] + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] else: image, status_code, improved_image_prompt, intent_type = result @@ -899,7 +902,7 @@ async def chat( return ## Generate Text Output - async for result in send_event("status", f"**💭 Generating a well-informed response**"): + async for result in send_event(ChatEvent.STATUS, f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, @@ -917,21 +920,21 @@ async def chat( ) # Send Response - async for result in send_event("start_llm_response", ""): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result continue_stream = True iterator = AsyncIteratorWrapper(llm_response) async for item in iterator: if item is None: - async for result in send_event("end_llm_response", ""): + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result logger.debug("Finished streaming response") return if not connection_alive or not continue_stream: continue try: - async for result in send_event("message", f"{item}"): + async for result in send_event(ChatEvent.MESSAGE, f"{item}"): yield result except Exception as e: continue_stream = False @@ -949,7 +952,7 @@ async def chat( async for item in iterator: try: item_json = json.loads(item) - if "type" in item_json and item_json["type"] == "references": + if "type" in item_json and item_json["type"] == ChatEvent.REFERENCES.value: response_obj = item_json["data"] except: actual_response += item diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 7b8af5d9..538b571b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -8,6 +8,7 @@ import math import re from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone +from enum import Enum from functools import partial from random import random from typing import ( @@ -782,7 +783,7 @@ async def text_to_image( if send_status_func: async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): - yield {"status": event} + yield {ChatEvent.STATUS: event} improved_image_prompt = await generate_better_image_prompt( message, chat_history, @@ -794,7 +795,7 @@ async def text_to_image( if send_status_func: async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): - yield {"status": event} + yield {ChatEvent.STATUS: event} if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): @@ -1191,3 +1192,11 @@ def construct_automation_created_message(automation: Job, crontime: str, query_t Manage your automations [here](/automations). """.strip() + + +class ChatEvent(Enum): + START_LLM_RESPONSE = "start_llm_response" + END_LLM_RESPONSE = "end_llm_response" + MESSAGE = "message" + REFERENCES = "references" + STATUS = "status"