Convert Websocket into Server Side Event (SSE) API endpoint

- Convert functions in SSE API path into async generators using yields
- Validate image generation, online, notes lookup and general paths of
  chat request are handled fine by the web client and server API
This commit is contained in:
Debanjum Singh Solanky
2024-07-21 12:10:13 +05:30
parent e694c82343
commit 91fe41106e
6 changed files with 577 additions and 489 deletions

View File

@@ -40,6 +40,7 @@ dependencies = [
"dateparser >= 1.1.1", "dateparser >= 1.1.1",
"defusedxml == 0.7.1", "defusedxml == 0.7.1",
"fastapi >= 0.104.1", "fastapi >= 0.104.1",
"sse-starlette ~= 2.1.0",
"python-multipart >= 0.0.7", "python-multipart >= 0.0.7",
"jinja2 == 3.1.4", "jinja2 == 3.1.4",
"openai >= 1.0.0", "openai >= 1.0.0",

View File

@@ -74,14 +74,14 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000); }, 1000);
}); });
} }
var websocket = null; var sseConnection = null;
let region = null; let region = null;
let city = null; let city = null;
let countryName = null; let countryName = null;
let timezone = null; let timezone = null;
let waitingForLocation = true; let waitingForLocation = true;
let websocketState = { let chatMessageState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
loadingEllipsis: null, loadingEllipsis: null,
@@ -105,7 +105,7 @@ To get started, just start typing below. You can also type / to see a list of co
.finally(() => { .finally(() => {
console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone); console.debug("Region:", region, "City:", city, "Country:", countryName, "Timezone:", timezone);
waitingForLocation = false; waitingForLocation = false;
setupWebSocket(); initializeSSE();
}); });
function formatDate(date) { function formatDate(date) {
@@ -599,10 +599,8 @@ To get started, just start typing below. You can also type / to see a list of co
} }
async function chat(isVoice=false) { async function chat(isVoice=false) {
if (websocket) { sendMessageViaSSE(isVoice);
sendMessageViaWebSocket(isVoice);
return; return;
}
let query = document.getElementById("chat-input").value.trim(); let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5; let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@@ -1069,17 +1067,13 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat; window.onload = loadChat;
function setupWebSocket(isVoice=false) { function initializeSSE(isVoice=false) {
let chatBody = document.getElementById("chat-body");
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
if (waitingForLocation) { if (waitingForLocation) {
console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available."); console.debug("Waiting for location data to be fetched. Will setup WebSocket once location data is available.");
return; return;
} }
websocketState = { chatMessageState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
loadingEllipsis: null, loadingEllipsis: null,
@@ -1088,85 +1082,55 @@ To get started, just start typing below. You can also type / to see a list of co
rawQuery: "", rawQuery: "",
isVoice: isVoice, isVoice: isVoice,
} }
}
function sendSSEMessage(query) {
let chatBody = document.getElementById("chat-body");
let sseProtocol = window.location.protocol;
let sseUrl = `/api/chat/stream?q=${query}`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`; sseUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
webSocketUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; sseUrl += (!!region && !!city && !!countryName) && !!timezone ? `&region=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : '';
websocket = new WebSocket(webSocketUrl);
websocket.onmessage = function(event) {
function handleChatResponse(event) {
// Get the last element in the chat-body // Get the last element in the chat-body
let chunk = event.data; let chunk = event.data;
if (chunk == "start_llm_response") {
console.log("Started streaming", new Date());
} else if (chunk == "end_llm_response") {
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (websocketState.isVoice && "{{ is_active }}" == "True")
textToSpeech(websocketState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseTextEl);
const liveQuery = websocketState.rawQuery;
// Reset variables
websocketState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
isVoice: false,
}
} else {
try { try {
if (chunk.includes("application/json")) if (chunk.includes("application/json"))
{
chunk = JSON.parse(chunk); chunk = JSON.parse(chunk);
}
} catch (error) { } catch (error) {
// If the chunk is not a JSON object, continue. // If the chunk is not a JSON object, continue.
} }
const contentType = chunk["content-type"] const contentType = chunk["content-type"]
if (contentType === "application/json") { if (contentType === "application/json") {
// Handle JSON response // Handle JSON response
try { try {
if (chunk.image || chunk.detail) { if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse)); ({rawResponse, references } = handleImageResponse(chunk, chatMessageState.rawResponse));
websocketState.rawResponse = rawResponse; chatMessageState.rawResponse = rawResponse;
websocketState.references = references; chatMessageState.references = references;
} else if (chunk.type == "status") {
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, null, false);
} else if (chunk.type == "rate_limit") {
handleStreamResponse(websocketState.newResponseTextEl, chunk.message, websocketState.rawQuery, websocketState.loadingEllipsis, true);
} else { } else {
rawResponse = chunk.response; rawResponse = chunk.response;
} }
} catch (error) { } catch (error) {
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk; chatMessageState.rawResponse += chunk;
} finally { } finally {
if (chunk.type != "status" && chunk.type != "rate_limit") { addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references);
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseTextEl, websocketState.references);
}
} }
} else { } else {
// Handle streamed response of type text/event-stream or text/plain // Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) { if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseTextEl, chunk, websocketState.references, websocketState.rawResponse)); ({ rawResponse, references } = handleCompiledReferences(chatMessageState.newResponseTextEl, chunk, chatMessageState.references, chatMessageState.rawResponse));
websocketState.rawResponse = rawResponse; chatMessageState.rawResponse = rawResponse;
websocketState.references = references; chatMessageState.references = references;
} else { } else {
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk; chatMessageState.rawResponse += chunk;
if (websocketState.newResponseTextEl) { if (chatMessageState.newResponseTextEl) {
handleStreamResponse(websocketState.newResponseTextEl, websocketState.rawResponse, websocketState.rawQuery, websocketState.loadingEllipsis); handleStreamResponse(chatMessageState.newResponseTextEl, chatMessageState.rawResponse, chatMessageState.rawQuery, chatMessageState.loadingEllipsis);
} }
} }
@@ -1174,35 +1138,82 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}; };
} }
}
}; };
websocket.onclose = function(event) {
websocket = null; sseConnection = new EventSource(sseUrl);
console.log("WebSocket is closed now."); sseConnection.onmessage = handleChatResponse;
let setupWebSocketButton = document.createElement("button"); sseConnection.addEventListener("complete_llm_response", handleChatResponse);
setupWebSocketButton.textContent = "Reconnect to Server"; sseConnection.addEventListener("status", (event) => {
setupWebSocketButton.onclick = setupWebSocket; console.log(`${event.data}`);
let statusDotIcon = document.getElementById("connection-status-icon"); handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, null, false);
statusDotIcon.style.backgroundColor = "red"; });
let statusDotText = document.getElementById("connection-status-text"); sseConnection.addEventListener("rate_limit", (event) => {
statusDotText.innerHTML = ""; handleStreamResponse(chatMessageState.newResponseTextEl, event.data, chatMessageState.rawQuery, chatMessageState.loadingEllipsis, true);
statusDotText.style.marginTop = "5px"; });
statusDotText.appendChild(setupWebSocketButton); sseConnection.addEventListener("start_llm_response", (event) => {
} console.log("Started streaming", new Date());
websocket.onerror = function(event) { });
console.log("WebSocket error observed:", event); sseConnection.addEventListener("end_llm_response", (event) => {
sseConnection.close();
console.log("Stopped streaming", new Date());
// Automatically respond with voice if the subscribed user has sent voice message
if (chatMessageState.isVoice && "{{ is_active }}" == "True")
textToSpeech(chatMessageState.rawResponse);
// Append any references after all the data has been streamed
finalizeChatBodyResponse(chatMessageState.references, chatMessageState.newResponseTextEl);
const liveQuery = chatMessageState.rawQuery;
// Reset variables
chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
rawQuery: liveQuery,
} }
websocket.onopen = function(event) { // Reset status icon
console.log("WebSocket is open now.")
let statusDotIcon = document.getElementById("connection-status-icon"); let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green"; statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text"); let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Connected to Server"; statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
});
sseConnection.onclose = function(event) {
sseConnection = null;
console.debug("SSE is closed now.");
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "green";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Ready";
statusDotText.style.marginTop = "5px";
}
sseConnection.onerror = function(event) {
console.log("SSE error observed:", event);
sseConnection.close();
sseConnection = null;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Server Error";
if (chatMessageState.newResponseEl.getElementsByClassName("lds-ellipsis").length > 0 && chatMessageState.loadingEllipsis) {
chatMessageState.newResponseTextEl.removeChild(chatMessageState.loadingEllipsis);
}
chatMessageState.newResponseTextEl.textContent += "Failed to get response! Try again or contact developers at team@khoj.dev"
}
sseConnection.onopen = function(event) {
console.debug("SSE is open now.")
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "orange";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.textContent = "Processing";
} }
} }
function sendMessageViaWebSocket(isVoice=false) { function sendMessageViaSSE(isVoice=false) {
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
var query = document.getElementById("chat-input").value.trim(); var query = document.getElementById("chat-input").value.trim();
@@ -1242,11 +1253,11 @@ To get started, just start typing below. You can also type / to see a list of co
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Call specified Khoj API // Call specified Khoj API
websocket.send(query); sendSSEMessage(query);
let rawResponse = ""; let rawResponse = "";
let references = {}; let references = {};
websocketState = { chatMessageState = {
newResponseTextEl, newResponseTextEl,
newResponseEl, newResponseEl,
loadingEllipsis, loadingEllipsis,
@@ -1265,7 +1276,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`; let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
setupWebSocket(); initializeSSE();
loadFileFiltersFromConversation(); loadFileFiltersFromConversation();
} }
@@ -1305,7 +1316,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation(); loadFileFiltersFromConversation();
setupWebSocket(); initializeSSE();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent; let agentMetadata = response.agent;

View File

@@ -56,7 +56,8 @@ async def search_online(
query += " ".join(custom_filters) query += " ".join(custom_filters)
if not is_internet_connected(): if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet") logger.warn("Cannot search online as not connected to internet")
return {} yield {}
return
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location) subqueries = await generate_online_subqueries(query, conversation_history, location)
@@ -66,7 +67,8 @@ async def search_online(
logger.info(f"🌐 Searching the Internet for {list(subqueries)}") logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func: if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries)) subqueries_str = "\n- " + "\n- ".join(list(subqueries))
await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}") async for event in send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}"):
yield {"status": event}
with timer(f"Internet searches for {list(subqueries)} took", logger): with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
@@ -89,7 +91,8 @@ async def search_online(
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages] tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -98,7 +101,7 @@ async def search_online(
if webpage_extract is not None: if webpage_extract is not None:
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract} response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
return response_dict yield response_dict
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
@@ -127,13 +130,15 @@ async def read_webpages(
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
if send_status_func: if send_status_func:
await send_status_func(f"**🧐 Inferring web pages to read**") async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {"status": event}
urls = await infer_webpage_urls(query, conversation_history, location) urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") async for event in send_status_func(f"**📖 Reading web pages**: {webpage_links_str}"):
yield {"status": event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls] tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -141,7 +146,7 @@ async def read_webpages(
response[query]["webpages"] = [ response[query]["webpages"] = [
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
] ]
return response yield response
async def read_webpage_and_extract_content( async def read_webpage_and_extract_content(

View File

@@ -6,7 +6,6 @@ import os
import threading import threading
import time import time
import uuid import uuid
from random import random
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import cron_descriptor import cron_descriptor
@@ -298,11 +297,13 @@ async def extract_references_and_questions(
not ConversationCommand.Notes in conversation_commands not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands and not ConversationCommand.Default in conversation_commands
): ):
return compiled_references, inferred_queries, q yield compiled_references, inferred_queries, q
return
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
return compiled_references, inferred_queries, q yield compiled_references, inferred_queries, q
return
# Extract filter terms from user message # Extract filter terms from user message
defiltered_query = q defiltered_query = q
@@ -313,7 +314,8 @@ async def extract_references_and_questions(
if not conversation: if not conversation:
logger.error(f"Conversation with id {conversation_id} not found.") logger.error(f"Conversation with id {conversation_id} not found.")
return compiled_references, inferred_queries, defiltered_query yield compiled_references, inferred_queries, defiltered_query
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False using_offline_chat = False
@@ -372,7 +374,8 @@ async def extract_references_and_questions(
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func: if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
await send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}") async for event in send_status_func(f"**🔍 Searching Documents for:** {inferred_queries_str}"):
yield {"status": event}
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n n_items = min(n, 3) if using_offline_chat else n
search_results.extend( search_results.extend(
@@ -391,7 +394,7 @@ async def extract_references_and_questions(
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results {"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
] ]
return compiled_references, inferred_queries, defiltered_query yield compiled_references, inferred_queries, defiltered_query
@api.get("/health", response_class=Response) @api.get("/health", response_class=Response)

View File

@@ -1,17 +1,18 @@
import asyncio
import json import json
import logging import logging
import math import math
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from sse_starlette import EventSourceResponse
from starlette.authentication import requires from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
@@ -526,74 +527,35 @@ async def set_conversation_title(
) )
@api_chat.websocket("/ws") @api_chat.get("/stream")
async def websocket_endpoint( async def stream_chat(
websocket: WebSocket, request: Request,
q: str,
conversation_id: int, conversation_id: int,
city: Optional[str] = None, city: Optional[str] = None,
region: Optional[str] = None, region: Optional[str] = None,
country: Optional[str] = None, country: Optional[str] = None,
timezone: Optional[str] = None, timezone: Optional[str] = None,
): ):
async def event_generator(q: str):
connection_alive = True connection_alive = True
async def send_status_update(message: str): async def send_event(event_type: str, data: str):
nonlocal connection_alive nonlocal connection_alive
if not connection_alive: if not connection_alive or await request.is_disconnected():
return
status_packet = {
"type": "status",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_complete_llm_response(llm_response: str):
nonlocal connection_alive
if not connection_alive:
return return
try: try:
await websocket.send_text("start_llm_response") if event_type == "message":
await websocket.send_text(llm_response) yield data
await websocket.send_text("end_llm_response") else:
except ConnectionClosedOK: yield {"event": event_type, "data": data, "retry": 15000}
except Exception as e:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") logger.info(f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}")
async def send_message(message: str): user: KhojUser = request.user.object
nonlocal connection_alive
if not connection_alive:
return
try:
await websocket.send_text(message)
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_rate_limit_message(message: str):
nonlocal connection_alive
if not connection_alive:
return
status_packet = {
"type": "rate_limit",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
user: KhojUser = websocket.user.object
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=websocket.user.client_app, conversation_id=conversation_id user, client_application=request.user.client_app, conversation_id=conversation_id
) )
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
@@ -609,56 +571,61 @@ async def websocket_endpoint(
if city or region or country: if city or region or country:
location = LocationData(city=city, region=region, country=country) location = LocationData(city=city, region=region, country=country)
await websocket.accept()
while connection_alive: while connection_alive:
try: try:
if conversation: if conversation:
await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"])
q = await websocket.receive_text()
# Refresh these because the connection to the database might have been closed # Refresh these because the connection to the database might have been closed
await conversation.arefresh_from_db() await conversation.arefresh_from_db()
except WebSocketDisconnect:
logger.debug(f"User {user} disconnected web socket")
break
try: try:
await sync_to_async(hourly_limiter)(websocket) await sync_to_async(hourly_limiter)(request)
await sync_to_async(daily_limiter)(websocket) await sync_to_async(daily_limiter)(request)
except HTTPException as e: except HTTPException as e:
await send_rate_limit_message(e.detail) async for result in send_event("rate_limit", e.detail):
yield result
break break
if is_query_empty(q): if is_query_empty(q):
await send_message("start_llm_response") async for event in send_event("start_llm_response", ""):
await send_message( yield event
"It seems like your query is incomplete. Could you please provide more details or specify what you need help with?" async for event in send_event(
) "message",
await send_message("end_llm_response") "It seems like your query is incomplete. Could you please provide more details or specify what you need help with?",
continue ):
yield event
async for event in send_event("end_llm_response", ""):
yield event
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = [get_conversation_command(query=q, any_references=True)]
await send_status_update(f"**👀 Understanding Query**: {q}") async for result in send_event("status", f"**👀 Understanding Query**: {q}"):
yield result
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}") async for result in send_event(
"status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task)
await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}") async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands: if mode not in conversation_commands:
conversation_commands.append(mode) conversation_commands.append(mode)
for cmd in conversation_commands: for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
file_filters = conversation.file_filters if conversation else [] file_filters = conversation.file_filters if conversation else []
@@ -675,29 +642,50 @@ async def websocket_endpoint(
elif ConversationCommand.Summarize in conversation_commands: elif ConversationCommand.Summarize in conversation_commands:
response_log = "" response_log = ""
if len(file_filters) == 0: if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = (
await send_complete_llm_response(response_log) "No files selected for summarization. Please add files using the section on the left."
)
async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
elif len(file_filters) > 1: elif len(file_filters) > 1:
response_log = "Only one file can be selected for summarization." response_log = "Only one file can be selected for summarization."
await send_complete_llm_response(response_log) async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
else: else:
try: try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0: if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
await send_complete_llm_response(response_log) async for result in send_event("complete_llm_response", response_log):
continue yield result
async for event in send_event("end_llm_response", ""):
yield event
return
contextual_data = " ".join([file.raw_text for file in file_object]) contextual_data = " ".join([file.raw_text for file in file_object])
if not q: if not q:
q = "Create a general summary of the file" q = "Create a general summary of the file"
await send_status_update(f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}") async for result in send_event(
"status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(q, contextual_data) response = await extract_relevant_summary(q, contextual_data)
response_log = str(response) response_log = str(response)
await send_complete_llm_response(response_log) async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
except Exception as e: except Exception as e:
response_log = "Error summarizing file." response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
await send_complete_llm_response(response_log) async for result in send_event("complete_llm_response", response_log):
yield result
async for event in send_event("end_llm_response", ""):
yield event
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
response_log, response_log,
@@ -705,16 +693,16 @@ async def websocket_endpoint(
meta_log, meta_log,
user_message_time, user_message_time,
intent_type="summarize", intent_type="summarize",
client_application=websocket.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
) )
update_telemetry_state( update_telemetry_state(
request=websocket, request=request,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
metadata={"conversation_command": conversation_commands[0].value}, metadata={"conversation_command": conversation_commands[0].value},
) )
continue return
custom_filters = [] custom_filters = []
if conversation_commands == [ConversationCommand.Help]: if conversation_commands == [ConversationCommand.Help]:
@@ -723,24 +711,30 @@ async def websocket_endpoint(
if conversation_config == None: if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) formatted_help = help_message.format(
await send_complete_llm_response(formatted_help) model=model_type, version=state.khoj_version, device=get_device()
continue )
# Adding specification to search online specifically on khoj.dev pages. async for result in send_event("complete_llm_response", formatted_help):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
custom_filters.append("site:khoj.dev") custom_filters.append("site:khoj.dev")
conversation_commands.append(ConversationCommand.Online) conversation_commands.append(ConversationCommand.Online)
if ConversationCommand.Automation in conversation_commands: if ConversationCommand.Automation in conversation_commands:
try: try:
automation, crontime, query_to_run, subject = await create_automation( automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, websocket.url, meta_log q, timezone, user, request.url, meta_log
) )
except Exception as e: except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}") logger.error(f"Error scheduling task {q} for {user.email}: {e}")
await send_complete_llm_response( error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
f"Unable to create automation. Ensure the automation doesn't already exist." async for result in send_event("complete_llm_response", error_message):
) yield result
continue async for event in send_event("end_llm_response", ""):
yield event
return
llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject)
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
@@ -750,57 +744,95 @@ async def websocket_endpoint(
meta_log, meta_log,
user_message_time, user_message_time,
intent_type="automation", intent_type="automation",
client_application=websocket.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
) )
common = CommonQueryParamsClass( common = CommonQueryParamsClass(
client=websocket.user.client_app, client=request.user.client_app,
user_agent=websocket.headers.get("user-agent"), user_agent=request.headers.get("user-agent"),
host=websocket.headers.get("host"), host=request.headers.get("host"),
) )
update_telemetry_state( update_telemetry_state(
request=websocket, request=request,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
**common.__dict__, **common.__dict__,
) )
await send_complete_llm_response(llm_response) async for result in send_event("complete_llm_response", llm_response):
continue yield result
async for event in send_event("end_llm_response", ""):
yield event
return
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = [], [], None
websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update async for result in extract_references_and_questions(
request,
meta_log,
q,
7,
0.18,
conversation_id,
conversation_commands,
location,
partial(send_event, "status"),
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(
set([c.get("compiled", c).split("\n")[0] for c in compiled_references])
) )
async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"):
if compiled_references: yield result
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
online_results: Dict = dict() online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(
await send_complete_llm_response(f"{no_entries_found.format()}") user
continue ):
async for result in send_event("complete_llm_response", f"{no_entries_found.format()}"):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes) conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
online_results = await search_online( async for result in search_online(
defiltered_query, meta_log, location, send_status_update, custom_filters defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters
) ):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
online_results = result
except ValueError as e: except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results") error_message = f"Error searching online: {e}. Attempting to respond without online results"
await send_complete_llm_response( logger.warning(error_message)
f"Error searching online: {e}. Attempting to respond without online results" async for result in send_event("complete_llm_response", error_message):
) yield result
continue async for event in send_event("end_llm_response", ""):
yield event
return
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update) async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, "status")
):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
direct_web_pages = result
webpages = [] webpages = []
for query in direct_web_pages: for query in direct_web_pages:
if online_results.get(query): if online_results.get(query):
@@ -810,29 +842,35 @@ async def websocket_endpoint(
for webpage in direct_web_pages[query]["webpages"]: for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"]) webpages.append(webpage["link"])
async for result in send_event("status", f"**📚 Read web pages**: {webpages}"):
await send_status_update(f"**📚 Read web pages**: {webpages}") yield result
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True f"Error directly reading webpages: {e}. Attempting to respond without online results",
exc_info=True,
) )
if ConversationCommand.Image in conversation_commands: if ConversationCommand.Image in conversation_commands:
update_telemetry_state( update_telemetry_state(
request=websocket, request=request,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
metadata={"conversation_command": conversation_commands[0].value}, metadata={"conversation_command": conversation_commands[0].value},
) )
image, status_code, improved_image_prompt, intent_type = await text_to_image( async for result in text_to_image(
q, q,
user, user,
meta_log, meta_log,
location_data=location, location_data=location,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
send_status_func=send_status_update, send_status_func=partial(send_event, "status"),
) ):
if isinstance(result, dict) and "status" in result:
yield result["status"]
else:
image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200: if image is None or status_code != 200:
content_obj = { content_obj = {
"image": image, "image": image,
@@ -840,8 +878,11 @@ async def websocket_endpoint(
"detail": improved_image_prompt, "detail": improved_image_prompt,
"content-type": "application/json", "content-type": "application/json",
} }
await send_complete_llm_response(json.dumps(content_obj)) async for result in send_event("complete_llm_response", json.dumps(content_obj)):
continue yield result
async for event in send_event("end_llm_response", ""):
yield event
return
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
@@ -851,17 +892,27 @@ async def websocket_endpoint(
user_message_time, user_message_time,
intent_type=intent_type, intent_type=intent_type,
inferred_queries=[improved_image_prompt], inferred_queries=[improved_image_prompt],
client_application=websocket.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
) )
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore content_obj = {
"image": image,
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"context": compiled_references,
"content-type": "application/json",
"online_results": online_results,
}
async for result in send_event("complete_llm_response", json.dumps(content_obj)):
yield result
async for event in send_event("end_llm_response", ""):
yield event
return
await send_complete_llm_response(json.dumps(content_obj)) async for result in send_event("status", f"**💭 Generating a well-informed response**"):
continue yield result
await send_status_update(f"**💭 Generating a well-informed response**")
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query, defiltered_query,
meta_log, meta_log,
@@ -871,7 +922,7 @@ async def websocket_endpoint(
inferred_queries, inferred_queries,
conversation_commands, conversation_commands,
user, user,
websocket.user.client_app, request.user.client_app,
conversation_id, conversation_id,
location, location,
user_name, user_name,
@@ -880,26 +931,37 @@ async def websocket_endpoint(
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state( update_telemetry_state(
request=websocket, request=request,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
metadata=chat_metadata, metadata=chat_metadata,
) )
iterator = AsyncIteratorWrapper(llm_response) iterator = AsyncIteratorWrapper(llm_response)
await send_message("start_llm_response") async for result in send_event("start_llm_response", ""):
yield result
async for item in iterator: async for item in iterator:
if item is None: if item is None:
break break
if connection_alive: if connection_alive:
try: try:
await send_message(f"{item}") async for result in send_event("message", f"{item}"):
except ConnectionClosedOK: yield result
except Exception as e:
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") logger.info(
f"User {user} disconnected SSE. Emitting rest of responses to clear thread: {e}"
)
async for result in send_event("end_llm_response", ""):
yield result
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in SSE endpoint: {e}", exc_info=True)
break
await send_message("end_llm_response") return EventSourceResponse(event_generator(q))
@api_chat.get("", response_class=Response) @api_chat.get("", response_class=Response)

View File

@@ -755,7 +755,7 @@ async def text_to_image(
references: List[Dict[str, Any]], references: List[Dict[str, Any]],
online_results: Dict[str, Any], online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
) -> Tuple[Optional[str], int, Optional[str], str]: ):
status_code = 200 status_code = 200
image = None image = None
response = None response = None
@@ -767,7 +767,8 @@ async def text_to_image(
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
message = "Failed to generate image. Setup image generation on the server." message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name text2image_model = text_to_image_config.model_name
chat_history = "" chat_history = ""
@@ -781,7 +782,8 @@ async def text_to_image(
with timer("Improve the original user query", logger): with timer("Improve the original user query", logger):
if send_status_func: if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**") async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
yield {"status": event}
improved_image_prompt = await generate_better_image_prompt( improved_image_prompt = await generate_better_image_prompt(
message, message,
chat_history, chat_history,
@@ -792,7 +794,8 @@ async def text_to_image(
) )
if send_status_func: if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):
yield {"status": event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger): with timer("Generate image with OpenAI", logger):
@@ -817,12 +820,14 @@ async def text_to_image(
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
else: else:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger): with timer("Generate image with Stability AI", logger):
@@ -844,7 +849,8 @@ async def text_to_image(
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}" message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return
with timer("Convert image to webp", logger): with timer("Convert image to webp", logger):
# Convert png to webp for faster loading # Convert png to webp for faster loading
@@ -864,7 +870,7 @@ async def text_to_image(
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8") image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value yield image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter: class ApiUserRateLimiter: