Merge pull request #679 from khoj-ai/features/chat-socket-streaming

Add a websocket for streaming from the chat UI
This commit is contained in:
sabaimran
2024-04-03 04:43:31 -07:00
committed by GitHub
4 changed files with 603 additions and 121 deletions

View File

@@ -75,6 +75,7 @@ dependencies = [
"django-phonenumber-field == 7.3.0", "django-phonenumber-field == 7.3.0",
"phonenumbers == 8.13.27", "phonenumbers == 8.13.27",
"markdownify ~= 0.11.6", "markdownify ~= 0.11.6",
"websockets == 12.0",
] ]
dynamic = ["version"] dynamic = ["version"]

View File

@@ -47,11 +47,22 @@ To get started, just start typing below. You can also type / to see a list of co
}, 1000); }, 1000);
}); });
} }
var websocket = null;
var timeout = null;
var timeoutDuration = 600000; // 10 minutes
let region = null; let region = null;
let city = null; let city = null;
let countryName = null; let countryName = null;
let websocketState = {
newResponseText: null,
newResponseElement: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
}
fetch("https://ipapi.co/json") fetch("https://ipapi.co/json")
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
@@ -415,6 +426,12 @@ To get started, just start typing below. You can also type / to see a list of co
async function chat() { async function chat() {
// Extract required fields for search from form // Extract required fields for search from form
if (websocket) {
sendMessageViaWebSocket();
return;
}
let query = document.getElementById("chat-input").value.trim(); let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5; let resultsCount = localStorage.getItem("khojResultsCount") || 5;
console.log(`Query: ${query}`); console.log(`Query: ${query}`);
@@ -440,9 +457,6 @@ To get started, just start typing below. You can also type / to see a list of co
refreshChatSessionsPanel(); refreshChatSessionsPanel();
} }
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}`;
let new_response = document.createElement("div"); let new_response = document.createElement("div");
new_response.classList.add("chat-message", "khoj"); new_response.classList.add("chat-message", "khoj");
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
@@ -452,6 +466,79 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.classList.add("chat-message-text", "khoj"); newResponseText.classList.add("chat-message-text", "khoj");
new_response.appendChild(newResponseText); new_response.appendChild(newResponseText);
// Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = createLoadingEllipse();
newResponseText.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "none";
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}`;
// Call specified Khoj API
let response = await fetch(url);
let rawResponse = "";
let references = null;
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image || responseAsJson.detail) {
({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse));
} else {
rawResponse = responseAsJson.response;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
addMessageToChatBody(rawResponse, newResponseText, references);
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = {};
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
finalizeChatBodyResponse(references, newResponseText);
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse));
readStream();
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
handleStreamResponse(newResponseText, rawResponse, loadingEllipsis);
readStream();
}
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
};
function createLoadingEllipse() {
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = document.createElement("div"); let loadingEllipsis = document.createElement("div");
loadingEllipsis.classList.add("lds-ellipsis"); loadingEllipsis.classList.add("lds-ellipsis");
@@ -473,87 +560,25 @@ To get started, just start typing below. You can also type / to see a list of co
loadingEllipsis.appendChild(thirdEllipsis); loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis); loadingEllipsis.appendChild(fourthEllipsis);
newResponseText.appendChild(loadingEllipsis); return loadingEllipsis;
}
function handleStreamResponse(newResponseElement, rawResponse, loadingEllipsis, replace=true) {
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
newResponseElement.removeChild(loadingEllipsis);
}
if (replace) {
newResponseElement.innerHTML = "";
}
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "none";
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API
let response = await fetch(url);
let rawResponse = "";
let references = null;
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
if (responseAsJson.intentType === "text-to-image") {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
if (responseAsJson.context && responseAsJson.context.length > 0) {
const rawReferenceAsJson = responseAsJson.context;
references = createReferenceSection(rawReferenceAsJson);
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
if (references != null) {
newResponseText.appendChild(references);
} }
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) {
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = {};
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != {}) {
newResponseText.appendChild(createReferenceSection(references));
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0]; const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse; rawResponse += additionalResponse;
newResponseText.innerHTML = ""; rawResponseElement.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse)); rawResponseElement.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1]; const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference); const rawReferenceAsJson = JSON.parse(rawReference);
@@ -562,26 +587,53 @@ To get started, just start typing below. You can also type / to see a list of co
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson; references["online"] = rawReferenceAsJson;
} }
readStream(); return { rawResponse, references };
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseText.removeChild(loadingEllipsis);
} }
// If the chunk is not a JSON object, just display it as is function handleImageResponse(imageJson, rawResponse) {
rawResponse += chunk; if (imageJson.image) {
newResponseText.innerHTML = ""; const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
});
// Scroll to bottom of chat window as chat response is streamed // If response has image field, response is a generated image.
if (imageJson.intentType === "text-to-image") {
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image2") {
rawResponse += `![generated_image](${imageJson.image})`;
}
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
let references = {};
if (imageJson.context && imageJson.context.length > 0) {
const rawReferenceAsJson = imageJson.context;
if (rawReferenceAsJson instanceof Array) {
references["notes"] = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references["online"] = rawReferenceAsJson;
}
}
if (imageJson.detail) {
// If response has detail field, response is an error message.
rawResponse += imageJson.detail;
}
return { rawResponse, references };
}
function addMessageToChatBody(rawResponse, newResponseElement, references) {
newResponseElement.innerHTML = "";
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
finalizeChatBodyResponse(references, newResponseElement);
}
function finalizeChatBodyResponse(references, newResponseElement) {
if (references != null && Object.keys(references).length > 0) {
newResponseElement.appendChild(createReferenceSection(references));
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}; document.getElementById("chat-input").removeAttribute("disabled");
} }
};
function incrementalChat(event) { function incrementalChat(event) {
if (!event.shiftKey && event.key === 'Enter') { if (!event.shiftKey && event.key === 'Enter') {
@@ -798,6 +850,180 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat; window.onload = loadChat;
function setupWebSocket() {
let chatBody = document.getElementById("chat-body");
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
websocketState = {
newResponseText: null,
newResponseElement: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
}
function resetTimeout() {
if (timeout) {
clearTimeout(timeout);
}
timeout = setTimeout(function() {
if (websocket) {
websocket.close();
}
}, timeoutDuration);
}
if (chatBody.dataset.conversationId) {
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
webSocketUrl += `&region=${region}&city=${city}&country=${countryName}`;
websocket = new WebSocket(webSocketUrl);
websocket.onmessage = function(event) {
resetTimeout();
// Get the last element in the chat-body
let chunk = event.data;
if (chunk == "start_llm_response") {
console.log("Started streaming", new Date());
} else if(chunk == "end_llm_response") {
console.log("Stopped streaming", new Date());
// Append any references after all the data has been streamed
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseText);
// Reset variables
websocketState = {
newResponseText: null,
newResponseElement: null,
loadingEllipsis: null,
references: {},
rawResponse: "",
}
} else {
try {
if (chunk.includes("application/json"))
{
chunk = JSON.parse(chunk);
}
} catch (error) {
// If the chunk is not a JSON object, continue.
}
const contentType = chunk["content-type"]
if (contentType === "application/json") {
// Handle JSON response
try {
if (chunk.image || chunk.detail) {
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
websocketState.rawResponse = rawResponse;
websocketState.references = references;
} else if (chunk.type == "status") {
handleStreamResponse(websocketState.newResponseText, chunk.message, null, false);
} else {
rawResponse = chunk.response;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk;
} finally {
if (chunk.type != "status") {
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseText, websocketState.references);
}
}
} else {
// Handle streamed response of type text/event-stream or text/plain
if (chunk && chunk.includes("### compiled references:")) {
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseText, chunk, websocketState.references, websocketState.rawResponse));
websocketState.rawResponse = rawResponse;
websocketState.references = references;
} else {
// If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk;
if (websocketState.newResponseText) {
handleStreamResponse(websocketState.newResponseText, websocketState.rawResponse, websocketState.loadingEllipsis);
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
}
};
websocket.onclose = function(event) {
websocket = null;
console.log("WebSocket is closed now.");
let greenDot = document.getElementById("connected-green-dot");
greenDot.style.display = "none";
}
websocket.onerror = function(event) {
console.log("WebSocket error observed:", event);
}
websocket.onopen = function(event) {
console.log("WebSocket is open now.")
let greenDot = document.getElementById("connected-green-dot");
greenDot.style.display = "flex";
// Setup the timeout to close the connection after inactivity.
resetTimeout();
}
}
function sendMessageViaWebSocket(event) {
if (event) {
event.preventDefault();
}
let chatBody = document.getElementById("chat-body");
var query = document.getElementById("chat-input").value.trim();
console.log(`Query: ${query}`);
// Add message by user to chat body
renderMessage(query, "you");
document.getElementById("chat-input").value = "";
autoResize();
document.getElementById("chat-input").setAttribute("disabled", "disabled");
let newResponseElement = document.createElement("div");
newResponseElement.classList.add("chat-message", "khoj");
newResponseElement.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
chatBody.appendChild(newResponseElement);
let newResponseText = document.createElement("div");
newResponseText.classList.add("chat-message-text", "khoj");
newResponseElement.appendChild(newResponseText);
// Temporary status message to indicate that Khoj is thinking
let loadingEllipsis = createLoadingEllipse();
newResponseText.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "none";
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API
websocket.send(query);
let rawResponse = "";
let references = {};
websocketState = {
newResponseText,
newResponseElement,
loadingEllipsis,
references,
rawResponse,
}
}
function loadChat() { function loadChat() {
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = ""; chatBody.innerHTML = "";
@@ -805,6 +1031,7 @@ To get started, just start typing below. You can also type / to see a list of co
let chatHistoryUrl = `/api/chat/history?client=web`; let chatHistoryUrl = `/api/chat/history?client=web`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
setupWebSocket();
} }
if (window.screen.width < 700) { if (window.screen.width < 700) {
@@ -841,6 +1068,7 @@ To get started, just start typing below. You can also type / to see a list of co
// Render conversation history, if any // Render conversation history, if any
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationId = response.conversation_id;
setupWebSocket();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
let agentMetadata = response.agent; let agentMetadata = response.agent;
@@ -1323,6 +1551,10 @@ To get started, just start typing below. You can also type / to see a list of co
<div id="side-panel-wrapper"> <div id="side-panel-wrapper">
<div id="side-panel"> <div id="side-panel">
<div id="new-conversation"> <div id="new-conversation">
<div id="connected-green-dot" style="display: none; align-items: center; margin-bottom: 10px;">
<div style="width: 10px; height: 10px; background-color: green; border-radius: 50%; margin-right: 5px;"></div>
<div>Connected</div>
</div>
<button class="side-panel-button" id="new-conversation-button" onclick="createNewConversation()"> <button class="side-panel-button" id="new-conversation-button" onclick="createNewConversation()">
New Topic New Topic
<svg class="new-convo-button" viewBox="0 0 35 35" fill="#000000" viewBox="0 0 32 32" version="1.1" xmlns="http://www.w3.org/2000/svg"> <svg class="new-convo-button" viewBox="0 0 35 35" fill="#000000" viewBox="0 0 32 32" version="1.1" xmlns="http://www.w3.org/2000/svg">

View File

@@ -61,6 +61,36 @@ async def search(
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
): ):
user = request.user.object user = request.user.object
results = await execute_search(
user=user,
q=q,
n=n,
t=t,
r=r,
max_distance=max_distance,
dedupe=dedupe,
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="search",
**common.__dict__,
)
return results
async def execute_search(
user: KhojUser,
q: str,
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
):
start_time = time.time() start_time = time.time()
# Run validation checks # Run validation checks
@@ -155,13 +185,6 @@ async def search(
if user: if user:
state.query_cache[user.uuid][query_cache_key] = results state.query_cache[user.uuid][query_cache_key] = results
update_telemetry_state(
request=request,
telemetry_type="api",
api="search",
**common.__dict__,
)
end_time = time.time() end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
@@ -350,14 +373,14 @@ async def extract_references_and_questions(
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n n_items = min(n, 3) if using_offline_chat else n
result_list.extend( result_list.extend(
await search( await execute_search(
user,
f"{query} {filters_in_query}", f"{query} {filters_in_query}",
request=request,
n=n_items, n=n_items,
t=SearchType.All,
r=True, r=True,
max_distance=d, max_distance=d,
dedupe=False, dedupe=False,
common=common,
) )
) )
result_list = text_search.deduplicated_search_responses(result_list) result_list = text_search.deduplicated_search_responses(result_list)

View File

@@ -5,10 +5,12 @@ from typing import Dict, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request, WebSocket
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
@@ -242,6 +244,230 @@ async def set_conversation_title(
) )
@api_chat.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: int,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
):
connection_alive = True
async def send_status_update(message: str):
nonlocal connection_alive
if not connection_alive:
return
status_packet = {
"type": "status",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_complete_llm_response(llm_response: str):
nonlocal connection_alive
if not connection_alive:
return
try:
await websocket.send_text("start_llm_response")
await websocket.send_text(llm_response)
await websocket.send_text("end_llm_response")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_message(message: str):
nonlocal connection_alive
if not connection_alive:
return
try:
await websocket.send_text(message)
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
user: KhojUser = websocket.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=websocket.user.client_app, conversation_id=conversation_id
)
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
await websocket.accept()
while connection_alive:
try:
q = await websocket.receive_text()
except WebSocketDisconnect:
logger.debug(f"User {user} disconnected web socket")
break
await sync_to_async(hourly_limiter)(websocket)
await sync_to_async(daily_limiter)(websocket)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
await send_status_update(f"**Processing query**: {q}")
if conversation_commands == [ConversationCommand.Help]:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
await send_complete_llm_response(formatted_help)
continue
meta_log = conversation.conversation_log
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
mode = await aget_relevant_output_modes(q, meta_log)
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
await send_status_update(
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
)
if compiled_references:
headings = set([c.split("\n")[0] for c in compiled_references])
await send_status_update(f"**Searching references**: {headings}")
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
await send_complete_llm_response(f"{no_entries_found.format()}")
continue
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
if not online_search_enabled():
conversation_commands.remove(ConversationCommand.Online)
# If online search is not enabled, try to read webpages directly
if ConversationCommand.Webpage not in conversation_commands:
conversation_commands.append(ConversationCommand.Webpage)
else:
try:
await send_status_update("**Operation**: Searching the web for relevant information...")
online_results = await search_online(defiltered_query, meta_log, location)
online_searches = ", ".join([f"{query}" for query in online_results.keys()])
await send_status_update(f"**Online searches**: {online_searches}")
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
await send_complete_llm_response(
f"Error searching online: {e}. Attempting to respond without online results"
)
continue
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
await send_status_update("**Operation**: Augmenting your query and generating a superb image...")
intent_type = "text-to-image"
image, status_code, improved_image_prompt, image_url = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None or status_code != 200:
content_obj = {
"image": image,
"intentType": intent_type,
"detail": improved_image_prompt,
"content-type": "application/json",
}
await send_complete_llm_response(json.dumps(content_obj))
continue
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=websocket.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
await send_complete_llm_response(json.dumps(content_obj))
continue
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,
conversation_commands,
user,
websocket.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=websocket,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
iterator = AsyncIteratorWrapper(llm_response)
await send_message("start_llm_response")
async for item in iterator:
if item is None:
break
if connection_alive:
try:
await send_message(f"{item}")
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
await send_message("end_llm_response")
@api_chat.get("", response_class=Response) @api_chat.get("", response_class=Response)
@requires(["authenticated"]) @requires(["authenticated"])
async def chat( async def chat(