mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Merge pull request #679 from khoj-ai/features/chat-socket-streaming
Add a websocket for streaming from the chat UI
This commit is contained in:
@@ -75,6 +75,7 @@ dependencies = [
|
|||||||
"django-phonenumber-field == 7.3.0",
|
"django-phonenumber-field == 7.3.0",
|
||||||
"phonenumbers == 8.13.27",
|
"phonenumbers == 8.13.27",
|
||||||
"markdownify ~= 0.11.6",
|
"markdownify ~= 0.11.6",
|
||||||
|
"websockets == 12.0",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|||||||
@@ -47,11 +47,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
}, 1000);
|
}, 1000);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
var websocket = null;
|
||||||
|
var timeout = null;
|
||||||
|
var timeoutDuration = 600000; // 10 minutes
|
||||||
|
|
||||||
let region = null;
|
let region = null;
|
||||||
let city = null;
|
let city = null;
|
||||||
let countryName = null;
|
let countryName = null;
|
||||||
|
|
||||||
|
let websocketState = {
|
||||||
|
newResponseText: null,
|
||||||
|
newResponseElement: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
}
|
||||||
|
|
||||||
fetch("https://ipapi.co/json")
|
fetch("https://ipapi.co/json")
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
@@ -415,6 +426,12 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
|
|
||||||
async function chat() {
|
async function chat() {
|
||||||
// Extract required fields for search from form
|
// Extract required fields for search from form
|
||||||
|
|
||||||
|
if (websocket) {
|
||||||
|
sendMessageViaWebSocket();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let query = document.getElementById("chat-input").value.trim();
|
let query = document.getElementById("chat-input").value.trim();
|
||||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
||||||
console.log(`Query: ${query}`);
|
console.log(`Query: ${query}`);
|
||||||
@@ -440,9 +457,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
refreshChatSessionsPanel();
|
refreshChatSessionsPanel();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate backend API URL to execute query
|
|
||||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}`;
|
|
||||||
|
|
||||||
let new_response = document.createElement("div");
|
let new_response = document.createElement("div");
|
||||||
new_response.classList.add("chat-message", "khoj");
|
new_response.classList.add("chat-message", "khoj");
|
||||||
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||||
@@ -452,6 +466,79 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
newResponseText.classList.add("chat-message-text", "khoj");
|
newResponseText.classList.add("chat-message-text", "khoj");
|
||||||
new_response.appendChild(newResponseText);
|
new_response.appendChild(newResponseText);
|
||||||
|
|
||||||
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
|
let loadingEllipsis = createLoadingEllipse();
|
||||||
|
|
||||||
|
newResponseText.appendChild(loadingEllipsis);
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
|
||||||
|
let chatTooltip = document.getElementById("chat-tooltip");
|
||||||
|
chatTooltip.style.display = "none";
|
||||||
|
|
||||||
|
let chatInput = document.getElementById("chat-input");
|
||||||
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
|
// Generate backend API URL to execute query
|
||||||
|
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}`;
|
||||||
|
|
||||||
|
// Call specified Khoj API
|
||||||
|
let response = await fetch(url);
|
||||||
|
let rawResponse = "";
|
||||||
|
let references = null;
|
||||||
|
const contentType = response.headers.get("content-type");
|
||||||
|
|
||||||
|
if (contentType === "application/json") {
|
||||||
|
// Handle JSON response
|
||||||
|
try {
|
||||||
|
const responseAsJson = await response.json();
|
||||||
|
if (responseAsJson.image || responseAsJson.detail) {
|
||||||
|
({rawResponse, references } = handleImageResponse(responseAsJson, rawResponse));
|
||||||
|
} else {
|
||||||
|
rawResponse = responseAsJson.response;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
addMessageToChatBody(rawResponse, newResponseText, references);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle streamed response of type text/event-stream or text/plain
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let references = {};
|
||||||
|
|
||||||
|
readStream();
|
||||||
|
|
||||||
|
function readStream() {
|
||||||
|
reader.read().then(({ done, value }) => {
|
||||||
|
if (done) {
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
finalizeChatBodyResponse(references, newResponseText);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode message chunk from stream
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
if (chunk.includes("### compiled references:")) {
|
||||||
|
({ rawResponse, references } = handleCompiledReferences(newResponseText, chunk, references, rawResponse));
|
||||||
|
readStream();
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
handleStreamResponse(newResponseText, rawResponse, loadingEllipsis);
|
||||||
|
readStream();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
function createLoadingEllipse() {
|
||||||
// Temporary status message to indicate that Khoj is thinking
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
let loadingEllipsis = document.createElement("div");
|
let loadingEllipsis = document.createElement("div");
|
||||||
loadingEllipsis.classList.add("lds-ellipsis");
|
loadingEllipsis.classList.add("lds-ellipsis");
|
||||||
@@ -473,115 +560,80 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
loadingEllipsis.appendChild(thirdEllipsis);
|
loadingEllipsis.appendChild(thirdEllipsis);
|
||||||
loadingEllipsis.appendChild(fourthEllipsis);
|
loadingEllipsis.appendChild(fourthEllipsis);
|
||||||
|
|
||||||
newResponseText.appendChild(loadingEllipsis);
|
return loadingEllipsis;
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
}
|
||||||
|
|
||||||
let chatTooltip = document.getElementById("chat-tooltip");
|
function handleStreamResponse(newResponseElement, rawResponse, loadingEllipsis, replace=true) {
|
||||||
chatTooltip.style.display = "none";
|
if (newResponseElement.getElementsByClassName("lds-ellipsis").length > 0 && loadingEllipsis) {
|
||||||
|
newResponseElement.removeChild(loadingEllipsis);
|
||||||
let chatInput = document.getElementById("chat-input");
|
|
||||||
chatInput.classList.remove("option-enabled");
|
|
||||||
|
|
||||||
// Call specified Khoj API
|
|
||||||
let response = await fetch(url);
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = null;
|
|
||||||
const contentType = response.headers.get("content-type");
|
|
||||||
|
|
||||||
if (contentType === "application/json") {
|
|
||||||
// Handle JSON response
|
|
||||||
try {
|
|
||||||
const responseAsJson = await response.json();
|
|
||||||
if (responseAsJson.image) {
|
|
||||||
// If response has image field, response is a generated image.
|
|
||||||
if (responseAsJson.intentType === "text-to-image") {
|
|
||||||
rawResponse += ``;
|
|
||||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
|
||||||
rawResponse += ``;
|
|
||||||
}
|
|
||||||
const inferredQuery = responseAsJson.inferredQueries?.[0];
|
|
||||||
if (inferredQuery) {
|
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (responseAsJson.context && responseAsJson.context.length > 0) {
|
|
||||||
const rawReferenceAsJson = responseAsJson.context;
|
|
||||||
references = createReferenceSection(rawReferenceAsJson);
|
|
||||||
}
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
// If response has detail field, response is an error message.
|
|
||||||
rawResponse += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
if (references != null) {
|
|
||||||
newResponseText.appendChild(references);
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle streamed response of type text/event-stream or text/plain
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let references = {};
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
|
|
||||||
function readStream() {
|
|
||||||
reader.read().then(({ done, value }) => {
|
|
||||||
if (done) {
|
|
||||||
// Append any references after all the data has been streamed
|
|
||||||
if (references != {}) {
|
|
||||||
newResponseText.appendChild(createReferenceSection(references));
|
|
||||||
}
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode message chunk from stream
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
references["notes"] = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
references["online"] = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
readStream();
|
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
|
|
||||||
newResponseText.removeChild(loadingEllipsis);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
};
|
if (replace) {
|
||||||
|
newResponseElement.innerHTML = "";
|
||||||
|
}
|
||||||
|
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleCompiledReferences(rawResponseElement, chunk, references, rawResponse) {
|
||||||
|
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||||
|
rawResponse += additionalResponse;
|
||||||
|
rawResponseElement.innerHTML = "";
|
||||||
|
rawResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
const rawReference = chunk.split("### compiled references:")[1];
|
||||||
|
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||||
|
if (rawReferenceAsJson instanceof Array) {
|
||||||
|
references["notes"] = rawReferenceAsJson;
|
||||||
|
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||||
|
references["online"] = rawReferenceAsJson;
|
||||||
|
}
|
||||||
|
return { rawResponse, references };
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleImageResponse(imageJson, rawResponse) {
|
||||||
|
if (imageJson.image) {
|
||||||
|
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||||
|
|
||||||
|
// If response has image field, response is a generated image.
|
||||||
|
if (imageJson.intentType === "text-to-image") {
|
||||||
|
rawResponse += ``;
|
||||||
|
} else if (imageJson.intentType === "text-to-image2") {
|
||||||
|
rawResponse += ``;
|
||||||
|
}
|
||||||
|
if (inferredQuery) {
|
||||||
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let references = {};
|
||||||
|
if (imageJson.context && imageJson.context.length > 0) {
|
||||||
|
const rawReferenceAsJson = imageJson.context;
|
||||||
|
if (rawReferenceAsJson instanceof Array) {
|
||||||
|
references["notes"] = rawReferenceAsJson;
|
||||||
|
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||||
|
references["online"] = rawReferenceAsJson;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (imageJson.detail) {
|
||||||
|
// If response has detail field, response is an error message.
|
||||||
|
rawResponse += imageJson.detail;
|
||||||
|
}
|
||||||
|
return { rawResponse, references };
|
||||||
|
}
|
||||||
|
|
||||||
|
function addMessageToChatBody(rawResponse, newResponseElement, references) {
|
||||||
|
newResponseElement.innerHTML = "";
|
||||||
|
newResponseElement.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
finalizeChatBodyResponse(references, newResponseElement);
|
||||||
|
}
|
||||||
|
|
||||||
|
function finalizeChatBodyResponse(references, newResponseElement) {
|
||||||
|
if (references != null && Object.keys(references).length > 0) {
|
||||||
|
newResponseElement.appendChild(createReferenceSection(references));
|
||||||
|
}
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
if (!event.shiftKey && event.key === 'Enter') {
|
if (!event.shiftKey && event.key === 'Enter') {
|
||||||
@@ -798,6 +850,180 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
|
|
||||||
window.onload = loadChat;
|
window.onload = loadChat;
|
||||||
|
|
||||||
|
function setupWebSocket() {
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
let wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
|
let webSocketUrl = `${wsProtocol}//${window.location.host}/api/chat/ws`;
|
||||||
|
|
||||||
|
websocketState = {
|
||||||
|
newResponseText: null,
|
||||||
|
newResponseElement: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
function resetTimeout() {
|
||||||
|
if (timeout) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout = setTimeout(function() {
|
||||||
|
if (websocket) {
|
||||||
|
websocket.close();
|
||||||
|
}
|
||||||
|
}, timeoutDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chatBody.dataset.conversationId) {
|
||||||
|
webSocketUrl += `?conversation_id=${chatBody.dataset.conversationId}`;
|
||||||
|
webSocketUrl += `®ion=${region}&city=${city}&country=${countryName}`;
|
||||||
|
|
||||||
|
websocket = new WebSocket(webSocketUrl);
|
||||||
|
websocket.onmessage = function(event) {
|
||||||
|
resetTimeout();
|
||||||
|
|
||||||
|
// Get the last element in the chat-body
|
||||||
|
let chunk = event.data;
|
||||||
|
if (chunk == "start_llm_response") {
|
||||||
|
console.log("Started streaming", new Date());
|
||||||
|
} else if(chunk == "end_llm_response") {
|
||||||
|
console.log("Stopped streaming", new Date());
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
finalizeChatBodyResponse(websocketState.references, websocketState.newResponseText);
|
||||||
|
|
||||||
|
// Reset variables
|
||||||
|
websocketState = {
|
||||||
|
newResponseText: null,
|
||||||
|
newResponseElement: null,
|
||||||
|
loadingEllipsis: null,
|
||||||
|
references: {},
|
||||||
|
rawResponse: "",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
if (chunk.includes("application/json"))
|
||||||
|
{
|
||||||
|
chunk = JSON.parse(chunk);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, continue.
|
||||||
|
}
|
||||||
|
|
||||||
|
const contentType = chunk["content-type"]
|
||||||
|
|
||||||
|
if (contentType === "application/json") {
|
||||||
|
// Handle JSON response
|
||||||
|
try {
|
||||||
|
if (chunk.image || chunk.detail) {
|
||||||
|
({rawResponse, references } = handleImageResponse(chunk, websocketState.rawResponse));
|
||||||
|
websocketState.rawResponse = rawResponse;
|
||||||
|
websocketState.references = references;
|
||||||
|
} else if (chunk.type == "status") {
|
||||||
|
handleStreamResponse(websocketState.newResponseText, chunk.message, null, false);
|
||||||
|
} else {
|
||||||
|
rawResponse = chunk.response;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
websocketState.rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
if (chunk.type != "status") {
|
||||||
|
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseText, websocketState.references);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// Handle streamed response of type text/event-stream or text/plain
|
||||||
|
if (chunk && chunk.includes("### compiled references:")) {
|
||||||
|
({ rawResponse, references } = handleCompiledReferences(websocketState.newResponseText, chunk, websocketState.references, websocketState.rawResponse));
|
||||||
|
websocketState.rawResponse = rawResponse;
|
||||||
|
websocketState.references = references;
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
websocketState.rawResponse += chunk;
|
||||||
|
if (websocketState.newResponseText) {
|
||||||
|
handleStreamResponse(websocketState.newResponseText, websocketState.rawResponse, websocketState.loadingEllipsis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
websocket.onclose = function(event) {
|
||||||
|
websocket = null;
|
||||||
|
console.log("WebSocket is closed now.");
|
||||||
|
let greenDot = document.getElementById("connected-green-dot");
|
||||||
|
greenDot.style.display = "none";
|
||||||
|
}
|
||||||
|
websocket.onerror = function(event) {
|
||||||
|
console.log("WebSocket error observed:", event);
|
||||||
|
}
|
||||||
|
|
||||||
|
websocket.onopen = function(event) {
|
||||||
|
console.log("WebSocket is open now.")
|
||||||
|
let greenDot = document.getElementById("connected-green-dot");
|
||||||
|
greenDot.style.display = "flex";
|
||||||
|
|
||||||
|
// Setup the timeout to close the connection after inactivity.
|
||||||
|
resetTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function sendMessageViaWebSocket(event) {
|
||||||
|
if (event) {
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
|
||||||
|
var query = document.getElementById("chat-input").value.trim();
|
||||||
|
console.log(`Query: ${query}`);
|
||||||
|
|
||||||
|
// Add message by user to chat body
|
||||||
|
renderMessage(query, "you");
|
||||||
|
document.getElementById("chat-input").value = "";
|
||||||
|
autoResize();
|
||||||
|
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||||
|
|
||||||
|
let newResponseElement = document.createElement("div");
|
||||||
|
newResponseElement.classList.add("chat-message", "khoj");
|
||||||
|
newResponseElement.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||||
|
chatBody.appendChild(newResponseElement);
|
||||||
|
|
||||||
|
let newResponseText = document.createElement("div");
|
||||||
|
newResponseText.classList.add("chat-message-text", "khoj");
|
||||||
|
newResponseElement.appendChild(newResponseText);
|
||||||
|
|
||||||
|
// Temporary status message to indicate that Khoj is thinking
|
||||||
|
let loadingEllipsis = createLoadingEllipse();
|
||||||
|
|
||||||
|
newResponseText.appendChild(loadingEllipsis);
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
|
||||||
|
let chatTooltip = document.getElementById("chat-tooltip");
|
||||||
|
chatTooltip.style.display = "none";
|
||||||
|
|
||||||
|
let chatInput = document.getElementById("chat-input");
|
||||||
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
|
// Call specified Khoj API
|
||||||
|
websocket.send(query);
|
||||||
|
let rawResponse = "";
|
||||||
|
let references = {};
|
||||||
|
|
||||||
|
websocketState = {
|
||||||
|
newResponseText,
|
||||||
|
newResponseElement,
|
||||||
|
loadingEllipsis,
|
||||||
|
references,
|
||||||
|
rawResponse,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function loadChat() {
|
function loadChat() {
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
chatBody.innerHTML = "";
|
chatBody.innerHTML = "";
|
||||||
@@ -805,6 +1031,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
let chatHistoryUrl = `/api/chat/history?client=web`;
|
let chatHistoryUrl = `/api/chat/history?client=web`;
|
||||||
if (chatBody.dataset.conversationId) {
|
if (chatBody.dataset.conversationId) {
|
||||||
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||||
|
setupWebSocket();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (window.screen.width < 700) {
|
if (window.screen.width < 700) {
|
||||||
@@ -841,6 +1068,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
// Render conversation history, if any
|
// Render conversation history, if any
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
chatBody.dataset.conversationId = response.conversation_id;
|
chatBody.dataset.conversationId = response.conversation_id;
|
||||||
|
setupWebSocket();
|
||||||
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
||||||
|
|
||||||
let agentMetadata = response.agent;
|
let agentMetadata = response.agent;
|
||||||
@@ -1323,6 +1551,10 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
<div id="side-panel-wrapper">
|
<div id="side-panel-wrapper">
|
||||||
<div id="side-panel">
|
<div id="side-panel">
|
||||||
<div id="new-conversation">
|
<div id="new-conversation">
|
||||||
|
<div id="connected-green-dot" style="display: none; align-items: center; margin-bottom: 10px;">
|
||||||
|
<div style="width: 10px; height: 10px; background-color: green; border-radius: 50%; margin-right: 5px;"></div>
|
||||||
|
<div>Connected</div>
|
||||||
|
</div>
|
||||||
<button class="side-panel-button" id="new-conversation-button" onclick="createNewConversation()">
|
<button class="side-panel-button" id="new-conversation-button" onclick="createNewConversation()">
|
||||||
New Topic
|
New Topic
|
||||||
<svg class="new-convo-button" viewBox="0 0 35 35" fill="#000000" viewBox="0 0 32 32" version="1.1" xmlns="http://www.w3.org/2000/svg">
|
<svg class="new-convo-button" viewBox="0 0 35 35" fill="#000000" viewBox="0 0 32 32" version="1.1" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
|||||||
@@ -61,6 +61,36 @@ async def search(
|
|||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
|
results = await execute_search(
|
||||||
|
user=user,
|
||||||
|
q=q,
|
||||||
|
n=n,
|
||||||
|
t=t,
|
||||||
|
r=r,
|
||||||
|
max_distance=max_distance,
|
||||||
|
dedupe=dedupe,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="search",
|
||||||
|
**common.__dict__,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_search(
|
||||||
|
user: KhojUser,
|
||||||
|
q: str,
|
||||||
|
n: Optional[int] = 5,
|
||||||
|
t: Optional[SearchType] = SearchType.All,
|
||||||
|
r: Optional[bool] = False,
|
||||||
|
max_distance: Optional[Union[float, None]] = None,
|
||||||
|
dedupe: Optional[bool] = True,
|
||||||
|
):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Run validation checks
|
# Run validation checks
|
||||||
@@ -155,13 +185,6 @@ async def search(
|
|||||||
if user:
|
if user:
|
||||||
state.query_cache[user.uuid][query_cache_key] = results
|
state.query_cache[user.uuid][query_cache_key] = results
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="search",
|
|
||||||
**common.__dict__,
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||||
|
|
||||||
@@ -350,14 +373,14 @@ async def extract_references_and_questions(
|
|||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if using_offline_chat else n
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
result_list.extend(
|
result_list.extend(
|
||||||
await search(
|
await execute_search(
|
||||||
|
user,
|
||||||
f"{query} {filters_in_query}",
|
f"{query} {filters_in_query}",
|
||||||
request=request,
|
|
||||||
n=n_items,
|
n=n_items,
|
||||||
|
t=SearchType.All,
|
||||||
r=True,
|
r=True,
|
||||||
max_distance=d,
|
max_distance=d,
|
||||||
dedupe=False,
|
dedupe=False,
|
||||||
common=common,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result_list = text_search.deduplicated_search_responses(result_list)
|
result_list = text_search.deduplicated_search_responses(result_list)
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ from typing import Dict, Optional
|
|||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request, WebSocket
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
from websockets import ConnectionClosedOK
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
@@ -242,6 +244,230 @@ async def set_conversation_title(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_chat.websocket("/ws")
|
||||||
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
conversation_id: int,
|
||||||
|
city: Optional[str] = None,
|
||||||
|
region: Optional[str] = None,
|
||||||
|
country: Optional[str] = None,
|
||||||
|
):
|
||||||
|
connection_alive = True
|
||||||
|
|
||||||
|
async def send_status_update(message: str):
|
||||||
|
nonlocal connection_alive
|
||||||
|
if not connection_alive:
|
||||||
|
return
|
||||||
|
|
||||||
|
status_packet = {
|
||||||
|
"type": "status",
|
||||||
|
"message": message,
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await websocket.send_text(json.dumps(status_packet))
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
async def send_complete_llm_response(llm_response: str):
|
||||||
|
nonlocal connection_alive
|
||||||
|
if not connection_alive:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await websocket.send_text("start_llm_response")
|
||||||
|
await websocket.send_text(llm_response)
|
||||||
|
await websocket.send_text("end_llm_response")
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
async def send_message(message: str):
|
||||||
|
nonlocal connection_alive
|
||||||
|
if not connection_alive:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await websocket.send_text(message)
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
user: KhojUser = websocket.user.object
|
||||||
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
|
user, client_application=websocket.user.client_app, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||||
|
|
||||||
|
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||||
|
|
||||||
|
await is_ready_to_chat(user)
|
||||||
|
|
||||||
|
user_name = await aget_user_name(user)
|
||||||
|
|
||||||
|
location = None
|
||||||
|
|
||||||
|
if city or region or country:
|
||||||
|
location = LocationData(city=city, region=region, country=country)
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
while connection_alive:
|
||||||
|
try:
|
||||||
|
q = await websocket.receive_text()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.debug(f"User {user} disconnected web socket")
|
||||||
|
break
|
||||||
|
|
||||||
|
await sync_to_async(hourly_limiter)(websocket)
|
||||||
|
await sync_to_async(daily_limiter)(websocket)
|
||||||
|
|
||||||
|
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||||
|
|
||||||
|
await send_status_update(f"**Processing query**: {q}")
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Help]:
|
||||||
|
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||||
|
if conversation_config == None:
|
||||||
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
model_type = conversation_config.model_type
|
||||||
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
|
await send_complete_llm_response(formatted_help)
|
||||||
|
continue
|
||||||
|
|
||||||
|
meta_log = conversation.conversation_log
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Default]:
|
||||||
|
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||||
|
mode = await aget_relevant_output_modes(q, meta_log)
|
||||||
|
if mode not in conversation_commands:
|
||||||
|
conversation_commands.append(mode)
|
||||||
|
|
||||||
|
for cmd in conversation_commands:
|
||||||
|
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||||
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
|
await send_status_update(
|
||||||
|
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
|
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
|
||||||
|
)
|
||||||
|
|
||||||
|
if compiled_references:
|
||||||
|
headings = set([c.split("\n")[0] for c in compiled_references])
|
||||||
|
await send_status_update(f"**Searching references**: {headings}")
|
||||||
|
|
||||||
|
online_results: Dict = dict()
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
|
await send_complete_llm_response(f"{no_entries_found.format()}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||||
|
conversation_commands.remove(ConversationCommand.Notes)
|
||||||
|
|
||||||
|
if ConversationCommand.Online in conversation_commands:
|
||||||
|
if not online_search_enabled():
|
||||||
|
conversation_commands.remove(ConversationCommand.Online)
|
||||||
|
# If online search is not enabled, try to read webpages directly
|
||||||
|
if ConversationCommand.Webpage not in conversation_commands:
|
||||||
|
conversation_commands.append(ConversationCommand.Webpage)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await send_status_update("**Operation**: Searching the web for relevant information...")
|
||||||
|
online_results = await search_online(defiltered_query, meta_log, location)
|
||||||
|
online_searches = ", ".join([f"{query}" for query in online_results.keys()])
|
||||||
|
await send_status_update(f"**Online searches**: {online_searches}")
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
||||||
|
await send_complete_llm_response(
|
||||||
|
f"Error searching online: {e}. Attempting to respond without online results"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ConversationCommand.Image in conversation_commands:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=websocket,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
metadata={"conversation_command": conversation_commands[0].value},
|
||||||
|
)
|
||||||
|
await send_status_update("**Operation**: Augmenting your query and generating a superb image...")
|
||||||
|
intent_type = "text-to-image"
|
||||||
|
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||||
|
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||||
|
)
|
||||||
|
if image is None or status_code != 200:
|
||||||
|
content_obj = {
|
||||||
|
"image": image,
|
||||||
|
"intentType": intent_type,
|
||||||
|
"detail": improved_image_prompt,
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
await send_complete_llm_response(json.dumps(content_obj))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if image_url:
|
||||||
|
intent_type = "text-to-image2"
|
||||||
|
image = image_url
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
image,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
intent_type=intent_type,
|
||||||
|
inferred_queries=[improved_image_prompt],
|
||||||
|
client_application=websocket.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
)
|
||||||
|
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore
|
||||||
|
|
||||||
|
await send_complete_llm_response(json.dumps(content_obj))
|
||||||
|
continue
|
||||||
|
|
||||||
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
|
defiltered_query,
|
||||||
|
meta_log,
|
||||||
|
conversation,
|
||||||
|
compiled_references,
|
||||||
|
online_results,
|
||||||
|
inferred_queries,
|
||||||
|
conversation_commands,
|
||||||
|
user,
|
||||||
|
websocket.user.client_app,
|
||||||
|
conversation_id,
|
||||||
|
location,
|
||||||
|
user_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=websocket,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
metadata=chat_metadata,
|
||||||
|
)
|
||||||
|
iterator = AsyncIteratorWrapper(llm_response)
|
||||||
|
|
||||||
|
await send_message("start_llm_response")
|
||||||
|
|
||||||
|
async for item in iterator:
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
if connection_alive:
|
||||||
|
try:
|
||||||
|
await send_message(f"{item}")
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
connection_alive = False
|
||||||
|
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
|
||||||
|
|
||||||
|
await send_message("end_llm_response")
|
||||||
|
|
||||||
|
|
||||||
@api_chat.get("", response_class=Response)
|
@api_chat.get("", response_class=Response)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|||||||
Reference in New Issue
Block a user