diff --git a/pyproject.toml b/pyproject.toml index 2669f5ff..939a1d9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "dateparser >= 1.1.1", "defusedxml == 0.7.1", "fastapi >= 0.104.1", + "sse-starlette ~= 2.1.0", "python-multipart >= 0.0.7", "jinja2 == 3.1.4", "openai >= 1.0.0", diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index ad8ced27..3e07a860 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -74,14 +74,14 @@ To get started, just start typing below. You can also type / to see a list of co }, 1000); }); } - var websocket = null; + var sseConnection = null; let region = null; let city = null; let countryName = null; let timezone = null; let waitingForLocation = true; - let websocketState = { + let chatMessageState = { newResponseTextEl: null, newResponseEl: null, loadingEllipsis: null, @@ -105,7 +105,7 @@ To get started, just start typing below. You can also type / to see a list of co .finally(() => { console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); waitingForLocation = false; - setupWebSocket(); + initializeSSE(); }); function formatDate(date) { @@ -599,10 +599,8 @@ To get started, just start typing below. You can also type / to see a list of co } async function chat(isVoice=false) { - if (websocket) { - sendMessageViaWebSocket(isVoice); - return; - } + sendMessageViaSSE(isVoice); + return; let query = document.getElementById("chat-input").value.trim(); let resultsCount = localStorage.getItem("khojResultsCount") || 5; @@ -1069,17 +1067,13 @@ To get started, just start typing below. You can also type / to see a list of co window.onload = loadChat; - function setupWebSocket(isVoice=false) { - let chatBody = document.getElementById("chat-body"); - let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`; - + function initializeSSE(isVoice=false) { if (waitingForLocation) { console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); return; } - websocketState = { + chatMessageState = { newResponseTextEl: null, newResponseEl: null, loadingEllipsis: null, @@ -1088,121 +1082,138 @@ To get started, just start typing below. You can also type / to see a list of co rawQuery: "", isVoice: isVoice, } + } + + function sendSSEMessage(query) { + let chatBody = document.getElementById("chat-body"); + let sseProtocol = window.location.protocol; + let sseUrl = `/api/chat/stream?q=${query}`; if (chatBody.dataset.conversationId) { - webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`; - webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; - - websocket = new WebSocket(webSocketUrl); - websocket.onmessage = function(event) { + sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`; + sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; + function handleChatResponse(event) { // Get the last element in the chat-body let chunk = event.data; - if (chunk == "start_llm_response") { - console.log("Started streaming", new Date()); - } else if (chunk == "end_llm_response") { - console.log("Stopped streaming", new Date()); + try { + if (chunk.includes("application/json")) + chunk = JSON.parse(chunk); + } catch (error) { + // If the chunk is not a JSON object, continue. + } - // Automatically respond with voice if the subscribed user has sent voice message - if (websocketState.isVoice && "{{ is_active }}" == "True") - textToSpeech(websocketState.rawResponse); - - // Append any references after all the data has been streamed - finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl); - - const liveQuery = websocketState.rawQuery; - // Reset variables - websocketState = { - newResponseTextEl: null, - newResponseEl: null, - loadingEllipsis: null, - references: {}, - rawResponse: "", - rawQuery: liveQuery, - isVoice: false, - } - } else { + const contentType = chunk["content-type"] + if (contentType === "application/json") { + // Handle JSON response try { - if (chunk.includes("application/json")) - { - chunk = JSON.parse(chunk); + if (chunk.image || chunk.detail) { + ({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse)); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else { + rawResponse = chunk.response; } } catch (error) { - // If the chunk is not a JSON object, continue. + // If the chunk is not a JSON object, just display it as is + chatMessageState.rawResponse += chunk; + } finally { + addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + if (chunk && chunk.includes("### compiled references:")) { + ({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse)); + chatMessageState.rawResponse = rawResponse; + chatMessageState.references = references; + } else { + // If the chunk is not a JSON object, just display it as is + chatMessageState.rawResponse += chunk; + if (chatMessageState.newResponseTextEl) { + handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis); + } } - const contentType = chunk["content-type"] - - if (contentType === "application/json") { - // Handle JSON response - try { - if (chunk.image || chunk.detail) { - ({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse)); - websocketState.rawResponse = rawResponse; - websocketState.references = references; - } else if (chunk.type == "status") { - handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false); - } else if (chunk.type == "rate_limit") { - handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true); - } else { - rawResponse = chunk.response; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - websocketState.rawResponse += chunk; - } finally { - if (chunk.type != "status" && chunk.type != "rate_limit") { - addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references); - } - } - } else { - - // Handle streamed response of type text/event-stream or text/plain - if (chunk && chunk.includes("### compiled references:")) { - ({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse)); - websocketState.rawResponse = rawResponse; - websocketState.references = references; - } else { - // If the chunk is not a JSON object, just display it as is - websocketState.rawResponse += chunk; - if (websocketState.newResponseTextEl) { - handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }; - } + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }; } }; - websocket.onclose = function(event) { - websocket = null; - console.log("WebSocket is closed now."); - let setupWebSocketButton = document.createElement("button"); - setupWebSocketButton.textContent = "Reconnect to Server"; - setupWebSocketButton.onclick = setupWebSocket; - let statusDotIcon = document.getElementById("connection-status-icon"); - statusDotIcon.style.backgroundColor = "red"; - let statusDotText = document.getElementById("connection-status-text"); - statusDotText.innerHTML = ""; - statusDotText.style.marginTop = "5px"; - statusDotText.appendChild(setupWebSocketButton); - } - websocket.onerror = function(event) { - console.log("WebSocket error observed:", event); - } - websocket.onopen = function(event) { - console.log("WebSocket is open now.") + sseConnection = new EventSource(sseUrl); + sseConnection.onmessage = handleChatResponse; + sseConnection.addEventListener("complete_llm_response", handleChatResponse); + sseConnection.addEventListener("status", (event) => { + console.log(`${event.data}`); + handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false); + }); + sseConnection.addEventListener("rate_limit", (event) => { + handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true); + }); + sseConnection.addEventListener("start_llm_response", (event) => { + console.log("Started streaming", new Date()); + }); + sseConnection.addEventListener("end_llm_response", (event) => { + sseConnection.close(); + console.log("Stopped streaming", new Date()); + + // Automatically respond with voice if the subscribed user has sent voice message + if (chatMessageState.isVoice && "{{ is_active }}" == "True") + textToSpeech(chatMessageState.rawResponse); + + // Append any references after all the data has been streamed + finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl); + + const liveQuery = chatMessageState.rawQuery; + // Reset variables + chatMessageState = { + newResponseTextEl: null, + newResponseEl: null, + loadingEllipsis: null, + references: {}, + rawResponse: "", + rawQuery: liveQuery, + } + + // Reset status icon let statusDotIcon = document.getElementById("connection-status-icon"); statusDotIcon.style.backgroundColor = "green"; let statusDotText = document.getElementById("connection-status-text"); - statusDotText.textContent = "Connected to Server"; + statusDotText.textContent = "Ready"; + statusDotText.style.marginTop = "5px"; + }); + sseConnection.onclose = function(event) { + sseConnection = null; + console.debug("SSE is closed now."); + let statusDotIcon = document.getElementById("connection-status-icon"); + statusDotIcon.style.backgroundColor = "green"; + let statusDotText = document.getElementById("connection-status-text"); + statusDotText.textContent = "Ready"; + statusDotText.style.marginTop = "5px"; + } + sseConnection.onerror = function(event) { + console.log("SSE error observed:", event); + sseConnection.close(); + sseConnection = null; + let statusDotIcon = document.getElementById("connection-status-icon"); + statusDotIcon.style.backgroundColor = "red"; + let statusDotText = document.getElementById("connection-status-text"); + statusDotText.textContent = "Server Error"; + if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) { + chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis); + } + chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev" + } + sseConnection.onopen = function(event) { + console.debug("SSE is open now.") + let statusDotIcon = document.getElementById("connection-status-icon"); + statusDotIcon.style.backgroundColor = "orange"; + let statusDotText = document.getElementById("connection-status-text"); + statusDotText.textContent = "Processing"; } } - function sendMessageViaWebSocket(isVoice=false) { + function sendMessageViaSSE(isVoice=false) { let chatBody = document.getElementById("chat-body"); var query = document.getElementById("chat-input").value.trim(); @@ -1242,11 +1253,11 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API - websocket.send(query); + sendSSEMessage(query); let rawResponse = ""; let references = {}; - websocketState = { + chatMessageState = { newResponseTextEl, newResponseEl, loadingEllipsis, @@ -1265,7 +1276,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatHistoryUrl = `/api/chat/history?client=web`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; - setupWebSocket(); + initializeSSE(); loadFileFiltersFromConversation(); } @@ -1305,7 +1316,7 @@ To get started, just start typing below. You can also type / to see a list of co let chatBody = document.getElementById("chat-body"); chatBody.dataset.conversationId = response.conversation_id; loadFileFiltersFromConversation(); - setupWebSocket(); + initializeSSE(); chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; let agentMetadata = response.agent; diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 72191077..1f8a5c9e 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -56,7 +56,8 @@ async def search_online( query += " ".join(custom_filters) if not is_internet_connected(): logger.warn("Cannot search online as not connected to internet") - return {} + yield {} + return # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries(query, conversation_history, location) @@ -66,7 +67,8 @@ async def search_online( logger.info(f"🌐 Searching the Internet for {list(subqueries)}") if send_status_func: subqueries_str = "\n- " + "\n- ".join(list(subqueries)) - await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}") + async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"): + yield {"status": event} with timer(f"Internet searches for {list(subqueries)} took", logger): search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina @@ -89,7 +91,8 @@ async def search_online( logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) - await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") + async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): + yield {"status": event} tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] results = await asyncio.gather(*tasks) @@ -98,7 +101,7 @@ async def search_online( if webpage_extract is not None: response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract} - return response_dict + yield response_dict async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: @@ -127,13 +130,15 @@ async def read_webpages( "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") if send_status_func: - await send_status_func(f"**🧐 Inferring web pages to read**") + async for event in send_status_func(f"**🧐 Inferring web pages to read**"): + yield {"status": event} urls = await infer_webpage_urls(query, conversation_history, location) logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) - await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") + async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"): + yield {"status": event} tasks = [read_webpage_and_extract_content(query, url) for url in urls] results = await asyncio.gather(*tasks) @@ -141,7 +146,7 @@ async def read_webpages( response[query]["webpages"] = [ {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None ] - return response + yield response async def read_webpage_and_extract_content( diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cbe19891..836b963f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,7 +6,6 @@ import os import threading import time import uuid -from random import random from typing import Any, Callable, List, Optional, Union import cron_descriptor @@ -298,11 +297,13 @@ async def extract_references_and_questions( not ConversationCommand.Notes in conversation_commands and not ConversationCommand.Default in conversation_commands ): - return compiled_references, inferred_queries, q + yield compiled_references, inferred_queries, q + return if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") - return compiled_references, inferred_queries, q + yield compiled_references, inferred_queries, q + return # Extract filter terms from user message defiltered_query = q @@ -313,7 +314,8 @@ async def extract_references_and_questions( if not conversation: logger.error(f"Conversation with id {conversation_id} not found.") - return compiled_references, inferred_queries, defiltered_query + yield compiled_references, inferred_queries, defiltered_query + return filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) using_offline_chat = False @@ -372,7 +374,8 @@ async def extract_references_and_questions( logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") if send_status_func: inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) - await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}") + async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"): + yield {"status": event} for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n search_results.extend( @@ -391,7 +394,7 @@ async def extract_references_and_questions( {"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results ] - return compiled_references, inferred_queries, defiltered_query + yield compiled_references, inferred_queries, defiltered_query @api.get("/health", response_class=Response) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index be28622b..4c3603cf 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,17 +1,18 @@ +import asyncio import json import logging import math from datetime import datetime +from functools import partial from typing import Any, Dict, List, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async -from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse +from sse_starlette import EventSourceResponse from starlette.authentication import requires -from starlette.websockets import WebSocketDisconnect -from websockets import ConnectionClosedOK from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -526,380 +527,441 @@ async def set_conversation_title( ) -@api_chat.websocket("/ws") -async def websocket_endpoint( - websocket: WebSocket, +@api_chat.get("/stream") +async def stream_chat( + request: Request, + q: str, conversation_id: int, city: Optional[str] = None, region: Optional[str] = None, country: Optional[str] = None, timezone: Optional[str] = None, ): - connection_alive = True + async def event_generator(q: str): + connection_alive = True - async def send_status_update(message: str): - nonlocal connection_alive - if not connection_alive: - return - - status_packet = { - "type": "status", - "message": message, - "content-type": "application/json", - } - try: - await websocket.send_text(json.dumps(status_packet)) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_complete_llm_response(llm_response: str): - nonlocal connection_alive - if not connection_alive: - return - try: - await websocket.send_text("start_llm_response") - await websocket.send_text(llm_response) - await websocket.send_text("end_llm_response") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_message(message: str): - nonlocal connection_alive - if not connection_alive: - return - try: - await websocket.send_text(message) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - async def send_rate_limit_message(message: str): - nonlocal connection_alive - if not connection_alive: - return - - status_packet = { - "type": "rate_limit", - "message": message, - "content-type": "application/json", - } - try: - await websocket.send_text(json.dumps(status_packet)) - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") - - user: KhojUser = websocket.user.object - conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=websocket.user.client_app, conversation_id=conversation_id - ) - - hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - - daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - - location = None - - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - await websocket.accept() - while connection_alive: - try: - if conversation: - await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) - q = await websocket.receive_text() - - # Refresh these because the connection to the database might have been closed - await conversation.arefresh_from_db() - - except WebSocketDisconnect: - logger.debug(f"User {user} disconnected web socket") - break - - try: - await sync_to_async(hourly_limiter)(websocket) - await sync_to_async(daily_limiter)(websocket) - except HTTPException as e: - await send_rate_limit_message(e.detail) - break - - if is_query_empty(q): - await send_message("start_llm_response") - await send_message( - "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?" - ) - await send_message("end_llm_response") - continue - - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - await send_status_update(f"**👀 Understanding Query**: {q}") - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}") - - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}") - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = "No files selected for summarization. Please add files using the section on the left." - await send_complete_llm_response(response_log) - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - await send_complete_llm_response(response_log) - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - await send_complete_llm_response(response_log) - continue - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}") - response = await extract_relevant_summary(q, contextual_data) - response_log = str(response) - await send_complete_llm_response(response_log) - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - await send_complete_llm_response(response_log) - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=websocket.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - continue - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - await send_complete_llm_response(formatted_help) - continue - # Adding specification to search online specifically on khoj.dev pages. - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: + async def send_event(event_type: str, data: str): + nonlocal connection_alive + if not connection_alive or await request.is_disconnected(): + return try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, websocket.url, meta_log - ) + if event_type == "message": + yield data + else: + yield {"event": event_type, "data": data, "retry": 15000} except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - await send_complete_llm_response( - f"Unable to create automation. Ensure the automation doesn't already exist." - ) - continue + connection_alive = False + logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}") - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=websocket.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - common = CommonQueryParamsClass( - client=websocket.user.client_app, - user_agent=websocket.headers.get("user-agent"), - host=websocket.headers.get("host"), - ) - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - **common.__dict__, - ) - await send_complete_llm_response(llm_response) - continue - - compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update + user: KhojUser = request.user.object + conversation = await ConversationAdapters.aget_conversation_by_user( + user, client_application=request.user.client_app, conversation_id=conversation_id ) - if compiled_references: - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - await send_status_update(f"**📜 Found Relevant Notes**: {headings}") + hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - online_results: Dict = dict() + daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - await send_complete_llm_response(f"{no_entries_found.format()}") - continue + await is_ready_to_chat(user) - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) + user_name = await aget_user_name(user) - if ConversationCommand.Online in conversation_commands: + location = None + + if city or region or country: + location = LocationData(city=city, region=region, country=country) + + while connection_alive: try: - online_results = await search_online( - defiltered_query, meta_log, location, send_status_update, custom_filters - ) - except ValueError as e: - logger.warning(f"Error searching online: {e}. Attempting to respond without online results") - await send_complete_llm_response( - f"Error searching online: {e}. Attempting to respond without online results" - ) - continue + if conversation: + await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) - if ConversationCommand.Webpage in conversation_commands: - try: - direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update) - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + # Refresh these because the connection to the database might have been closed + await conversation.arefresh_from_db() - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - - await send_status_update(f"**📚 Read web pages**: {webpages}") - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - image, status_code, improved_image_prompt, intent_type = await text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=send_status_update, - ) - if image is None or status_code != 200: - content_obj = { - "image": image, - "intentType": intent_type, - "detail": improved_image_prompt, - "content-type": "application/json", - } - await send_complete_llm_response(json.dumps(content_obj)) - continue - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=websocket.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore - - await send_complete_llm_response(json.dumps(content_obj)) - continue - - await send_status_update(f"**💭 Generating a well-informed response**") - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - websocket.user.client_app, - conversation_id, - location, - user_name, - ) - - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=websocket, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) - iterator = AsyncIteratorWrapper(llm_response) - - await send_message("start_llm_response") - - async for item in iterator: - if item is None: - break - if connection_alive: try: - await send_message(f"{item}") - except ConnectionClosedOK: - connection_alive = False - logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + await sync_to_async(hourly_limiter)(request) + await sync_to_async(daily_limiter)(request) + except HTTPException as e: + async for result in send_event("rate_limit", e.detail): + yield result + break - await send_message("end_llm_response") + if is_query_empty(q): + async for event in send_event("start_llm_response", ""): + yield event + async for event in send_event( + "message", + "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", + ): + yield event + async for event in send_event("end_llm_response", ""): + yield event + return + + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_commands = [get_conversation_command(query=q, any_references=True)] + + async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + yield result + + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result + + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) + async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands: + response_log = "" + if len(file_filters) == 0: + response_log = ( + "No files selected for summarization. Please add files using the section on the left." + ) + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + elif len(file_filters) > 1: + response_log = "Only one file can be selected for summarization." + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + else: + try: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + if len(file_object) == 0: + response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" + async for result in send_event( + "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" + ): + yield result + + response = await extract_relevant_summary(q, contextual_data) + response_log = str(response) + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + except Exception as e: + response_log = "Error summarizing file." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_event("complete_llm_response", response_log): + yield result + async for event in send_event("end_llm_response", ""): + yield event + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + return + + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format( + model=model_type, version=state.khoj_version, device=get_device() + ) + async for result in send_event("complete_llm_response", formatted_help): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_event("complete_llm_response", error_message): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + ) + common = CommonQueryParamsClass( + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + **common.__dict__, + ) + async for result in send_event("complete_llm_response", llm_response): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + meta_log, + q, + 7, + 0.18, + conversation_id, + conversation_commands, + location, + partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join( + set([c.get("compiled", c).split("\n")[0] for c in compiled_references]) + ) + async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): + yield result + + online_results: Dict = dict() + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( + user + ): + async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_event("complete_llm_response", error_message): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages( + defiltered_query, meta_log, location, partial(send_event, "status") + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + if ConversationCommand.Image in conversation_commands: + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + send_status_func=partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "image": image, + "intentType": intent_type, + "detail": improved_image_prompt, + "content-type": "application/json", + } + async for result in send_event("complete_llm_response", json.dumps(content_obj)): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + ) + content_obj = { + "image": image, + "intentType": intent_type, + "inferredQueries": [improved_image_prompt], + "context": compiled_references, + "content-type": "application/json", + "online_results": online_results, + } + async for result in send_event("complete_llm_response", json.dumps(content_obj)): + yield result + async for event in send_event("end_llm_response", ""): + yield event + return + + async for result in send_event("status", f"**💭 Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + ) + + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata=chat_metadata, + ) + iterator = AsyncIteratorWrapper(llm_response) + + async for result in send_event("start_llm_response", ""): + yield result + + async for item in iterator: + if item is None: + break + if connection_alive: + try: + async for result in send_event("message", f"{item}"): + yield result + except Exception as e: + connection_alive = False + logger.info( + f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}" + ) + async for result in send_event("end_llm_response", ""): + yield result + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in SSE endpoint: {e}", exc_info=True) + break + + return EventSourceResponse(event_generator(q)) @api_chat.get("", response_class=Response) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e0f91df7..d23df6f0 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -755,7 +755,7 @@ async def text_to_image( references: List[Dict[str, Any]], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, -) -> Tuple[Optional[str], int, Optional[str], str]: +): status_code = 200 image = None response = None @@ -767,7 +767,8 @@ async def text_to_image( # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return text2image_model = text_to_image_config.model_name chat_history = "" @@ -781,7 +782,8 @@ async def text_to_image( with timer("Improve the original user query", logger): if send_status_func: - await send_status_func("**✍🏽 Enhancing the Painting Prompt**") + async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): + yield {"status": event} improved_image_prompt = await generate_better_image_prompt( message, chat_history, @@ -792,7 +794,8 @@ async def text_to_image( ) if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"): + yield {"status": event} if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): @@ -817,12 +820,14 @@ async def text_to_image( logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: with timer("Generate image with Stability AI", logger): @@ -844,7 +849,8 @@ async def text_to_image( logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with Stability AI error: {e}" status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type.value + yield image_url or image, status_code, message, intent_type.value + return with timer("Convert image to webp", logger): # Convert png to webp for faster loading @@ -864,7 +870,7 @@ async def text_to_image( intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 image = base64.b64encode(webp_image_bytes).decode("utf-8") - return image_url or image, status_code, improved_image_prompt, intent_type.value + yield image_url or image, status_code, improved_image_prompt, intent_type.value class ApiUserRateLimiter: