Merge pull request #580 from khoj-ai/fix-upgrade-chat-to-create-images

Support Image Generation with Khoj
This commit is contained in:
sabaimran
2023-12-07 21:17:58 +05:30
committed by GitHub
22 changed files with 529 additions and 303 deletions

View File

@@ -179,7 +179,13 @@
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@@ -244,6 +250,17 @@
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@@ -328,109 +345,142 @@
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url, { headers })
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
// Call specified Khoj API
let response = await fetch(url, { headers });
let rawResponse = "";
const contentType = response.headers.get("content-type");
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
references.appendChild(referenceExpandButton);
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
readStream();
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
}
}
function incrementalChat(event) {
@@ -522,7 +572,7 @@
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
});
})
.catch(err => {
@@ -625,9 +675,13 @@
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
if (err.status === 501) {
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
} else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
});
};
@@ -810,6 +864,9 @@
margin-top: -10px;
transform: rotate(-60deg)
}
img.text-to-image {
max-width: 60%;
}
#chat-footer {
padding: 0;
@@ -1050,6 +1107,9 @@
margin: 4px;
grid-template-columns: auto;
}
img.text-to-image {
max-width: 100%;
}
}
@media only screen and (min-width: 600px) {
body {

View File

@@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
import { KhojSetting } from 'src/settings';
import fetch from "node-fetch";
export interface ChatJsonResult {
image?: string;
detail?: string;
}
export class KhojChatModal extends Modal {
result: string;
setting: KhojSetting;
@@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
return referenceButton;
}
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
if (!message) {
return;
} else if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
this.renderMessage(chatEl, imageMarkdown, sender, dt);
return;
} else if (!context) {
this.renderMessage(chatEl, message, sender, dt);
return
return;
} else if (!!context && context?.length === 0) {
this.renderMessage(chatEl, message, sender, dt);
return
return;
}
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
@@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
});
}
@@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
this.result = "";
responseElement.innerHTML = "";
if (response.headers.get("content-type") == "application/json") {
let responseText = ""
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.detail) {
responseText = responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
responseText = response.body.read().toString()
} finally {
this.renderIncrementalMessage(responseElement, responseText);
}
}
for await (const chunk of response.body) {
const responseText = chunk.toString();
let responseText = chunk.toString();
if (responseText.includes("### compiled references:")) {
const additionalResponse = responseText.split("### compiled references:")[0];
this.renderIncrementalMessage(responseElement, additionalResponse);
@@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
} else {
if (responseText.startsWith("{") && responseText.endsWith("}")) {
} else {
// If the chunk is not a JSON object, just display it as is
continue;
}
this.renderIncrementalMessage(responseElement, responseText);
}
}
@@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
if (response.status === 200) {
console.log(response);
chatInput.value += response.json.text;
} else if (response.status === 422) {
throw new Error("⛔️ Failed to transcribe audio");
} else {
} else if (response.status === 501) {
throw new Error("⛔️ Configure speech-to-text model on server.");
} else if (response.status === 422) {
throw new Error("⛔️ Audio file to large to process.");
} else {
throw new Error("⛔️ Failed to transcribe audio.");
}
};

View File

@@ -217,7 +217,9 @@ button.copy-button:hover {
background: #f5f5f5;
cursor: pointer;
}
img {
max-width: 60%;
}
#khoj-chat-footer {
padding: 0;