Merge pull request #580 from khoj-ai/fix-upgrade-chat-to-create-images

Support Image Generation with Khoj
This commit is contained in:
sabaimran
2023-12-07 21:17:58 +05:30
committed by GitHub
22 changed files with 529 additions and 303 deletions

View File

@@ -42,7 +42,7 @@ dependencies = [
"fastapi >= 0.104.1", "fastapi >= 0.104.1",
"python-multipart >= 0.0.5", "python-multipart >= 0.0.5",
"jinja2 == 3.1.2", "jinja2 == 3.1.2",
"openai >= 0.27.0, < 1.0.0", "openai >= 1.0.0",
"tiktoken >= 0.3.2", "tiktoken >= 0.3.2",
"tenacity >= 8.2.2", "tenacity >= 8.2.2",
"pillow ~= 9.5.0", "pillow ~= 9.5.0",

View File

@@ -179,7 +179,13 @@
return numOnlineReferences; return numOnlineReferences;
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) { if (context == null && onlineContext == null) {
renderMessage(message, by, dt); renderMessage(message, by, dt);
return; return;
@@ -244,6 +250,17 @@
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, ''); newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown // Render markdown
newHTML = md.render(newHTML); newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language" // Get any elements with a class that starts with "language"
@@ -328,14 +345,41 @@
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain // Call specified Khoj API
fetch(url, { headers }) let response = await fetch(url, { headers });
.then(response => { let rawResponse = "";
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader(); const reader = response.body.getReader();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let rawResponse = "";
let references = null; let references = null;
readStream();
function readStream() { function readStream() {
reader.read().then(({ done, value }) => { reader.read().then(({ done, value }) => {
if (done) { if (done) {
@@ -404,16 +448,23 @@
if (newResponseText.getElementsByClassName("spinner").length > 0) { if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner); newResponseText.removeChild(loadingSpinner);
} }
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) { if (chunk.startsWith("{") && chunk.endsWith("}")) {
try { try {
const responseAsJson = JSON.parse(chunk); const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) { if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail; rawResponse += responseAsJson.detail;
} }
} catch (error) { } catch (error) {
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk; rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
} }
} else { } else {
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
@@ -429,8 +480,7 @@
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}); });
} }
readStream(); }
});
} }
function incrementalChat(event) { function incrementalChat(event) {
@@ -522,7 +572,7 @@
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
}); });
}) })
.catch(err => { .catch(err => {
@@ -625,9 +675,13 @@
.then(response => response.ok ? response.json() : Promise.reject(response)) .then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; }) .then(data => { chatInput.value += data.text; })
.catch(err => { .catch(err => {
err.status == 422 if (err.status === 501) {
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio") } else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
}); });
}; };
@@ -810,6 +864,9 @@
margin-top: -10px; margin-top: -10px;
transform: rotate(-60deg) transform: rotate(-60deg)
} }
img.text-to-image {
max-width: 60%;
}
#chat-footer { #chat-footer {
padding: 0; padding: 0;
@@ -1050,6 +1107,9 @@
margin: 4px; margin: 4px;
grid-template-columns: auto; grid-template-columns: auto;
} }
img.text-to-image {
max-width: 100%;
}
} }
@media only screen and (min-width: 600px) { @media only screen and (min-width: 600px) {
body { body {

View File

@@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
import { KhojSetting } from 'src/settings'; import { KhojSetting } from 'src/settings';
import fetch from "node-fetch"; import fetch from "node-fetch";
export interface ChatJsonResult {
image?: string;
detail?: string;
}
export class KhojChatModal extends Modal { export class KhojChatModal extends Modal {
result: string; result: string;
setting: KhojSetting; setting: KhojSetting;
@@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
return referenceButton; return referenceButton;
} }
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) { renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
if (!message) { if (!message) {
return; return;
} else if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
this.renderMessage(chatEl, imageMarkdown, sender, dt);
return;
} else if (!context) { } else if (!context) {
this.renderMessage(chatEl, message, sender, dt); this.renderMessage(chatEl, message, sender, dt);
return return;
} else if (!!context && context?.length === 0) { } else if (!!context && context?.length === 0) {
this.renderMessage(chatEl, message, sender, dt); this.renderMessage(chatEl, message, sender, dt);
return return;
} }
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt); let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0] let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
@@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
let response = await request({ url: chatUrl, headers: headers }); let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response; let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => { chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created)); this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
}); });
} }
@@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
this.result = ""; this.result = "";
responseElement.innerHTML = ""; responseElement.innerHTML = "";
if (response.headers.get("content-type") == "application/json") {
let responseText = ""
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.detail) {
responseText = responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
responseText = response.body.read().toString()
} finally {
this.renderIncrementalMessage(responseElement, responseText);
}
}
for await (const chunk of response.body) { for await (const chunk of response.body) {
const responseText = chunk.toString(); let responseText = chunk.toString();
if (responseText.includes("### compiled references:")) { if (responseText.includes("### compiled references:")) {
const additionalResponse = responseText.split("### compiled references:")[0]; const additionalResponse = responseText.split("### compiled references:")[0];
this.renderIncrementalMessage(responseElement, additionalResponse); this.renderIncrementalMessage(responseElement, additionalResponse);
@@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
referenceExpandButton.innerHTML = expandButtonText; referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection); references.appendChild(referenceSection);
} else { } else {
if (responseText.startsWith("{") && responseText.endsWith("}")) {
} else {
// If the chunk is not a JSON object, just display it as is
continue;
}
this.renderIncrementalMessage(responseElement, responseText); this.renderIncrementalMessage(responseElement, responseText);
} }
} }
@@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
if (response.status === 200) { if (response.status === 200) {
console.log(response); console.log(response);
chatInput.value += response.json.text; chatInput.value += response.json.text;
} else if (response.status === 422) { } else if (response.status === 501) {
throw new Error("⛔️ Failed to transcribe audio");
} else {
throw new Error("⛔️ Configure speech-to-text model on server."); throw new Error("⛔️ Configure speech-to-text model on server.");
} else if (response.status === 422) {
throw new Error("⛔️ Audio file to large to process.");
} else {
throw new Error("⛔️ Failed to transcribe audio.");
} }
}; };

