mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Improve Chat Page Load Perf, Offline Chat Perf and Miscellaneous Fixes (#703)
### Store Generated Images as WebP -78bac4aeAdd migration script to convert PNG to WebP references in database -c6e84436Update clients to support rendering webp images inline -d21f22ffStore Khoj generated images as webp instead of png for faster loading ### Lazy Fetch Chat Messages to Improve Time, Data to First Render This is especially helpful for long conversations with lots of images -128829c4Render latest msgs on chat session load. Fetch, render rest as they near viewport -9e558577Support getting latest N chat messages via chat history API ### Intelligently set Context Window of Offline Chat to Improve Performance -4977b551Use offline chat prompt config to set context window of loaded chat model ### Fixes -148923c1Fix to raise error on hitting rate limit during Github indexing -b8bc6beeAlways remove loading animation on Desktop app if can't login to server -38250705Fix `get_user_photo` to only return photo, not user name from DB ### Miscellaneous Improvements -689202e0Update recommended CMAKE flag to enable using CUDA on linux in Docs -b820daf3Makes logs less noisy
This commit is contained in:
@@ -134,7 +134,7 @@ python -m pip install khoj-assistant
|
|||||||
# CPU
|
# CPU
|
||||||
python -m pip install khoj-assistant
|
python -m pip install khoj-assistant
|
||||||
# NVIDIA (CUDA) GPU
|
# NVIDIA (CUDA) GPU
|
||||||
CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||||
# AMD (ROCm) GPU
|
# AMD (ROCm) GPU
|
||||||
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||||
# VULCAN GPU
|
# VULCAN GPU
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ dependencies = [
|
|||||||
"phonenumbers == 8.13.27",
|
"phonenumbers == 8.13.27",
|
||||||
"markdownify ~= 0.11.6",
|
"markdownify ~= 0.11.6",
|
||||||
"websockets == 12.0",
|
"websockets == 12.0",
|
||||||
|
"psutil >= 5.8.0",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
@@ -105,7 +106,6 @@ dev = [
|
|||||||
"pytest-asyncio == 0.21.1",
|
"pytest-asyncio == 0.21.1",
|
||||||
"freezegun >= 1.2.0",
|
"freezegun >= 1.2.0",
|
||||||
"factory-boy >= 3.2.1",
|
"factory-boy >= 3.2.1",
|
||||||
"psutil >= 5.8.0",
|
|
||||||
"mypy >= 1.0.1",
|
"mypy >= 1.0.1",
|
||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||||
<title>Khoj - Chat</title>
|
<title>Khoj - Chat</title>
|
||||||
|
|
||||||
|
<link rel="stylesheet" href="./assets/khoj.css">
|
||||||
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
|
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
|
||||||
<link rel="manifest" href="/static/khoj.webmanifest">
|
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||||
<link rel="stylesheet" href="./assets/khoj.css">
|
|
||||||
</head>
|
</head>
|
||||||
<script type="text/javascript" src="./assets/markdown-it.min.js"></script>
|
<script type="text/javascript" src="./assets/markdown-it.min.js"></script>
|
||||||
<script src="./utils.js"></script>
|
<script src="./utils.js"></script>
|
||||||
@@ -130,7 +130,7 @@
|
|||||||
return referenceButton;
|
return referenceButton;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderMessage(message, by, dt=null, annotations=null, raw=false) {
|
function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
|
||||||
let message_time = formatDate(dt ?? new Date());
|
let message_time = formatDate(dt ?? new Date());
|
||||||
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
|
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
|
||||||
let formattedMessage = formatHTMLMessage(message, raw);
|
let formattedMessage = formatHTMLMessage(message, raw);
|
||||||
@@ -153,10 +153,15 @@
|
|||||||
|
|
||||||
// Append chat message div to chat body
|
// Append chat message div to chat body
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
if (renderType === "append") {
|
||||||
chatBody.appendChild(chatMessage);
|
chatBody.appendChild(chatMessage);
|
||||||
|
|
||||||
// Scroll to bottom of chat-body element
|
// Scroll to bottom of chat-body element
|
||||||
chatBody.scrollTop = chatBody.scrollHeight;
|
chatBody.scrollTop = chatBody.scrollHeight;
|
||||||
|
} else if (renderType === "prepend") {
|
||||||
|
chatBody.insertBefore(chatMessage, chatBody.firstChild);
|
||||||
|
} else if (renderType === "return") {
|
||||||
|
return chatMessage;
|
||||||
|
}
|
||||||
|
|
||||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||||
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||||
@@ -207,6 +212,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||||
|
// If no document or online context is provided, render the message as is
|
||||||
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||||
if (intentType?.includes("text-to-image")) {
|
if (intentType?.includes("text-to-image")) {
|
||||||
let imageMarkdown;
|
let imageMarkdown;
|
||||||
@@ -214,30 +220,29 @@
|
|||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
} else if (intentType === "text-to-image2") {
|
} else if (intentType === "text-to-image2") {
|
||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
|
} else if (intentType === "text-to-image-v3") {
|
||||||
|
imageMarkdown = ``;
|
||||||
}
|
}
|
||||||
|
|
||||||
const inferredQuery = inferredQueries?.[0];
|
const inferredQuery = inferredQueries?.[0];
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
renderMessage(imageMarkdown, by, dt);
|
return renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
renderMessage(message, by, dt);
|
return renderMessage(message, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (context == null && onlineContext == null) {
|
if (context == null && onlineContext == null) {
|
||||||
renderMessage(message, by, dt);
|
return renderMessage(message, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||||
renderMessage(message, by, dt);
|
return renderMessage(message, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If document or online context is provided, render the message with its references
|
||||||
let references = document.createElement('div');
|
let references = document.createElement('div');
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
let referenceExpandButton = document.createElement('button');
|
||||||
@@ -288,16 +293,17 @@
|
|||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
} else if (intentType === "text-to-image2") {
|
} else if (intentType === "text-to-image2") {
|
||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
|
} else if (intentType === "text-to-image-v3") {
|
||||||
|
imageMarkdown = ``;
|
||||||
}
|
}
|
||||||
const inferredQuery = inferredQueries?.[0];
|
const inferredQuery = inferredQueries?.[0];
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
renderMessage(imageMarkdown, by, dt, references);
|
return renderMessage(imageMarkdown, by, dt, references, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
renderMessage(message, by, dt, references);
|
return renderMessage(message, by, dt, references, false, "return");
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
|
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
|
||||||
@@ -509,6 +515,8 @@
|
|||||||
rawResponse += ``;
|
rawResponse += ``;
|
||||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
} else if (responseAsJson.intentType === "text-to-image2") {
|
||||||
rawResponse += ``;
|
rawResponse += ``;
|
||||||
|
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
||||||
|
rawResponse += ``;
|
||||||
}
|
}
|
||||||
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
||||||
if (inferredQueries) {
|
if (inferredQueries) {
|
||||||
@@ -671,7 +679,7 @@
|
|||||||
let firstRunSetupMessageRendered = false;
|
let firstRunSetupMessageRendered = false;
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
chatBody.innerHTML = "";
|
chatBody.innerHTML = "";
|
||||||
let chatHistoryUrl = `/api/chat/history?client=desktop`;
|
let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`;
|
||||||
if (chatBody.dataset.conversationId) {
|
if (chatBody.dataset.conversationId) {
|
||||||
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||||
}
|
}
|
||||||
@@ -683,7 +691,8 @@
|
|||||||
loadingScreen.appendChild(yellowOrb);
|
loadingScreen.appendChild(yellowOrb);
|
||||||
chatBody.appendChild(loadingScreen);
|
chatBody.appendChild(loadingScreen);
|
||||||
|
|
||||||
fetch(`${hostURL}${chatHistoryUrl}`, { headers })
|
// Get the most recent 10 chat messages from conversation history
|
||||||
|
fetch(`${chatHistoryUrl}&n=10`, { headers })
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
@@ -703,11 +712,21 @@
|
|||||||
chatBody.dataset.conversationId = response.conversation_id;
|
chatBody.dataset.conversationId = response.conversation_id;
|
||||||
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
||||||
|
|
||||||
const fullChatLog = response.chat || [];
|
// Create a new IntersectionObserver
|
||||||
|
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
|
||||||
|
entries.forEach(entry => {
|
||||||
|
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||||
|
if (entry.isIntersecting) {
|
||||||
|
fetchRemainingChatMessages(chatHistoryUrl);
|
||||||
|
observer.unobserve(entry.target);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, {rootMargin: '0px 0px 0px 0px'});
|
||||||
|
|
||||||
fullChatLog.forEach(chat_log => {
|
const fullChatLog = response.chat || [];
|
||||||
|
fullChatLog.forEach((chat_log, index) => {
|
||||||
if (chat_log.message != null) {
|
if (chat_log.message != null) {
|
||||||
renderMessageWithReference(
|
let messageElement = renderMessageWithReference(
|
||||||
chat_log.message,
|
chat_log.message,
|
||||||
chat_log.by,
|
chat_log.by,
|
||||||
chat_log.context,
|
chat_log.context,
|
||||||
@@ -715,10 +734,25 @@
|
|||||||
chat_log.onlineContext,
|
chat_log.onlineContext,
|
||||||
chat_log.intent?.type,
|
chat_log.intent?.type,
|
||||||
chat_log.intent?.["inferred-queries"]);
|
chat_log.intent?.["inferred-queries"]);
|
||||||
|
chatBody.appendChild(messageElement);
|
||||||
|
|
||||||
|
// When the 4th oldest message is within viewing distance (~60% scrolled up)
|
||||||
|
// Fetch the remaining chat messages
|
||||||
|
if (index === 4) {
|
||||||
|
fetchRemainingMessagesObserver.observe(messageElement);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
loadingScreen.style.height = chatBody.scrollHeight + 'px';
|
loadingScreen.style.height = chatBody.scrollHeight + 'px';
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Scroll to bottom of chat-body element
|
||||||
|
chatBody.scrollTop = chatBody.scrollHeight;
|
||||||
|
|
||||||
|
// Set height of chat-body element to the height of the chat-body-wrapper
|
||||||
|
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||||
|
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||||
|
chatBody.style.height = chatBodyWrapperHeight;
|
||||||
|
|
||||||
// Add fade out animation to loading screen and remove it after the animation ends
|
// Add fade out animation to loading screen and remove it after the animation ends
|
||||||
fadeOutLoadingAnimation(loadingScreen);
|
fadeOutLoadingAnimation(loadingScreen);
|
||||||
})
|
})
|
||||||
@@ -726,8 +760,8 @@
|
|||||||
// If the server returns a 500 error with detail, render a setup hint.
|
// If the server returns a 500 error with detail, render a setup hint.
|
||||||
if (!firstRunSetupMessageRendered) {
|
if (!firstRunSetupMessageRendered) {
|
||||||
renderFirstRunSetupMessage();
|
renderFirstRunSetupMessage();
|
||||||
fadeOutLoadingAnimation(loadingScreen);
|
|
||||||
}
|
}
|
||||||
|
fadeOutLoadingAnimation(loadingScreen);
|
||||||
return;
|
return;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -778,6 +812,65 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function fetchRemainingChatMessages(chatHistoryUrl) {
|
||||||
|
// Create a new IntersectionObserver
|
||||||
|
let observer = new IntersectionObserver((entries, observer) => {
|
||||||
|
entries.forEach(entry => {
|
||||||
|
// If the element is in the viewport, render the message and unobserve the element
|
||||||
|
if (entry.isIntersecting) {
|
||||||
|
let chat_log = entry.target.chat_log;
|
||||||
|
let messageElement = renderMessageWithReference(
|
||||||
|
chat_log.message,
|
||||||
|
chat_log.by,
|
||||||
|
chat_log.context,
|
||||||
|
new Date(chat_log.created),
|
||||||
|
chat_log.onlineContext,
|
||||||
|
chat_log.intent?.type,
|
||||||
|
chat_log.intent?.["inferred-queries"]
|
||||||
|
);
|
||||||
|
entry.target.replaceWith(messageElement);
|
||||||
|
|
||||||
|
// Remove the observer after the element has been rendered
|
||||||
|
observer.unobserve(entry.target);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
|
||||||
|
|
||||||
|
// Fetch remaining chat messages from conversation history
|
||||||
|
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.status != "ok") {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
return data.response;
|
||||||
|
})
|
||||||
|
.then(response => {
|
||||||
|
const fullChatLog = response.chat || [];
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
fullChatLog
|
||||||
|
.reverse()
|
||||||
|
.forEach(chat_log => {
|
||||||
|
if (chat_log.message != null) {
|
||||||
|
// Create a new element for each chat log
|
||||||
|
let placeholder = document.createElement('div');
|
||||||
|
placeholder.chat_log = chat_log;
|
||||||
|
|
||||||
|
// Insert the message placeholder as the first child of chat body after the welcome message
|
||||||
|
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
|
||||||
|
|
||||||
|
// Observe the element
|
||||||
|
placeholder.style.height = "20px";
|
||||||
|
observer.observe(placeholder);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
console.log(err);
|
||||||
|
return;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
function fadeOutLoadingAnimation(loadingScreen) {
|
function fadeOutLoadingAnimation(loadingScreen) {
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||||
|
|||||||
@@ -156,6 +156,8 @@ export class KhojChatModal extends Modal {
|
|||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
} else if (intentType === "text-to-image2") {
|
} else if (intentType === "text-to-image2") {
|
||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
|
} else if (intentType === "text-to-image-v3") {
|
||||||
|
imageMarkdown = ``;
|
||||||
}
|
}
|
||||||
if (inferredQueries) {
|
if (inferredQueries) {
|
||||||
imageMarkdown += "\n\n**Inferred Query**:";
|
imageMarkdown += "\n\n**Inferred Query**:";
|
||||||
@@ -429,6 +431,8 @@ export class KhojChatModal extends Modal {
|
|||||||
responseText += ``;
|
responseText += ``;
|
||||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
} else if (responseAsJson.intentType === "text-to-image2") {
|
||||||
responseText += ``;
|
responseText += ``;
|
||||||
|
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
||||||
|
responseText += ``;
|
||||||
}
|
}
|
||||||
const inferredQuery = responseAsJson.inferredQueries?.[0];
|
const inferredQuery = responseAsJson.inferredQueries?.[0];
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled
|
|||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import SearchType
|
from khoj.utils.config import SearchType
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.helpers import is_none_or_empty
|
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -232,6 +232,9 @@ def configure_server(
|
|||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
setup_default_agent()
|
setup_default_agent()
|
||||||
|
|
||||||
|
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
||||||
|
logger.info(message)
|
||||||
|
|
||||||
if not init:
|
if not init:
|
||||||
initialize_content(regenerate, search_type, user)
|
initialize_content(regenerate, search_type, user)
|
||||||
|
|
||||||
@@ -329,9 +332,7 @@ def configure_search_types():
|
|||||||
|
|
||||||
@schedule.repeat(schedule.every(2).minutes)
|
@schedule.repeat(schedule.every(2).minutes)
|
||||||
def upload_telemetry():
|
def upload_telemetry():
|
||||||
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
|
if telemetry_disabled(state.config.app) or not state.telemetry:
|
||||||
message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
|
|
||||||
logger.debug(message)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -197,9 +197,6 @@ def get_user_name(user: KhojUser):
|
|||||||
|
|
||||||
|
|
||||||
def get_user_photo(user: KhojUser):
|
def get_user_photo(user: KhojUser):
|
||||||
full_name = user.get_full_name()
|
|
||||||
if not is_none_or_empty(full_name):
|
|
||||||
return full_name
|
|
||||||
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
||||||
if google_profile:
|
if google_profile:
|
||||||
return google_profile.picture
|
return google_profile.picture
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from khoj.database.models import (
|
|||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
)
|
)
|
||||||
|
from khoj.utils.helpers import ImageIntentType
|
||||||
|
|
||||||
|
|
||||||
class KhojUserAdmin(UserAdmin):
|
class KhojUserAdmin(UserAdmin):
|
||||||
@@ -114,9 +115,12 @@ class ConversationAdmin(admin.ModelAdmin):
|
|||||||
log["by"] == "khoj"
|
log["by"] == "khoj"
|
||||||
and log["intent"]
|
and log["intent"]
|
||||||
and log["intent"]["type"]
|
and log["intent"]["type"]
|
||||||
and log["intent"]["type"] == "text-to-image"
|
and (
|
||||||
|
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||||
|
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||||
|
)
|
||||||
):
|
):
|
||||||
log["message"] = "image redacted for space"
|
log["message"] = "inline image redacted for space"
|
||||||
chat_log[idx] = log
|
chat_log[idx] = log
|
||||||
modified_log["chat"] = chat_log
|
modified_log["chat"] = chat_log
|
||||||
|
|
||||||
@@ -154,9 +158,12 @@ class ConversationAdmin(admin.ModelAdmin):
|
|||||||
log["by"] == "khoj"
|
log["by"] == "khoj"
|
||||||
and log["intent"]
|
and log["intent"]
|
||||||
and log["intent"]["type"]
|
and log["intent"]["type"]
|
||||||
and log["intent"]["type"] == "text-to-image"
|
and (
|
||||||
|
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||||
|
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||||
|
)
|
||||||
):
|
):
|
||||||
updated_log["message"] = "image redacted for space"
|
updated_log["message"] = "inline image redacted for space"
|
||||||
chat_log[idx] = updated_log
|
chat_log[idx] = updated_log
|
||||||
return_log["chat"] = chat_log
|
return_log["chat"] = chat_log
|
||||||
|
|
||||||
|
|||||||
0
src/khoj/database/management/__init__.py
Normal file
0
src/khoj/database/management/__init__.py
Normal file
0
src/khoj/database/management/commands/__init__.py
Normal file
0
src/khoj/database/management/commands/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
from khoj.database.models import Conversation
|
||||||
|
from khoj.utils.helpers import ImageIntentType
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Convert all images to WebP format or reverse."
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
# Add a new argument 'reverse' to the command
|
||||||
|
parser.add_argument(
|
||||||
|
"--reverse",
|
||||||
|
action="store_true",
|
||||||
|
help="Convert from WebP to PNG instead of PNG to WebP",
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
updated_count = 0
|
||||||
|
for conversation in Conversation.objects.all():
|
||||||
|
conversation_updated = False
|
||||||
|
for chat in conversation.conversation_log["chat"]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
|
||||||
|
if options["reverse"] and chat["message"].endswith(".webp"):
|
||||||
|
# Convert WebP url to PNG url
|
||||||
|
chat["message"] = chat["message"].replace(".webp", ".png")
|
||||||
|
conversation_updated = True
|
||||||
|
updated_count += 1
|
||||||
|
elif chat["message"].endswith(".png"):
|
||||||
|
# Convert PNG url to WebP url
|
||||||
|
chat["message"] = chat["message"].replace(".png", ".webp")
|
||||||
|
conversation_updated = True
|
||||||
|
updated_count += 1
|
||||||
|
if conversation_updated:
|
||||||
|
conversation.save()
|
||||||
|
|
||||||
|
if updated_count > 0 and options["reverse"]:
|
||||||
|
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format."))
|
||||||
|
elif updated_count > 0:
|
||||||
|
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format."))
|
||||||
69
src/khoj/database/migrations/0035_convert_png_to_webp.py
Normal file
69
src/khoj/database/migrations/0035_convert_png_to_webp.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# Generated by Django 4.2.10 on 2024-04-13 17:54
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from khoj.utils.helpers import ImageIntentType
|
||||||
|
|
||||||
|
|
||||||
|
def convert_png_images_to_webp(apps, schema_editor):
|
||||||
|
# Get the model from the versioned app registry to ensure the correct version is used
|
||||||
|
Conversations = apps.get_model("database", "Conversation")
|
||||||
|
for conversation in Conversations.objects.all():
|
||||||
|
for chat in conversation.conversation_log["chat"]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
|
||||||
|
# Decode the base64 encoded PNG image
|
||||||
|
decoded_image = base64.b64decode(chat["message"])
|
||||||
|
|
||||||
|
# Convert images from PNG to WebP format
|
||||||
|
image_io = io.BytesIO(decoded_image)
|
||||||
|
with Image.open(image_io) as png_image:
|
||||||
|
webp_image_io = io.BytesIO()
|
||||||
|
png_image.save(webp_image_io, "WEBP")
|
||||||
|
|
||||||
|
# Encode the WebP image back to base64
|
||||||
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
|
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
||||||
|
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||||
|
webp_image_io.close()
|
||||||
|
|
||||||
|
# Save the updated conversation history
|
||||||
|
conversation.save()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_webp_images_to_png(apps, schema_editor):
|
||||||
|
# Get the model from the versioned app registry to ensure the correct version is used
|
||||||
|
Conversations = apps.get_model("database", "Conversation")
|
||||||
|
for conversation in Conversations.objects.all():
|
||||||
|
for chat in conversation.conversation_log["chat"]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
|
||||||
|
# Decode the base64 encoded PNG image
|
||||||
|
decoded_image = base64.b64decode(chat["message"])
|
||||||
|
|
||||||
|
# Convert images from PNG to WebP format
|
||||||
|
image_io = io.BytesIO(decoded_image)
|
||||||
|
with Image.open(image_io) as png_image:
|
||||||
|
webp_image_io = io.BytesIO()
|
||||||
|
png_image.save(webp_image_io, "PNG")
|
||||||
|
|
||||||
|
# Encode the WebP image back to base64
|
||||||
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
|
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
||||||
|
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
|
||||||
|
webp_image_io.close()
|
||||||
|
|
||||||
|
# Save the updated conversation history
|
||||||
|
conversation.save()
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0034_alter_chatmodeloptions_chat_model"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png),
|
||||||
|
]
|
||||||
@@ -4,10 +4,10 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||||
<title>Khoj - Chat</title>
|
<title>Khoj - Chat</title>
|
||||||
|
|
||||||
|
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
|
||||||
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
|
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
|
||||||
<link rel="apple-touch-icon" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
|
<link rel="apple-touch-icon" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
|
||||||
<link rel="manifest" href="/static/khoj.webmanifest?v={{ khoj_version }}">
|
<link rel="manifest" href="/static/khoj.webmanifest?v={{ khoj_version }}">
|
||||||
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
|
|
||||||
</head>
|
</head>
|
||||||
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
|
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
|
||||||
<script type="text/javascript" src="/static/assets/markdown-it.min.js?v={{ khoj_version }}"></script>
|
<script type="text/javascript" src="/static/assets/markdown-it.min.js?v={{ khoj_version }}"></script>
|
||||||
@@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
return referenceButton;
|
return referenceButton;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderMessage(message, by, dt=null, annotations=null, raw=false) {
|
function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
|
||||||
let message_time = formatDate(dt ?? new Date());
|
let message_time = formatDate(dt ?? new Date());
|
||||||
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
|
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
|
||||||
let formattedMessage = formatHTMLMessage(message, raw);
|
let formattedMessage = formatHTMLMessage(message, raw);
|
||||||
@@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
|
|
||||||
// Append chat message div to chat body
|
// Append chat message div to chat body
|
||||||
let chatBody = document.getElementById("chat-body");
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
if (renderType === "append") {
|
||||||
chatBody.appendChild(chatMessage);
|
chatBody.appendChild(chatMessage);
|
||||||
|
|
||||||
// Scroll to bottom of chat-body element
|
// Scroll to bottom of chat-body element
|
||||||
chatBody.scrollTop = chatBody.scrollHeight;
|
chatBody.scrollTop = chatBody.scrollHeight;
|
||||||
|
} else if (renderType === "prepend"){
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
chatBody.insertBefore(chatMessage, chatBody.firstChild);
|
||||||
|
} else if (renderType === "return") {
|
||||||
|
return chatMessage;
|
||||||
|
}
|
||||||
|
|
||||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||||
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||||
@@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
}
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||||
|
// If no document or online context is provided, render the message as is
|
||||||
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||||
if (intentType?.includes("text-to-image")) {
|
if (intentType?.includes("text-to-image")) {
|
||||||
let imageMarkdown;
|
let imageMarkdown;
|
||||||
@@ -244,24 +251,24 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
} else if (intentType === "text-to-image2") {
|
} else if (intentType === "text-to-image2") {
|
||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
|
} else if (intentType === "text-to-image-v3") {
|
||||||
|
imageMarkdown = ``;
|
||||||
}
|
}
|
||||||
const inferredQuery = inferredQueries?.[0];
|
const inferredQuery = inferredQueries?.[0];
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
renderMessage(imageMarkdown, by, dt);
|
return renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
renderMessage(message, by, dt);
|
return renderMessage(message, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||||
renderMessage(message, by, dt);
|
return renderMessage(message, by, dt, null, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If document or online context is provided, render the message with its references
|
||||||
let references = document.createElement('div');
|
let references = document.createElement('div');
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
let referenceExpandButton = document.createElement('button');
|
||||||
@@ -312,16 +319,17 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
} else if (intentType === "text-to-image2") {
|
} else if (intentType === "text-to-image2") {
|
||||||
imageMarkdown = ``;
|
imageMarkdown = ``;
|
||||||
|
} else if (intentType === "text-to-image-v3") {
|
||||||
|
imageMarkdown = ``;
|
||||||
}
|
}
|
||||||
const inferredQuery = inferredQueries?.[0];
|
const inferredQuery = inferredQueries?.[0];
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
renderMessage(imageMarkdown, by, dt, references);
|
return renderMessage(imageMarkdown, by, dt, references, false, "return");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
renderMessage(message, by, dt, references);
|
return renderMessage(message, by, dt, references, false, "return");
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
|
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
|
||||||
@@ -619,6 +627,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
rawResponse += ``;
|
rawResponse += ``;
|
||||||
} else if (imageJson.intentType === "text-to-image2") {
|
} else if (imageJson.intentType === "text-to-image2") {
|
||||||
rawResponse += ``;
|
rawResponse += ``;
|
||||||
|
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||||
|
rawResponse = ``;
|
||||||
}
|
}
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
@@ -1064,7 +1074,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
loadingScreen.appendChild(yellowOrb);
|
loadingScreen.appendChild(yellowOrb);
|
||||||
chatBody.appendChild(loadingScreen);
|
chatBody.appendChild(loadingScreen);
|
||||||
|
|
||||||
fetch(chatHistoryUrl, { method: "GET" })
|
// Get the most recent 10 chat messages from conversation history
|
||||||
|
fetch(`${chatHistoryUrl}&n=10`, { method: "GET" })
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
@@ -1134,11 +1145,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
agentMetadataElement.style.display = "none";
|
agentMetadataElement.style.display = "none";
|
||||||
}
|
}
|
||||||
|
|
||||||
const fullChatLog = response.chat || [];
|
// Create a new IntersectionObserver
|
||||||
|
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
|
||||||
|
entries.forEach(entry => {
|
||||||
|
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||||
|
if (entry.isIntersecting) {
|
||||||
|
fetchRemainingChatMessages(chatHistoryUrl);
|
||||||
|
observer.unobserve(entry.target);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, {rootMargin: '0px 0px 0px 0px'});
|
||||||
|
|
||||||
fullChatLog.forEach(chat_log => {
|
const fullChatLog = response.chat || [];
|
||||||
|
fullChatLog.forEach((chat_log, index) => {
|
||||||
|
// Render the last 10 messages immediately
|
||||||
if (chat_log.message != null) {
|
if (chat_log.message != null) {
|
||||||
renderMessageWithReference(
|
let messageElement = renderMessageWithReference(
|
||||||
chat_log.message,
|
chat_log.message,
|
||||||
chat_log.by,
|
chat_log.by,
|
||||||
chat_log.context,
|
chat_log.context,
|
||||||
@@ -1146,14 +1168,26 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
chat_log.onlineContext,
|
chat_log.onlineContext,
|
||||||
chat_log.intent?.type,
|
chat_log.intent?.type,
|
||||||
chat_log.intent?.["inferred-queries"]);
|
chat_log.intent?.["inferred-queries"]);
|
||||||
|
chatBody.appendChild(messageElement);
|
||||||
|
|
||||||
|
// When the 4th oldest message is within viewing distance (~60% scroll up)
|
||||||
|
// Fetch the remaining chat messages
|
||||||
|
if (index === 4) {
|
||||||
|
fetchRemainingMessagesObserver.observe(messageElement);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
loadingScreen.style.height = chatBody.scrollHeight + 'px';
|
loadingScreen.style.height = chatBody.scrollHeight + 'px';
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add fade out animation to loading screen and remove it after the animation ends
|
// Scroll to bottom of chat-body element
|
||||||
|
chatBody.scrollTop = chatBody.scrollHeight;
|
||||||
|
|
||||||
|
// Set height of chat-body element to the height of the chat-body-wrapper
|
||||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||||
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||||
chatBody.style.height = chatBodyWrapperHeight;
|
chatBody.style.height = chatBodyWrapperHeight;
|
||||||
|
|
||||||
|
// Add fade out animation to loading screen and remove it after the animation ends
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
loadingScreen.remove();
|
loadingScreen.remove();
|
||||||
chatBody.classList.remove("relative-position");
|
chatBody.classList.remove("relative-position");
|
||||||
@@ -1211,6 +1245,66 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
document.getElementById("chat-input").value = query_via_url;
|
document.getElementById("chat-input").value = query_via_url;
|
||||||
chat();
|
chat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
function fetchRemainingChatMessages(chatHistoryUrl) {
|
||||||
|
// Create a new IntersectionObserver
|
||||||
|
let observer = new IntersectionObserver((entries, observer) => {
|
||||||
|
entries.forEach(entry => {
|
||||||
|
// If the element is in the viewport, render the message and unobserve the element
|
||||||
|
if (entry.isIntersecting) {
|
||||||
|
let chat_log = entry.target.chat_log;
|
||||||
|
let messageElement = renderMessageWithReference(
|
||||||
|
chat_log.message,
|
||||||
|
chat_log.by,
|
||||||
|
chat_log.context,
|
||||||
|
new Date(chat_log.created),
|
||||||
|
chat_log.onlineContext,
|
||||||
|
chat_log.intent?.type,
|
||||||
|
chat_log.intent?.["inferred-queries"]
|
||||||
|
);
|
||||||
|
entry.target.replaceWith(messageElement);
|
||||||
|
|
||||||
|
// Remove the observer after the element has been rendered
|
||||||
|
observer.unobserve(entry.target);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
|
||||||
|
|
||||||
|
// Fetch remaining chat messages from conversation history
|
||||||
|
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.status != "ok") {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
return data.response;
|
||||||
|
})
|
||||||
|
.then(response => {
|
||||||
|
const fullChatLog = response.chat || [];
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
fullChatLog
|
||||||
|
.reverse()
|
||||||
|
.forEach(chat_log => {
|
||||||
|
if (chat_log.message != null) {
|
||||||
|
// Create a new element for each chat log
|
||||||
|
let placeholder = document.createElement('div');
|
||||||
|
placeholder.chat_log = chat_log;
|
||||||
|
|
||||||
|
// Insert the message placeholder as the first child of chat body after the welcome message
|
||||||
|
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
|
||||||
|
|
||||||
|
// Observe the element
|
||||||
|
placeholder.style.height = "20px";
|
||||||
|
observer.observe(placeholder);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
console.log(err);
|
||||||
|
return;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function flashStatusInChatInput(message) {
|
function flashStatusInChatInput(message) {
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries):
|
|||||||
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
|
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
|
||||||
except ConnectionAbortedError as e:
|
except ConnectionAbortedError as e:
|
||||||
logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}")
|
logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}")
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries):
|
|||||||
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
with timer("Processing response", logger=logger):
|
with timer("Processing response", logger=logger):
|
||||||
pages_or_databases = response["results"] if response.get("results") else []
|
pages_or_databases = response.get("results", [])
|
||||||
|
|
||||||
# Get all pages content
|
# Get all pages content
|
||||||
for p_or_d in pages_or_databases:
|
for p_or_d in pages_or_databases:
|
||||||
@@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries):
|
|||||||
|
|
||||||
current_entries = []
|
current_entries = []
|
||||||
curr_heading = ""
|
curr_heading = ""
|
||||||
for block in content["results"]:
|
for block in content.get("results", []):
|
||||||
block_type = block.get("type")
|
block_type = block.get("type")
|
||||||
|
|
||||||
if block_type == None:
|
if block_type == None:
|
||||||
@@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries):
|
|||||||
return f"\n<b>{heading}</b>\n"
|
return f"\n<b>{heading}</b>\n"
|
||||||
|
|
||||||
def process_nested_children(self, children, raw_content, block_type=None):
|
def process_nested_children(self, children, raw_content, block_type=None):
|
||||||
results = children["results"] if children.get("results") else []
|
results = children.get("results", [])
|
||||||
for child in results:
|
for child in results:
|
||||||
child_type = child.get("type")
|
child_type = child.get("type")
|
||||||
if child_type == None:
|
if child_type == None:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def extract_questions_offline(
|
|||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
should_extract_questions: bool = True,
|
should_extract_questions: bool = True,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
|
max_prompt_size: int = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
@@ -41,7 +42,7 @@ def extract_questions_offline(
|
|||||||
return all_questions
|
return all_questions
|
||||||
|
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
|
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||||
|
|
||||||
@@ -67,12 +68,14 @@ def extract_questions_offline(
|
|||||||
location=location,
|
location=location,
|
||||||
)
|
)
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
example_questions, model_name=model, loaded_model=offline_chat_model
|
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||||
)
|
)
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
|
response = send_message_to_model_offline(
|
||||||
|
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
|
|
||||||
@@ -138,7 +141,7 @@ def converse_offline(
|
|||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
@@ -190,18 +193,18 @@ def converse_offline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response_iterator = send_message_to_model_offline(
|
response_iterator = send_message_to_model_offline(
|
||||||
messages, loaded_model=model, stop=stop_phrases, streaming=True
|
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||||
@@ -216,9 +219,10 @@ def send_message_to_model_offline(
|
|||||||
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||||
streaming=False,
|
streaming=False,
|
||||||
stop=[],
|
stop=[],
|
||||||
|
max_prompt_size: int = None,
|
||||||
):
|
):
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
||||||
if streaming:
|
if streaming:
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from huggingface_hub.constants import HF_HUB_CACHE
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import get_device_memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
||||||
from llama_cpp.llama import Llama
|
# Initialize Model Parameters
|
||||||
|
# Use n_ctx=0 to get context size from the model
|
||||||
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
|
|
||||||
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||||
|
|
||||||
# Decide whether to load model to GPU or CPU
|
# Decide whether to load model to GPU or CPU
|
||||||
@@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
|||||||
model_path = load_model_from_cache(repo_id, filename)
|
model_path = load_model_from_cache(repo_id, filename)
|
||||||
chat_model = None
|
chat_model = None
|
||||||
try:
|
try:
|
||||||
if model_path:
|
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||||
chat_model = Llama(model_path, **kwargs)
|
|
||||||
else:
|
|
||||||
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
|
||||||
except:
|
except:
|
||||||
# Load model on CPU if GPU is not available
|
# Load model on CPU if GPU is not available
|
||||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||||
if model_path:
|
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||||
chat_model = Llama(model_path, **kwargs)
|
|
||||||
else:
|
|
||||||
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
|
||||||
|
|
||||||
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
|
# Now load the model with context size set based on:
|
||||||
|
# 1. context size supported by model and
|
||||||
|
# 2. configured size or machine (V)RAM
|
||||||
|
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
|
||||||
|
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
|
||||||
|
)
|
||||||
return chat_model
|
return chat_model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
|
||||||
|
from llama_cpp.llama import Llama
|
||||||
|
|
||||||
|
if model_path:
|
||||||
|
return Llama(model_path, **kwargs)
|
||||||
|
else:
|
||||||
|
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||||
# Construct the path to the model file in the cache directory
|
# Construct the path to the model file in the cache directory
|
||||||
repo_org, repo_name = repo_id.split("/")
|
repo_org, repo_name = repo_id.split("/")
|
||||||
@@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
|||||||
return paths[0]
|
return paths[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
||||||
|
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||||
|
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
|
||||||
|
if configured_max_tokens:
|
||||||
|
return min(configured_max_tokens, model_context_window)
|
||||||
|
else:
|
||||||
|
return min(vram_based_n_ctx, model_context_window)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import queue
|
import queue
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
@@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
|
|||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||||
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
try:
|
if not max_prompt_size:
|
||||||
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
if loaded_model:
|
||||||
except:
|
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
||||||
max_prompt_size = 2000
|
else:
|
||||||
logger.warning(
|
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
||||||
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Scale lookback turns proportional to max prompt size supported by model
|
# Scale lookback turns proportional to max prompt size supported by model
|
||||||
lookback_turns = max_prompt_size // 750
|
lookback_turns = max_prompt_size // 750
|
||||||
@@ -187,7 +186,7 @@ def truncate_messages(
|
|||||||
max_prompt_size,
|
max_prompt_size,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
tokenizer_name=None,
|
tokenizer_name="hf-internal-testing/llama-tokenizer",
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
@@ -197,15 +196,11 @@ def truncate_messages(
|
|||||||
elif model_name.startswith("gpt-"):
|
elif model_name.startswith("gpt-"):
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
else:
|
else:
|
||||||
try:
|
|
||||||
encoder = download_model(model_name).tokenizer()
|
encoder = download_model(model_name).tokenizer()
|
||||||
except:
|
except:
|
||||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
except:
|
|
||||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
|
||||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract system message from messages
|
# Extract system message from messages
|
||||||
|
|||||||
@@ -289,9 +289,7 @@ async def extract_references_and_questions(
|
|||||||
return compiled_references, inferred_queries, q
|
return compiled_references, inferred_queries, q
|
||||||
|
|
||||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||||
logger.warning(
|
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
|
||||||
)
|
|
||||||
return compiled_references, inferred_queries, q
|
return compiled_references, inferred_queries, q
|
||||||
|
|
||||||
# Extract filter terms from user message
|
# Extract filter terms from user message
|
||||||
@@ -317,8 +315,9 @@ async def extract_references_and_questions(
|
|||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
||||||
chat_model = default_offline_llm.chat_model
|
chat_model = default_offline_llm.chat_model
|
||||||
|
max_tokens = default_offline_llm.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
|
||||||
@@ -328,6 +327,7 @@ async def extract_references_and_questions(
|
|||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
should_extract_questions=True,
|
should_extract_questions=True,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
)
|
)
|
||||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ def chat_history(
|
|||||||
request: Request,
|
request: Request,
|
||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
conversation_id: Optional[int] = None,
|
conversation_id: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
validate_conversation_config()
|
validate_conversation_config()
|
||||||
@@ -109,6 +110,13 @@ def chat_history(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get latest N messages if N > 0
|
||||||
|
if n > 0:
|
||||||
|
meta_log["chat"] = meta_log["chat"][-n:]
|
||||||
|
# Else return all messages except latest N
|
||||||
|
else:
|
||||||
|
meta_log["chat"] = meta_log["chat"][:n]
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
@@ -425,8 +433,7 @@ async def websocket_endpoint(
|
|||||||
api="chat",
|
api="chat",
|
||||||
metadata={"conversation_command": conversation_commands[0].value},
|
metadata={"conversation_command": conversation_commands[0].value},
|
||||||
)
|
)
|
||||||
intent_type = "text-to-image"
|
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
|
||||||
q,
|
q,
|
||||||
user,
|
user,
|
||||||
meta_log,
|
meta_log,
|
||||||
@@ -445,9 +452,6 @@ async def websocket_endpoint(
|
|||||||
await send_complete_llm_response(json.dumps(content_obj))
|
await send_complete_llm_response(json.dumps(content_obj))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if image_url:
|
|
||||||
intent_type = "text-to-image2"
|
|
||||||
image = image_url
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
image,
|
image,
|
||||||
@@ -621,17 +625,13 @@ async def chat(
|
|||||||
metadata={"conversation_command": conversation_commands[0].value},
|
metadata={"conversation_command": conversation_commands[0].value},
|
||||||
**common.__dict__,
|
**common.__dict__,
|
||||||
)
|
)
|
||||||
intent_type = "text-to-image"
|
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
|
||||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||||
)
|
)
|
||||||
if image is None:
|
if image is None:
|
||||||
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
||||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||||
|
|
||||||
if image_url:
|
|
||||||
intent_type = "text-to-image2"
|
|
||||||
image = image_url
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
image,
|
image,
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@@ -18,6 +20,7 @@ from typing import (
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
from PIL import Image
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
|
|
||||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||||
@@ -46,6 +49,7 @@ from khoj.utils import state
|
|||||||
from khoj.utils.config import OfflineChatProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
ImageIntentType,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_valid_url,
|
is_valid_url,
|
||||||
log_telemetry,
|
log_telemetry,
|
||||||
@@ -79,9 +83,10 @@ async def is_ready_to_chat(user: KhojUser):
|
|||||||
|
|
||||||
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
|
max_tokens = user_conversation_config.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
logger.info("Loading Offline Chat Model...")
|
logger.info("Loading Offline Chat Model...")
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ready = has_openai_config or has_offline_config
|
ready = has_openai_config or has_offline_config
|
||||||
@@ -382,10 +387,11 @@ async def send_message_to_model_wrapper(
|
|||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
@@ -452,7 +458,9 @@ def generate_chat_response(
|
|||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
@@ -508,18 +516,19 @@ async def text_to_image(
|
|||||||
references: List[str],
|
references: List[str],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], int, Optional[str], str]:
|
||||||
status_code = 200
|
status_code = 200
|
||||||
image = None
|
image = None
|
||||||
response = None
|
response = None
|
||||||
image_url = None
|
image_url = None
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image, status_code, message, image_url
|
return image_url or image, status_code, message, intent_type.value
|
||||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
logger.info("Generating image with OpenAI")
|
logger.info("Generating image with OpenAI")
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
@@ -550,21 +559,38 @@ async def text_to_image(
|
|||||||
)
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
|
|
||||||
|
with timer("Convert image to webp", logger):
|
||||||
|
# Convert png to webp for faster loading
|
||||||
|
decoded_image = base64.b64decode(image)
|
||||||
|
image_io = io.BytesIO(decoded_image)
|
||||||
|
png_image = Image.open(image_io)
|
||||||
|
webp_image_io = io.BytesIO()
|
||||||
|
png_image.save(webp_image_io, "WEBP")
|
||||||
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
|
webp_image_io.close()
|
||||||
|
image_io.close()
|
||||||
|
|
||||||
with timer("Upload image to S3", logger):
|
with timer("Upload image to S3", logger):
|
||||||
image_url = upload_image(image, user.uuid)
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
return image, status_code, improved_image_prompt, image_url
|
if image_url:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||||
|
else:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
if "content_policy_violation" in e.message:
|
if "content_policy_violation" in e.message:
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
return image, status_code, message, image_url
|
return image_url or image, status_code, message, intent_type.value
|
||||||
else:
|
else:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image, status_code, message, image_url
|
return image_url or image, status_code, message, intent_type.value
|
||||||
return image, status_code, response, image_url
|
return image_url or image, status_code, response, intent_type.value
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@@ -17,16 +16,15 @@ if aws_enabled:
|
|||||||
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
|
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
|
||||||
|
|
||||||
|
|
||||||
def upload_image(image: str, user_id: uuid.UUID):
|
def upload_image(image: bytes, user_id: uuid.UUID):
|
||||||
"""Upload the image to the S3 bucket"""
|
"""Upload the image to the S3 bucket"""
|
||||||
if not aws_enabled:
|
if not aws_enabled:
|
||||||
logger.info("AWS is not enabled. Skipping image upload")
|
logger.info("AWS is not enabled. Skipping image upload")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
decoded_image = base64.b64decode(image)
|
image_key = f"{user_id}/{uuid.uuid4()}.webp"
|
||||||
image_key = f"{user_id}/{uuid.uuid4()}.png"
|
|
||||||
try:
|
try:
|
||||||
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
|
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
|
||||||
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
|
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
|
|||||||
|
|
||||||
|
|
||||||
class OfflineChatProcessorModel:
|
class OfflineChatProcessorModel:
|
||||||
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
|
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.loaded_model = None
|
self.loaded_model = None
|
||||||
try:
|
try:
|
||||||
self.loaded_model = download_model(self.chat_model)
|
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.loaded_model = None
|
self.loaded_model = None
|
||||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from time import perf_counter
|
|||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
@@ -233,6 +234,10 @@ def get_server_id():
|
|||||||
return server_id
|
return server_id
|
||||||
|
|
||||||
|
|
||||||
|
def telemetry_disabled(app_config: AppConfig):
|
||||||
|
return not app_config or not app_config.should_log_telemetry
|
||||||
|
|
||||||
|
|
||||||
def log_telemetry(
|
def log_telemetry(
|
||||||
telemetry_type: str,
|
telemetry_type: str,
|
||||||
api: str = None,
|
api: str = None,
|
||||||
@@ -242,7 +247,7 @@ def log_telemetry(
|
|||||||
):
|
):
|
||||||
"""Log basic app usage telemetry like client, os, api called"""
|
"""Log basic app usage telemetry like client, os, api called"""
|
||||||
# Do not log usage telemetry, if telemetry is disabled via app config
|
# Do not log usage telemetry, if telemetry is disabled via app config
|
||||||
if not app_config or not app_config.should_log_telemetry:
|
if telemetry_disabled(app_config):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if properties.get("server_id") is None:
|
if properties.get("server_id") is None:
|
||||||
@@ -267,6 +272,17 @@ def log_telemetry(
|
|||||||
return request_body
|
return request_body
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_memory() -> int:
|
||||||
|
"""Get device memory in GB"""
|
||||||
|
device = get_device()
|
||||||
|
if device.type == "cuda":
|
||||||
|
return torch.cuda.get_device_properties(device).total_memory
|
||||||
|
elif device.type == "mps":
|
||||||
|
return torch.mps.driver_allocated_memory()
|
||||||
|
else:
|
||||||
|
return psutil.virtual_memory().total
|
||||||
|
|
||||||
|
|
||||||
def get_device() -> torch.device:
|
def get_device() -> torch.device:
|
||||||
"""Get device to run model on"""
|
"""Get device to run model on"""
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -313,6 +329,20 @@ mode_descriptions_for_llm = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageIntentType(Enum):
|
||||||
|
"""
|
||||||
|
Chat message intent by Khoj for image responses.
|
||||||
|
Marks the schema used to reference image in chat messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Images as Inline PNG
|
||||||
|
TEXT_TO_IMAGE = "text-to-image"
|
||||||
|
# Images as URLs
|
||||||
|
TEXT_TO_IMAGE2 = "text-to-image2"
|
||||||
|
# Images as Inline WebP
|
||||||
|
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
|
||||||
|
|
||||||
|
|
||||||
def generate_random_name():
|
def generate_random_name():
|
||||||
# List of adjectives and nouns to choose from
|
# List of adjectives and nouns to choose from
|
||||||
adjectives = [
|
adjectives = [
|
||||||
|
|||||||
Reference in New Issue
Block a user