mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Merge pull request #580 from khoj-ai/fix-upgrade-chat-to-create-images
Support Image Generation with Khoj
This commit is contained in:
@@ -42,7 +42,7 @@ dependencies = [
|
|||||||
"fastapi >= 0.104.1",
|
"fastapi >= 0.104.1",
|
||||||
"python-multipart >= 0.0.5",
|
"python-multipart >= 0.0.5",
|
||||||
"jinja2 == 3.1.2",
|
"jinja2 == 3.1.2",
|
||||||
"openai >= 0.27.0, < 1.0.0",
|
"openai >= 1.0.0",
|
||||||
"tiktoken >= 0.3.2",
|
"tiktoken >= 0.3.2",
|
||||||
"tenacity >= 8.2.2",
|
"tenacity >= 8.2.2",
|
||||||
"pillow ~= 9.5.0",
|
"pillow ~= 9.5.0",
|
||||||
|
|||||||
@@ -179,7 +179,13 @@
|
|||||||
return numOnlineReferences;
|
return numOnlineReferences;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
|
||||||
|
if (intentType === "text-to-image") {
|
||||||
|
let imageMarkdown = ``;
|
||||||
|
renderMessage(imageMarkdown, by, dt);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (context == null && onlineContext == null) {
|
if (context == null && onlineContext == null) {
|
||||||
renderMessage(message, by, dt);
|
renderMessage(message, by, dt);
|
||||||
return;
|
return;
|
||||||
@@ -244,6 +250,17 @@
|
|||||||
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
||||||
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
||||||
|
|
||||||
|
// Customize the rendering of images
|
||||||
|
md.renderer.rules.image = function(tokens, idx, options, env, self) {
|
||||||
|
let token = tokens[idx];
|
||||||
|
|
||||||
|
// Add class="text-to-image" to images
|
||||||
|
token.attrPush(['class', 'text-to-image']);
|
||||||
|
|
||||||
|
// Use the default renderer to render image markdown format
|
||||||
|
return self.renderToken(tokens, idx, options);
|
||||||
|
};
|
||||||
|
|
||||||
// Render markdown
|
// Render markdown
|
||||||
newHTML = md.render(newHTML);
|
newHTML = md.render(newHTML);
|
||||||
// Get any elements with a class that starts with "language"
|
// Get any elements with a class that starts with "language"
|
||||||
@@ -328,109 +345,142 @@
|
|||||||
let chatInput = document.getElementById("chat-input");
|
let chatInput = document.getElementById("chat-input");
|
||||||
chatInput.classList.remove("option-enabled");
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
// Call specified Khoj API
|
||||||
fetch(url, { headers })
|
let response = await fetch(url, { headers });
|
||||||
.then(response => {
|
let rawResponse = "";
|
||||||
const reader = response.body.getReader();
|
const contentType = response.headers.get("content-type");
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = null;
|
|
||||||
|
|
||||||
function readStream() {
|
if (contentType === "application/json") {
|
||||||
reader.read().then(({ done, value }) => {
|
// Handle JSON response
|
||||||
if (done) {
|
try {
|
||||||
// Append any references after all the data has been streamed
|
const responseAsJson = await response.json();
|
||||||
if (references != null) {
|
if (responseAsJson.image) {
|
||||||
newResponseText.appendChild(references);
|
// If response has image field, response is a generated image.
|
||||||
}
|
rawResponse += ``;
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
}
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
if (responseAsJson.detail) {
|
||||||
return;
|
// If response has detail field, response is an error message.
|
||||||
|
rawResponse += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle streamed response of type text/event-stream or text/plain
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let references = null;
|
||||||
|
|
||||||
|
readStream();
|
||||||
|
|
||||||
|
function readStream() {
|
||||||
|
reader.read().then(({ done, value }) => {
|
||||||
|
if (done) {
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
if (references != null) {
|
||||||
|
newResponseText.appendChild(references);
|
||||||
|
}
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode message chunk from stream
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
if (chunk.includes("### compiled references:")) {
|
||||||
|
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||||
|
rawResponse += additionalResponse;
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
const rawReference = chunk.split("### compiled references:")[1];
|
||||||
|
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||||
|
references = document.createElement('div');
|
||||||
|
references.classList.add("references");
|
||||||
|
|
||||||
|
let referenceExpandButton = document.createElement('button');
|
||||||
|
referenceExpandButton.classList.add("reference-expand-button");
|
||||||
|
|
||||||
|
let referenceSection = document.createElement('div');
|
||||||
|
referenceSection.classList.add("reference-section");
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
|
||||||
|
let numReferences = 0;
|
||||||
|
|
||||||
|
// If rawReferenceAsJson is a list, then count the length
|
||||||
|
if (Array.isArray(rawReferenceAsJson)) {
|
||||||
|
numReferences = rawReferenceAsJson.length;
|
||||||
|
|
||||||
|
rawReferenceAsJson.forEach((reference, index) => {
|
||||||
|
let polishedReference = generateReference(reference, index);
|
||||||
|
referenceSection.appendChild(polishedReference);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode message chunk from stream
|
references.appendChild(referenceExpandButton);
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
referenceExpandButton.addEventListener('click', function() {
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
if (referenceSection.classList.contains("collapsed")) {
|
||||||
rawResponse += additionalResponse;
|
referenceSection.classList.remove("collapsed");
|
||||||
|
referenceSection.classList.add("expanded");
|
||||||
|
} else {
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
referenceSection.classList.remove("expanded");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||||
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
|
references.appendChild(referenceSection);
|
||||||
|
readStream();
|
||||||
|
} else {
|
||||||
|
// Display response from Khoj
|
||||||
|
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||||
|
newResponseText.removeChild(loadingSpinner);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||||
|
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||||
|
try {
|
||||||
|
const responseAsJson = JSON.parse(chunk);
|
||||||
|
if (responseAsJson.image) {
|
||||||
|
rawResponse += ``;
|
||||||
|
}
|
||||||
|
if (responseAsJson.detail) {
|
||||||
|
rawResponse += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
newResponseText.innerHTML = "";
|
newResponseText.innerHTML = "";
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
references = document.createElement('div');
|
|
||||||
references.classList.add("references");
|
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
|
||||||
referenceExpandButton.classList.add("reference-expand-button");
|
|
||||||
|
|
||||||
let referenceSection = document.createElement('div');
|
|
||||||
referenceSection.classList.add("reference-section");
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
|
|
||||||
let numReferences = 0;
|
|
||||||
|
|
||||||
// If rawReferenceAsJson is a list, then count the length
|
|
||||||
if (Array.isArray(rawReferenceAsJson)) {
|
|
||||||
numReferences = rawReferenceAsJson.length;
|
|
||||||
|
|
||||||
rawReferenceAsJson.forEach((reference, index) => {
|
|
||||||
let polishedReference = generateReference(reference, index);
|
|
||||||
referenceSection.appendChild(polishedReference);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
|
||||||
}
|
|
||||||
|
|
||||||
references.appendChild(referenceExpandButton);
|
|
||||||
|
|
||||||
referenceExpandButton.addEventListener('click', function() {
|
|
||||||
if (referenceSection.classList.contains("collapsed")) {
|
|
||||||
referenceSection.classList.remove("collapsed");
|
|
||||||
referenceSection.classList.add("expanded");
|
|
||||||
} else {
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
referenceSection.classList.remove("expanded");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
|
||||||
referenceExpandButton.innerHTML = expandButtonText;
|
|
||||||
references.appendChild(referenceSection);
|
|
||||||
readStream();
|
readStream();
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
|
||||||
newResponseText.removeChild(loadingSpinner);
|
|
||||||
}
|
|
||||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
|
||||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
|
||||||
try {
|
|
||||||
const responseAsJson = JSON.parse(chunk);
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
newResponseText.innerHTML += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
newResponseText.innerHTML += chunk;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
readStream();
|
}
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
@@ -522,7 +572,7 @@
|
|||||||
.then(response => {
|
.then(response => {
|
||||||
// Render conversation history, if any
|
// Render conversation history, if any
|
||||||
response.forEach(chat_log => {
|
response.forEach(chat_log => {
|
||||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
|
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
@@ -625,9 +675,13 @@
|
|||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
.then(data => { chatInput.value += data.text; })
|
.then(data => { chatInput.value += data.text; })
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
err.status == 422
|
if (err.status === 501) {
|
||||||
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||||
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
|
} else if (err.status === 422) {
|
||||||
|
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||||
|
} else {
|
||||||
|
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -810,6 +864,9 @@
|
|||||||
margin-top: -10px;
|
margin-top: -10px;
|
||||||
transform: rotate(-60deg)
|
transform: rotate(-60deg)
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 60%;
|
||||||
|
}
|
||||||
|
|
||||||
#chat-footer {
|
#chat-footer {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
@@ -1050,6 +1107,9 @@
|
|||||||
margin: 4px;
|
margin: 4px;
|
||||||
grid-template-columns: auto;
|
grid-template-columns: auto;
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 100%;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@media only screen and (min-width: 600px) {
|
@media only screen and (min-width: 600px) {
|
||||||
body {
|
body {
|
||||||
|
|||||||
@@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
|
|||||||
import { KhojSetting } from 'src/settings';
|
import { KhojSetting } from 'src/settings';
|
||||||
import fetch from "node-fetch";
|
import fetch from "node-fetch";
|
||||||
|
|
||||||
|
export interface ChatJsonResult {
|
||||||
|
image?: string;
|
||||||
|
detail?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
export class KhojChatModal extends Modal {
|
export class KhojChatModal extends Modal {
|
||||||
result: string;
|
result: string;
|
||||||
setting: KhojSetting;
|
setting: KhojSetting;
|
||||||
@@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
|
|||||||
return referenceButton;
|
return referenceButton;
|
||||||
}
|
}
|
||||||
|
|
||||||
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
|
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
|
||||||
if (!message) {
|
if (!message) {
|
||||||
return;
|
return;
|
||||||
|
} else if (intentType === "text-to-image") {
|
||||||
|
let imageMarkdown = ``;
|
||||||
|
this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||||
|
return;
|
||||||
} else if (!context) {
|
} else if (!context) {
|
||||||
this.renderMessage(chatEl, message, sender, dt);
|
this.renderMessage(chatEl, message, sender, dt);
|
||||||
return
|
return;
|
||||||
} else if (!!context && context?.length === 0) {
|
} else if (!!context && context?.length === 0) {
|
||||||
this.renderMessage(chatEl, message, sender, dt);
|
this.renderMessage(chatEl, message, sender, dt);
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||||
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
|
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
|
||||||
@@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
|
|||||||
let response = await request({ url: chatUrl, headers: headers });
|
let response = await request({ url: chatUrl, headers: headers });
|
||||||
let chatLogs = JSON.parse(response).response;
|
let chatLogs = JSON.parse(response).response;
|
||||||
chatLogs.forEach((chatLog: any) => {
|
chatLogs.forEach((chatLog: any) => {
|
||||||
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
|
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
|
|||||||
|
|
||||||
this.result = "";
|
this.result = "";
|
||||||
responseElement.innerHTML = "";
|
responseElement.innerHTML = "";
|
||||||
|
if (response.headers.get("content-type") == "application/json") {
|
||||||
|
let responseText = ""
|
||||||
|
try {
|
||||||
|
const responseAsJson = await response.json() as ChatJsonResult;
|
||||||
|
if (responseAsJson.image) {
|
||||||
|
responseText = ``;
|
||||||
|
} else if (responseAsJson.detail) {
|
||||||
|
responseText = responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
responseText = response.body.read().toString()
|
||||||
|
} finally {
|
||||||
|
this.renderIncrementalMessage(responseElement, responseText);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for await (const chunk of response.body) {
|
for await (const chunk of response.body) {
|
||||||
const responseText = chunk.toString();
|
let responseText = chunk.toString();
|
||||||
if (responseText.includes("### compiled references:")) {
|
if (responseText.includes("### compiled references:")) {
|
||||||
const additionalResponse = responseText.split("### compiled references:")[0];
|
const additionalResponse = responseText.split("### compiled references:")[0];
|
||||||
this.renderIncrementalMessage(responseElement, additionalResponse);
|
this.renderIncrementalMessage(responseElement, additionalResponse);
|
||||||
@@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
|
|||||||
referenceExpandButton.innerHTML = expandButtonText;
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
references.appendChild(referenceSection);
|
references.appendChild(referenceSection);
|
||||||
} else {
|
} else {
|
||||||
|
if (responseText.startsWith("{") && responseText.endsWith("}")) {
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
this.renderIncrementalMessage(responseElement, responseText);
|
this.renderIncrementalMessage(responseElement, responseText);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
|
|||||||
if (response.status === 200) {
|
if (response.status === 200) {
|
||||||
console.log(response);
|
console.log(response);
|
||||||
chatInput.value += response.json.text;
|
chatInput.value += response.json.text;
|
||||||
} else if (response.status === 422) {
|
} else if (response.status === 501) {
|
||||||
throw new Error("⛔️ Failed to transcribe audio");
|
|
||||||
} else {
|
|
||||||
throw new Error("⛔️ Configure speech-to-text model on server.");
|
throw new Error("⛔️ Configure speech-to-text model on server.");
|
||||||
|
} else if (response.status === 422) {
|
||||||
|
throw new Error("⛔️ Audio file to large to process.");
|
||||||
|
} else {
|
||||||
|
throw new Error("⛔️ Failed to transcribe audio.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -217,7 +217,9 @@ button.copy-button:hover {
|
|||||||
background: #f5f5f5;
|
background: #f5f5f5;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
|
img {
|
||||||
|
max-width: 60%;
|
||||||
|
}
|
||||||
|
|
||||||
#khoj-chat-footer {
|
#khoj-chat-footer {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import requests
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
import openai
|
||||||
import schedule
|
import schedule
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
@@ -22,6 +23,7 @@ from starlette.authentication import (
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
ConversationAdapters,
|
||||||
get_all_users,
|
get_all_users,
|
||||||
get_or_create_search_model,
|
get_or_create_search_model,
|
||||||
aget_user_subscription_state,
|
aget_user_subscription_state,
|
||||||
@@ -138,6 +140,10 @@ def configure_server(
|
|||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
state.config = config
|
state.config = config
|
||||||
|
|
||||||
|
if ConversationAdapters.has_valid_openai_conversation_config():
|
||||||
|
openai_config = ConversationAdapters.get_openai_conversation_config()
|
||||||
|
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
|
||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from khoj.database.models import (
|
|||||||
GithubConfig,
|
GithubConfig,
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
GoogleUser,
|
GoogleUser,
|
||||||
|
TextToImageModelConfig,
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
@@ -426,6 +427,10 @@ class ConversationAdapters:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_text_to_image_model_config():
|
||||||
|
return await TextToImageModelConfig.objects.filter().afirst()
|
||||||
|
|
||||||
|
|
||||||
class EntryAdapters:
|
class EntryAdapters:
|
||||||
word_filer = WordFilter()
|
word_filer = WordFilter()
|
||||||
|
|||||||
25
src/khoj/database/migrations/0022_texttoimagemodelconfig.py
Normal file
25
src/khoj/database/migrations/0022_texttoimagemodelconfig.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Generated by Django 4.2.7 on 2023-12-04 22:17
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0021_speechtotextmodeloptions_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="TextToImageModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("model_name", models.CharField(default="dall-e-3", max_length=200)),
|
||||||
|
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
|
|||||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToImageModelConfig(BaseModel):
|
||||||
|
class ModelType(models.TextChoices):
|
||||||
|
OPENAI = "openai"
|
||||||
|
|
||||||
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
api_key = models.CharField(max_length=200)
|
api_key = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|||||||
@@ -183,12 +183,18 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return numOnlineReferences;
|
return numOnlineReferences;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
|
||||||
|
if (intentType === "text-to-image") {
|
||||||
|
let imageMarkdown = ``;
|
||||||
|
renderMessage(imageMarkdown, by, dt);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
|
||||||
if (context == null && onlineContext == null) {
|
if (context == null && onlineContext == null) {
|
||||||
renderMessage(message, by, dt);
|
renderMessage(message, by, dt);
|
||||||
return;
|
return;
|
||||||
@@ -253,6 +259,17 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
||||||
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
||||||
|
|
||||||
|
// Customize the rendering of images
|
||||||
|
md.renderer.rules.image = function(tokens, idx, options, env, self) {
|
||||||
|
let token = tokens[idx];
|
||||||
|
|
||||||
|
// Add class="text-to-image" to images
|
||||||
|
token.attrPush(['class', 'text-to-image']);
|
||||||
|
|
||||||
|
// Use the default renderer to render image markdown format
|
||||||
|
return self.renderToken(tokens, idx, options);
|
||||||
|
};
|
||||||
|
|
||||||
// Render markdown
|
// Render markdown
|
||||||
newHTML = md.render(newHTML);
|
newHTML = md.render(newHTML);
|
||||||
// Get any elements with a class that starts with "language"
|
// Get any elements with a class that starts with "language"
|
||||||
@@ -292,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
return element
|
return element
|
||||||
}
|
}
|
||||||
|
|
||||||
function chat() {
|
async function chat() {
|
||||||
// Extract required fields for search from form
|
// Extract required fields for search from form
|
||||||
let query = document.getElementById("chat-input").value.trim();
|
let query = document.getElementById("chat-input").value.trim();
|
||||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
||||||
@@ -333,113 +350,123 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
let chatInput = document.getElementById("chat-input");
|
let chatInput = document.getElementById("chat-input");
|
||||||
chatInput.classList.remove("option-enabled");
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
// Call specified Khoj API
|
||||||
fetch(url)
|
let response = await fetch(url);
|
||||||
.then(response => {
|
let rawResponse = "";
|
||||||
const reader = response.body.getReader();
|
const contentType = response.headers.get("content-type");
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = null;
|
|
||||||
|
|
||||||
function readStream() {
|
if (contentType === "application/json") {
|
||||||
reader.read().then(({ done, value }) => {
|
// Handle JSON response
|
||||||
if (done) {
|
try {
|
||||||
// Append any references after all the data has been streamed
|
const responseAsJson = await response.json();
|
||||||
if (references != null) {
|
if (responseAsJson.image) {
|
||||||
newResponseText.appendChild(references);
|
// If response has image field, response is a generated image.
|
||||||
}
|
rawResponse += ``;
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode message chunk from stream
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
references = document.createElement('div');
|
|
||||||
references.classList.add("references");
|
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
|
||||||
referenceExpandButton.classList.add("reference-expand-button");
|
|
||||||
|
|
||||||
let referenceSection = document.createElement('div');
|
|
||||||
referenceSection.classList.add("reference-section");
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
|
|
||||||
let numReferences = 0;
|
|
||||||
|
|
||||||
// If rawReferenceAsJson is a list, then count the length
|
|
||||||
if (Array.isArray(rawReferenceAsJson)) {
|
|
||||||
numReferences = rawReferenceAsJson.length;
|
|
||||||
|
|
||||||
rawReferenceAsJson.forEach((reference, index) => {
|
|
||||||
let polishedReference = generateReference(reference, index);
|
|
||||||
referenceSection.appendChild(polishedReference);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
|
||||||
}
|
|
||||||
|
|
||||||
references.appendChild(referenceExpandButton);
|
|
||||||
|
|
||||||
referenceExpandButton.addEventListener('click', function() {
|
|
||||||
if (referenceSection.classList.contains("collapsed")) {
|
|
||||||
referenceSection.classList.remove("collapsed");
|
|
||||||
referenceSection.classList.add("expanded");
|
|
||||||
} else {
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
referenceSection.classList.remove("expanded");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
|
||||||
referenceExpandButton.innerHTML = expandButtonText;
|
|
||||||
references.appendChild(referenceSection);
|
|
||||||
readStream();
|
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
|
||||||
newResponseText.removeChild(loadingSpinner);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
|
||||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
|
||||||
try {
|
|
||||||
const responseAsJson = JSON.parse(chunk);
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
rawResponse += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
readStream();
|
if (responseAsJson.detail) {
|
||||||
});
|
// If response has detail field, response is an error message.
|
||||||
}
|
rawResponse += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle streamed response of type text/event-stream or text/plain
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let references = null;
|
||||||
|
|
||||||
|
readStream();
|
||||||
|
|
||||||
|
function readStream() {
|
||||||
|
reader.read().then(({ done, value }) => {
|
||||||
|
if (done) {
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
if (references != null) {
|
||||||
|
newResponseText.appendChild(references);
|
||||||
|
}
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode message chunk from stream
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
if (chunk.includes("### compiled references:")) {
|
||||||
|
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||||
|
rawResponse += additionalResponse;
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
const rawReference = chunk.split("### compiled references:")[1];
|
||||||
|
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||||
|
references = document.createElement('div');
|
||||||
|
references.classList.add("references");
|
||||||
|
|
||||||
|
let referenceExpandButton = document.createElement('button');
|
||||||
|
referenceExpandButton.classList.add("reference-expand-button");
|
||||||
|
|
||||||
|
let referenceSection = document.createElement('div');
|
||||||
|
referenceSection.classList.add("reference-section");
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
|
||||||
|
let numReferences = 0;
|
||||||
|
|
||||||
|
// If rawReferenceAsJson is a list, then count the length
|
||||||
|
if (Array.isArray(rawReferenceAsJson)) {
|
||||||
|
numReferences = rawReferenceAsJson.length;
|
||||||
|
|
||||||
|
rawReferenceAsJson.forEach((reference, index) => {
|
||||||
|
let polishedReference = generateReference(reference, index);
|
||||||
|
referenceSection.appendChild(polishedReference);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||||
|
}
|
||||||
|
|
||||||
|
references.appendChild(referenceExpandButton);
|
||||||
|
|
||||||
|
referenceExpandButton.addEventListener('click', function() {
|
||||||
|
if (referenceSection.classList.contains("collapsed")) {
|
||||||
|
referenceSection.classList.remove("collapsed");
|
||||||
|
referenceSection.classList.add("expanded");
|
||||||
|
} else {
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
referenceSection.classList.remove("expanded");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||||
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
|
references.appendChild(referenceSection);
|
||||||
|
readStream();
|
||||||
|
} else {
|
||||||
|
// Display response from Khoj
|
||||||
|
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||||
|
newResponseText.removeChild(loadingSpinner);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
readStream();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
if (!event.shiftKey && event.key === 'Enter') {
|
if (!event.shiftKey && event.key === 'Enter') {
|
||||||
@@ -516,7 +543,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
.then(response => {
|
.then(response => {
|
||||||
// Render conversation history, if any
|
// Render conversation history, if any
|
||||||
response.forEach(chat_log => {
|
response.forEach(chat_log => {
|
||||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
|
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
@@ -611,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
.then(data => { chatInput.value += data.text; })
|
.then(data => { chatInput.value += data.text; })
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
err.status == 422
|
if (err.status === 501) {
|
||||||
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||||
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
|
} else if (err.status === 422) {
|
||||||
|
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||||
|
} else {
|
||||||
|
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -902,6 +933,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
margin-top: -10px;
|
margin-top: -10px;
|
||||||
transform: rotate(-60deg)
|
transform: rotate(-60deg)
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 60%;
|
||||||
|
}
|
||||||
|
|
||||||
#chat-footer {
|
#chat-footer {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
@@ -1029,6 +1063,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
margin: 4px;
|
margin: 4px;
|
||||||
grid-template-columns: auto;
|
grid-template-columns: auto;
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 100%;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@media only screen and (min-width: 700px) {
|
@media only screen and (min-width: 700px) {
|
||||||
body {
|
body {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def extract_questions_offline(
|
|||||||
|
|
||||||
if use_history:
|
if use_history:
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
if chat["by"] == "khoj":
|
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
chat_history += f"A: {chat['message']}\n"
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ def download_model(model_name: str):
|
|||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Download the chat model
|
|
||||||
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
|
||||||
|
|
||||||
# Decide whether to load model to GPU or CPU
|
# Decide whether to load model to GPU or CPU
|
||||||
|
chat_model_config = None
|
||||||
try:
|
try:
|
||||||
|
# Download the chat model and its config
|
||||||
|
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
||||||
|
|
||||||
# Try load chat model to GPU if:
|
# Try load chat model to GPU if:
|
||||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
# 1. Loading chat model to GPU isn't disabled via CLI and
|
||||||
# 2. Machine has GPU
|
# 2. Machine has GPU
|
||||||
@@ -26,6 +27,12 @@ def download_model(model_name: str):
|
|||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
except Exception as e:
|
||||||
|
if chat_model_config is None:
|
||||||
|
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
|
||||||
|
logger.debug(f"Unable to download model config from gpt4all website: {e}")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
# Now load the downloaded chat model onto appropriate device
|
# Now load the downloaded chat model onto appropriate device
|
||||||
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
|
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def extract_questions(
|
|||||||
[
|
[
|
||||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
||||||
for chat in conversation_log.get("chat", [])[-4:]
|
for chat in conversation_log.get("chat", [])[-4:]
|
||||||
if chat["by"] == "khoj"
|
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,8 +123,8 @@ def send_message_to_model(
|
|||||||
|
|
||||||
def converse(
|
def converse(
|
||||||
references,
|
references,
|
||||||
online_results,
|
|
||||||
user_query,
|
user_query,
|
||||||
|
online_results=[],
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
retry_if_exception_type(openai._exceptions.APITimeoutError)
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
| retry_if_exception_type(openai._exceptions.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
),
|
),
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
@@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
|
|||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
retry_if_exception_type(openai._exceptions.APITimeoutError)
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
| retry_if_exception_type(openai._exceptions.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
),
|
),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ from io import BufferedReader
|
|||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
import openai
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
|
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe audio file using Whisper model via OpenAI's API
|
Transcribe audio file using Whisper model via OpenAI's API
|
||||||
"""
|
"""
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
|
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
|
||||||
return response["text"]
|
return response.text
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from time import perf_counter
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import queue
|
import queue
|
||||||
|
from typing import Any, Dict, List
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
@@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from khoj.database.adapters import ConversationAdapters
|
||||||
|
from khoj.database.models import KhojUser
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
@@ -89,6 +92,32 @@ def message_to_log(
|
|||||||
return conversation_log
|
return conversation_log
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_conversation_log(
|
||||||
|
q: str,
|
||||||
|
chat_response: str,
|
||||||
|
user: KhojUser,
|
||||||
|
meta_log: Dict,
|
||||||
|
user_message_time: str = None,
|
||||||
|
compiled_references: List[str] = [],
|
||||||
|
online_results: Dict[str, Any] = {},
|
||||||
|
inferred_queries: List[str] = [],
|
||||||
|
intent_type: str = "remember",
|
||||||
|
):
|
||||||
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
updated_conversation = message_to_log(
|
||||||
|
user_message=q,
|
||||||
|
chat_response=chat_response,
|
||||||
|
user_message_metadata={"created": user_message_time},
|
||||||
|
khoj_message_metadata={
|
||||||
|
"context": compiled_references,
|
||||||
|
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||||
|
"onlineContext": online_results,
|
||||||
|
},
|
||||||
|
conversation_log=meta_log.get("chat", []),
|
||||||
|
)
|
||||||
|
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message,
|
system_message,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from starlette.authentication import requires
|
|||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import ChatModelOptions
|
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
@@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
|||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
|
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||||
from khoj.processor.tools.online_search import search_with_google
|
from khoj.processor.tools.online_search import search_with_google
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
|
text_to_image,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
@@ -622,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
|||||||
|
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
|
||||||
if not speech_to_text_config:
|
if not speech_to_text_config:
|
||||||
# If the user has not configured a speech to text model, return an unprocessable entity error
|
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||||
status_code = 422
|
status_code = 501
|
||||||
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
|
||||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||||
finally:
|
finally:
|
||||||
# Close and Delete the temporary audio file
|
# Close and Delete the temporary audio file
|
||||||
audio_file.close()
|
audio_file.close()
|
||||||
@@ -665,7 +665,7 @@ async def chat(
|
|||||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
@@ -703,6 +703,11 @@ async def chat(
|
|||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
|
elif conversation_command == ConversationCommand.Image:
|
||||||
|
image, status_code = await text_to_image(q)
|
||||||
|
await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image")
|
||||||
|
content_obj = {"image": image, "intentType": "text-to-image"}
|
||||||
|
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
@@ -786,7 +791,6 @@ async def extract_references_and_questions(
|
|||||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
|
||||||
if (
|
if (
|
||||||
offline_chat_config
|
offline_chat_config
|
||||||
and offline_chat_config.enabled
|
and offline_chat_config.enabled
|
||||||
@@ -803,7 +807,7 @@ async def extract_references_and_questions(
|
|||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
)
|
)
|
||||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
|
|||||||
@@ -9,23 +9,23 @@ from functools import partial
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
# External Packages
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
import openai
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
|
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
||||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
|
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
|
||||||
|
|
||||||
# Internal Packages
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
@@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
elif query.startswith("/online"):
|
elif query.startswith("/online"):
|
||||||
return ConversationCommand.Online
|
return ConversationCommand.Online
|
||||||
|
elif query.startswith("/image"):
|
||||||
|
return ConversationCommand.Image
|
||||||
# If no relevant notes found for the given query
|
# If no relevant notes found for the given query
|
||||||
elif not any_references:
|
elif not any_references:
|
||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
@@ -186,30 +188,7 @@ def generate_chat_response(
|
|||||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
def _save_to_conversation_log(
|
|
||||||
q: str,
|
|
||||||
chat_response: str,
|
|
||||||
user_message_time: str,
|
|
||||||
compiled_references: List[str],
|
|
||||||
online_results: Dict[str, Any],
|
|
||||||
inferred_queries: List[str],
|
|
||||||
meta_log,
|
|
||||||
):
|
|
||||||
updated_conversation = message_to_log(
|
|
||||||
user_message=q,
|
|
||||||
chat_response=chat_response,
|
|
||||||
user_message_metadata={"created": user_message_time},
|
|
||||||
khoj_message_metadata={
|
|
||||||
"context": compiled_references,
|
|
||||||
"intent": {"inferred-queries": inferred_queries},
|
|
||||||
"onlineContext": online_results,
|
|
||||||
},
|
|
||||||
conversation_log=meta_log.get("chat", []),
|
|
||||||
)
|
|
||||||
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
chat_response = None
|
chat_response = None
|
||||||
logger.debug(f"Conversation Type: {conversation_command.name}")
|
logger.debug(f"Conversation Type: {conversation_command.name}")
|
||||||
|
|
||||||
@@ -217,13 +196,13 @@ def generate_chat_response(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
_save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
q,
|
q,
|
||||||
user_message_time=user_message_time,
|
user=user,
|
||||||
|
meta_log=meta_log,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
meta_log=meta_log,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
||||||
@@ -251,9 +230,9 @@ def generate_chat_response(
|
|||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
online_results,
|
|
||||||
q,
|
q,
|
||||||
meta_log,
|
online_results=online_results,
|
||||||
|
conversation_log=meta_log,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
@@ -271,6 +250,29 @@ def generate_chat_response(
|
|||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
||||||
|
status_code = 200
|
||||||
|
image = None
|
||||||
|
|
||||||
|
# Send the audio data to the Whisper API
|
||||||
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
|
if not text_to_image_config:
|
||||||
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
|
status_code = 501
|
||||||
|
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
|
text2image_model = text_to_image_config.model_name
|
||||||
|
try:
|
||||||
|
response = state.openai_client.images.generate(
|
||||||
|
prompt=message, model=text2image_model, response_format="b64_json"
|
||||||
|
)
|
||||||
|
image = response.data[0].b64_json
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
||||||
|
status_code = 500
|
||||||
|
|
||||||
|
return image, status_code
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
|||||||
@@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
|
|||||||
Notes = "notes"
|
Notes = "notes"
|
||||||
Help = "help"
|
Help = "help"
|
||||||
Online = "online"
|
Online = "online"
|
||||||
|
Image = "image"
|
||||||
|
|
||||||
|
|
||||||
command_descriptions = {
|
command_descriptions = {
|
||||||
@@ -280,6 +281,7 @@ command_descriptions = {
|
|||||||
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
||||||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||||
ConversationCommand.Online: "Look up information on the internet.",
|
ConversationCommand.Online: "Look up information on the internet.",
|
||||||
|
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from khoj.database.models import (
|
|||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
|
TextToImageModelConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
|
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
|
||||||
@@ -103,6 +104,15 @@ def initialization():
|
|||||||
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
||||||
)
|
)
|
||||||
|
|
||||||
|
default_text_to_image_model = "dall-e-3"
|
||||||
|
openai_text_to_image_model = input(
|
||||||
|
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
|
||||||
|
)
|
||||||
|
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
|
||||||
|
TextToImageModelConfig.objects.create(
|
||||||
|
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
|
||||||
|
)
|
||||||
|
|
||||||
if use_offline_model == "y" or use_openai_model == "y":
|
if use_offline_model == "y" or use_openai_model == "y":
|
||||||
logger.info("🗣️ Chat model configuration complete")
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
|
||||||
|
|||||||
@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseEncoder):
|
class OpenAI(BaseEncoder):
|
||||||
def __init__(self, model_name, device=None):
|
def __init__(self, model_name, client: openai.OpenAI, device=None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
if (
|
self.openai_client = client
|
||||||
not state.processor_config
|
|
||||||
or not state.processor_config.conversation
|
|
||||||
or not state.processor_config.conversation.openai_model
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
|
||||||
)
|
|
||||||
openai.api_key = state.processor_config.conversation.openai_model.api_key
|
|
||||||
self.embedding_dimensions = None
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
def encode(self, entries, device=None, **kwargs):
|
def encode(self, entries, device=None, **kwargs):
|
||||||
@@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
|
|||||||
processed_entry = entries[index].replace("\n", " ")
|
processed_entry = entries[index].replace("\n", " ")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
|
response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
|
||||||
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
||||||
# Use current models embedding dimension, once available
|
# Use current models embedding dimension, once available
|
||||||
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from pathlib import Path
|
from openai import OpenAI
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import LRU, get_device
|
from khoj.utils.helpers import LRU, get_device
|
||||||
@@ -21,6 +22,7 @@ search_models = SearchModels()
|
|||||||
embeddings_model: EmbeddingsModel = None
|
embeddings_model: EmbeddingsModel = None
|
||||||
cross_encoder_model: CrossEncoderModel = None
|
cross_encoder_model: CrossEncoderModel = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
openai_client: OpenAI = None
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
whisper_model: Whisper = None
|
whisper_model: Whisper = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
|
|||||||
@@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
|
|||||||
response_message = response_message.split("### compiled references")[0]
|
response_message = response_message.split("### compiled references")[0]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
|
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||||
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
|
"Expected links or serper not setup in response but got: " + response_message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user