View File

@@ -217,7 +217,9 @@ button.copy-button:hover {
background: #f5f5f5; background: #f5f5f5;
cursor: pointer; cursor: pointer;
} }
img {
max-width: 60%;
}
#khoj-chat-footer { #khoj-chat-footer {
padding: 0; padding: 0;

View File

@@ -7,6 +7,7 @@ import requests
import os import os
# External Packages # External Packages
import openai
import schedule import schedule
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
@@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages # Internal Packages
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import ( from khoj.database.adapters import (
ConversationAdapters,
get_all_users, get_all_users,
get_or_create_search_model, get_or_create_search_model,
aget_user_subscription_state, aget_user_subscription_state,
@@ -138,6 +140,10 @@ def configure_server(
config = FullConfig() config = FullConfig()
state.config = config state.config = config
if ConversationAdapters.has_valid_openai_conversation_config():
openai_config = ConversationAdapters.get_openai_conversation_config()
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
# Initialize Search Models from Config and initialize content # Initialize Search Models from Config and initialize content
try: try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)

View File

@@ -22,6 +22,7 @@ from khoj.database.models import (
GithubConfig, GithubConfig,
GithubRepoConfig, GithubRepoConfig,
GoogleUser, GoogleUser,
TextToImageModelConfig,
KhojApiUser, KhojApiUser,
KhojUser, KhojUser,
NotionConfig, NotionConfig,
@@ -426,6 +427,10 @@ class ConversationAdapters:
else: else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat") raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
@staticmethod
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
class EntryAdapters: class EntryAdapters:
word_filer = WordFilter() word_filer = WordFilter()

View File

@@ -0,0 +1,25 @@
# Generated by Django 4.2.7 on 2023-12-04 22:17
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0021_speechtotextmodeloptions_and_more"),
]
operations = [
migrations.CreateModel(
name="TextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("model_name", models.CharField(default="dall-e-3", max_length=200)),
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
],
options={
"abstract": False,
},
),
]

View File

@@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200) api_key = models.CharField(max_length=200)

View File

@@ -188,7 +188,13 @@ To get started, just start typing below. You can also type / to see a list of co
return numOnlineReferences; return numOnlineReferences;
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) { if (context == null && onlineContext == null) {
renderMessage(message, by, dt); renderMessage(message, by, dt);
return; return;
@@ -253,6 +259,17 @@ To get started, just start typing below. You can also type / to see a list of co
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, ''); newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown // Render markdown
newHTML = md.render(newHTML); newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language" // Get any elements with a class that starts with "language"
@@ -292,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co
return element return element
} }
function chat() { async function chat() {
// Extract required fields for search from form // Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim(); let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5; let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@@ -333,14 +350,41 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain // Call specified Khoj API
fetch(url) let response = await fetch(url);
.then(response => { let rawResponse = "";
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader(); const reader = response.body.getReader();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let rawResponse = "";
let references = null; let references = null;
readStream();
function readStream() { function readStream() {
reader.read().then(({ done, value }) => { reader.read().then(({ done, value }) => {
if (done) { if (done) {
@@ -410,36 +454,19 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.removeChild(loadingSpinner); newResponseText.removeChild(loadingSpinner);
} }
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
rawResponse += chunk; rawResponse += chunk;
newResponseText.innerHTML = ""; newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse)); newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream(); readStream();
} }
} });
// Scroll to bottom of chat window as chat response is streamed // Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}); };
}
readStream();
});
} }
};
function incrementalChat(event) { function incrementalChat(event) {
if (!event.shiftKey && event.key === 'Enter') { if (!event.shiftKey && event.key === 'Enter') {
@@ -516,7 +543,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
}); });
}) })
.catch(err => { .catch(err => {
@@ -611,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => response.ok ? response.json() : Promise.reject(response)) .then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; }) .then(data => { chatInput.value += data.text; })
.catch(err => { .catch(err => {
err.status == 422 if (err.status === 501) {
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio") } else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
}); });
}; };
@@ -902,6 +933,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin-top: -10px; margin-top: -10px;
transform: rotate(-60deg) transform: rotate(-60deg)
} }
img.text-to-image {
max-width: 60%;
}
#chat-footer { #chat-footer {
padding: 0; padding: 0;
@@ -1029,6 +1063,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 4px; margin: 4px;
grid-template-columns: auto; grid-template-columns: auto;
} }
img.text-to-image {
max-width: 100%;
}
} }
@media only screen and (min-width: 700px) { @media only screen and (min-width: 700px) {
body { body {

View File

@@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history: if use_history:
for chat in conversation_log.get("chat", [])[-4:]: for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj": if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n" chat_history += f"A: {chat['message']}\n"

View File

@@ -12,11 +12,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e raise e
# Download the chat model # Decide whether to load model to GPU or CPU
chat_model_config = None
try:
# Download the chat model and its config
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
# Decide whether to load model to GPU or CPU
try:
# Try load chat model to GPU if: # Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and # 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU # 2. Machine has GPU
@@ -26,6 +27,12 @@ def download_model(model_name: str):
) )
except ValueError: except ValueError:
device = "cpu" device = "cpu"
except Exception as e:
if chat_model_config is None:
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
logger.debug(f"Unable to download model config from gpt4all website: {e}")
else:
raise e
# Now load the downloaded chat model onto appropriate device # Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)

View File

@@ -41,7 +41,7 @@ def extract_questions(
[ [
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:] for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
] ]
) )
@@ -123,8 +123,8 @@ def send_message_to_model(
def converse( def converse(
references, references,
online_results,
user_query, user_query,
online_results=[],
conversation_log={}, conversation_log={},
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None, api_key: Optional[str] = None,

View File

@@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
@retry( @retry(
retry=( retry=(
retry_if_exception_type(openai.error.Timeout) retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError) | retry_if_exception_type(openai._exceptions.APIStatusError)
), ),
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
@retry( @retry(
retry=( retry=(
retry_if_exception_type(openai.error.Timeout) retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError) | retry_if_exception_type(openai._exceptions.APIStatusError)
), ),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),

View File

@@ -3,13 +3,13 @@ from io import BufferedReader
# External Packages # External Packages
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
import openai from openai import OpenAI
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
""" """
Transcribe audio file using Whisper model via OpenAI's API Transcribe audio file using Whisper model via OpenAI's API
""" """
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key) response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
return response["text"] return response.text

View File

@@ -4,6 +4,7 @@ from time import perf_counter
import json import json
from datetime import datetime from datetime import datetime
import queue import queue
from typing import Any, Dict, List
import tiktoken import tiktoken
# External packages # External packages
@@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
from transformers import AutoTokenizer from transformers import AutoTokenizer
# Internal Packages # Internal Packages
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
@@ -89,6 +92,32 @@ def message_to_log(
return conversation_log return conversation_log
def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
meta_log: Dict,
user_message_time: str = None,
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
intent_type: str = "remember",
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message, system_message,

View File

@@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server from khoj.configure import configure_server
from khoj.database import adapters from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import ChatModelOptions from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import ( from khoj.database.models import (
GithubConfig, GithubConfig,
@@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
CommonQueryParams, CommonQueryParams,
agenerate_chat_response, agenerate_chat_response,
get_conversation_command, get_conversation_command,
text_to_image,
is_ready_to_chat, is_ready_to_chat,
update_telemetry_state, update_telemetry_state,
validate_conversation_config, validate_conversation_config,
@@ -622,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config: if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unprocessable entity error # If the user has not configured a speech to text model, return an unsupported on server error
status_code = 422 status_code = 501
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
speech2text_model = speech_to_text_config.model_name speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally: finally:
# Close and Delete the temporary audio file # Close and Delete the temporary audio file
audio_file.close() audio_file.close()
@@ -665,7 +665,7 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user = request.user.object user: KhojUser = request.user.object
await is_ready_to_chat(user) await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True) conversation_command = get_conversation_command(query=q, any_references=True)
@@ -703,6 +703,11 @@ async def chat(
media_type="text/event-stream", media_type="text/event-stream",
status_code=200, status_code=200,
) )
elif conversation_command == ConversationCommand.Image:
image, status_code = await text_to_image(q)
await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image")
content_obj = {"image": image, "intentType": "text-to-image"}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
@@ -786,7 +791,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user) conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None: if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if ( if (
offline_chat_config offline_chat_config
and offline_chat_config.enabled and offline_chat_config.enabled
@@ -803,7 +807,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline( inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
) )
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat() openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key

View File

@@ -9,23 +9,23 @@ from functools import partial
from time import time from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
import openai
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
# Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.utils.helpers import ConversationCommand, log_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1) executor = ThreadPoolExecutor(max_workers=1)
@@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General return ConversationCommand.General
elif query.startswith("/online"): elif query.startswith("/online"):
return ConversationCommand.Online return ConversationCommand.Online
elif query.startswith("/image"):
return ConversationCommand.Image
# If no relevant notes found for the given query # If no relevant notes found for the given query
elif not any_references: elif not any_references:
return ConversationCommand.General return ConversationCommand.General
@@ -186,30 +188,7 @@ def generate_chat_response(
conversation_command: ConversationCommand = ConversationCommand.Default, conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None, user: KhojUser = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
def _save_to_conversation_log(
q: str,
chat_response: str,
user_message_time: str,
compiled_references: List[str],
online_results: Dict[str, Any],
inferred_queries: List[str],
meta_log,
):
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables # Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_response = None chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}") logger.debug(f"Conversation Type: {conversation_command.name}")
@@ -217,13 +196,13 @@ def generate_chat_response(
try: try:
partial_completion = partial( partial_completion = partial(
_save_to_conversation_log, save_to_conversation_log,
q, q,
user_message_time=user_message_time, user=user,
meta_log=meta_log,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
meta_log=meta_log,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user) conversation_config = ConversationAdapters.get_valid_conversation_config(user)
@@ -251,9 +230,9 @@ def generate_chat_response(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
chat_response = converse( chat_response = converse(
compiled_references, compiled_references,
online_results,
q, q,
meta_log, online_results=online_results,
conversation_log=meta_log,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
@@ -271,6 +250,29 @@ def generate_chat_response(
return chat_response, metadata return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
status_code = 200
image = None
# Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
text2image_model = text_to_image_config.model_name
try:
response = state.openai_client.images.generate(
prompt=message, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
status_code = 500
return image, status_code
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int): def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests self.requests = requests

View File

@@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes" Notes = "notes"
Help = "help" Help = "help"
Online = "online" Online = "online"
Image = "image"
command_descriptions = { command_descriptions = {
@@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.", ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
} }

View File

@@ -7,6 +7,7 @@ from khoj.database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ChatModelOptions, ChatModelOptions,
SpeechToTextModelOptions, SpeechToTextModelOptions,
TextToImageModelConfig,
) )
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@@ -103,6 +104,15 @@ def initialization():
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
) )
default_text_to_image_model = "dall-e-3"
openai_text_to_image_model = input(
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
)
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
TextToImageModelConfig.objects.create(
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
)
if use_offline_model == "y" or use_openai_model == "y": if use_offline_model == "y" or use_openai_model == "y":
logger.info("🗣️ Chat model configuration complete") logger.info("🗣️ Chat model configuration complete")

View File

@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder): class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None): def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name self.model_name = model_name
if ( self.openai_client = client
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
openai.api_key = state.processor_config.conversation.openai_model.api_key
self.embedding_dimensions = None self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs): def encode(self, entries, device=None, **kwargs):
@@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
processed_entry = entries[index].replace("\n", " ") processed_entry = entries[index].replace("\n", " ")
try: try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name) response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)] embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
# Use current models embedding dimension, once available # Use current models embedding dimension, once available
# Else default to embedding dimensions of the text-embedding-ada-002 model # Else default to embedding dimensions of the text-embedding-ada-002 model

View File

@@ -1,15 +1,16 @@
# Standard Packages # Standard Packages
from collections import defaultdict
import os import os
from pathlib import Path
import threading import threading
from typing import List, Dict from typing import List, Dict
from collections import defaultdict
# External Packages # External Packages
from pathlib import Path from openai import OpenAI
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from whisper import Whisper from whisper import Whisper
# Internal Packages # Internal Packages
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device from khoj.utils.helpers import LRU, get_device
@@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: EmbeddingsModel = None embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex() content_index = ContentIndex()
openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None whisper_model: Whisper = None
config_file: Path = None config_file: Path = None

View File

@@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0] response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"] expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message "Expected links or serper not setup in response but got: " + response_message
) )