mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Improve Chat Page Load Perf, Offline Chat Perf and Miscellaneous Fixes (#703)
### Store Generated Images as WebP -78bac4aeAdd migration script to convert PNG to WebP references in database -c6e84436Update clients to support rendering webp images inline -d21f22ffStore Khoj generated images as webp instead of png for faster loading ### Lazy Fetch Chat Messages to Improve Time, Data to First Render This is especially helpful for long conversations with lots of images -128829c4Render latest msgs on chat session load. Fetch, render rest as they near viewport -9e558577Support getting latest N chat messages via chat history API ### Intelligently set Context Window of Offline Chat to Improve Performance -4977b551Use offline chat prompt config to set context window of loaded chat model ### Fixes -148923c1Fix to raise error on hitting rate limit during Github indexing -b8bc6beeAlways remove loading animation on Desktop app if can't login to server -38250705Fix `get_user_photo` to only return photo, not user name from DB ### Miscellaneous Improvements -689202e0Update recommended CMAKE flag to enable using CUDA on linux in Docs -b820daf3Makes logs less noisy
This commit is contained in:
@@ -134,7 +134,7 @@ python -m pip install khoj-assistant
|
||||
# CPU
|
||||
python -m pip install khoj-assistant
|
||||
# NVIDIA (CUDA) GPU
|
||||
CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||
CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||
# AMD (ROCm) GPU
|
||||
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||
# VULCAN GPU
|
||||
|
||||
@@ -78,6 +78,7 @@ dependencies = [
|
||||
"phonenumbers == 8.13.27",
|
||||
"markdownify ~= 0.11.6",
|
||||
"websockets == 12.0",
|
||||
"psutil >= 5.8.0",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
@@ -105,7 +106,6 @@ dev = [
|
||||
"pytest-asyncio == 0.21.1",
|
||||
"freezegun >= 1.2.0",
|
||||
"factory-boy >= 3.2.1",
|
||||
"psutil >= 5.8.0",
|
||||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||
<title>Khoj - Chat</title>
|
||||
|
||||
<link rel="stylesheet" href="./assets/khoj.css">
|
||||
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
|
||||
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||
<link rel="stylesheet" href="./assets/khoj.css">
|
||||
</head>
|
||||
<script type="text/javascript" src="./assets/markdown-it.min.js"></script>
|
||||
<script src="./utils.js"></script>
|
||||
@@ -130,7 +130,7 @@
|
||||
return referenceButton;
|
||||
}
|
||||
|
||||
function renderMessage(message, by, dt=null, annotations=null, raw=false) {
|
||||
function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
|
||||
let message_time = formatDate(dt ?? new Date());
|
||||
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
|
||||
let formattedMessage = formatHTMLMessage(message, raw);
|
||||
@@ -153,10 +153,15 @@
|
||||
|
||||
// Append chat message div to chat body
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.appendChild(chatMessage);
|
||||
|
||||
// Scroll to bottom of chat-body element
|
||||
chatBody.scrollTop = chatBody.scrollHeight;
|
||||
if (renderType === "append") {
|
||||
chatBody.appendChild(chatMessage);
|
||||
// Scroll to bottom of chat-body element
|
||||
chatBody.scrollTop = chatBody.scrollHeight;
|
||||
} else if (renderType === "prepend") {
|
||||
chatBody.insertBefore(chatMessage, chatBody.firstChild);
|
||||
} else if (renderType === "return") {
|
||||
return chatMessage;
|
||||
}
|
||||
|
||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||
@@ -207,6 +212,7 @@
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||
// If no document or online context is provided, render the message as is
|
||||
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||
if (intentType?.includes("text-to-image")) {
|
||||
let imageMarkdown;
|
||||
@@ -214,30 +220,29 @@
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image2") {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = ``;
|
||||
}
|
||||
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
return renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
return renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
if (context == null && onlineContext == null) {
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
return renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
return renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
// If document or online context is provided, render the message with its references
|
||||
let references = document.createElement('div');
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
@@ -288,16 +293,17 @@
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image2") {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = ``;
|
||||
}
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt, references);
|
||||
return;
|
||||
return renderMessage(imageMarkdown, by, dt, references, false, "return");
|
||||
}
|
||||
|
||||
renderMessage(message, by, dt, references);
|
||||
return renderMessage(message, by, dt, references, false, "return");
|
||||
}
|
||||
|
||||
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
|
||||
@@ -509,6 +515,8 @@
|
||||
rawResponse += ``;
|
||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
||||
rawResponse += ``;
|
||||
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
||||
rawResponse += ``;
|
||||
}
|
||||
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
||||
if (inferredQueries) {
|
||||
@@ -671,7 +679,7 @@
|
||||
let firstRunSetupMessageRendered = false;
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.innerHTML = "";
|
||||
let chatHistoryUrl = `/api/chat/history?client=desktop`;
|
||||
let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`;
|
||||
if (chatBody.dataset.conversationId) {
|
||||
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
|
||||
}
|
||||
@@ -683,7 +691,8 @@
|
||||
loadingScreen.appendChild(yellowOrb);
|
||||
chatBody.appendChild(loadingScreen);
|
||||
|
||||
fetch(`${hostURL}${chatHistoryUrl}`, { headers })
|
||||
// Get the most recent 10 chat messages from conversation history
|
||||
fetch(`${chatHistoryUrl}&n=10`, { headers })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
@@ -703,11 +712,21 @@
|
||||
chatBody.dataset.conversationId = response.conversation_id;
|
||||
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
|
||||
|
||||
const fullChatLog = response.chat || [];
|
||||
// Create a new IntersectionObserver
|
||||
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
|
||||
entries.forEach(entry => {
|
||||
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||
if (entry.isIntersecting) {
|
||||
fetchRemainingChatMessages(chatHistoryUrl);
|
||||
observer.unobserve(entry.target);
|
||||
}
|
||||
});
|
||||
}, {rootMargin: '0px 0px 0px 0px'});
|
||||
|
||||
fullChatLog.forEach(chat_log => {
|
||||
const fullChatLog = response.chat || [];
|
||||
fullChatLog.forEach((chat_log, index) => {
|
||||
if (chat_log.message != null) {
|
||||
renderMessageWithReference(
|
||||
let messageElement = renderMessageWithReference(
|
||||
chat_log.message,
|
||||
chat_log.by,
|
||||
chat_log.context,
|
||||
@@ -715,10 +734,25 @@
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]);
|
||||
chatBody.appendChild(messageElement);
|
||||
|
||||
// When the 4th oldest message is within viewing distance (~60% scrolled up)
|
||||
// Fetch the remaining chat messages
|
||||
if (index === 4) {
|
||||
fetchRemainingMessagesObserver.observe(messageElement);
|
||||
}
|
||||
}
|
||||
loadingScreen.style.height = chatBody.scrollHeight + 'px';
|
||||
})
|
||||
|
||||
// Scroll to bottom of chat-body element
|
||||
chatBody.scrollTop = chatBody.scrollHeight;
|
||||
|
||||
// Set height of chat-body element to the height of the chat-body-wrapper
|
||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||
chatBody.style.height = chatBodyWrapperHeight;
|
||||
|
||||
// Add fade out animation to loading screen and remove it after the animation ends
|
||||
fadeOutLoadingAnimation(loadingScreen);
|
||||
})
|
||||
@@ -726,9 +760,9 @@
|
||||
// If the server returns a 500 error with detail, render a setup hint.
|
||||
if (!firstRunSetupMessageRendered) {
|
||||
renderFirstRunSetupMessage();
|
||||
fadeOutLoadingAnimation(loadingScreen);
|
||||
}
|
||||
return;
|
||||
fadeOutLoadingAnimation(loadingScreen);
|
||||
return;
|
||||
});
|
||||
|
||||
await refreshChatSessionsPanel();
|
||||
@@ -778,6 +812,65 @@
|
||||
}
|
||||
}
|
||||
|
||||
function fetchRemainingChatMessages(chatHistoryUrl) {
|
||||
// Create a new IntersectionObserver
|
||||
let observer = new IntersectionObserver((entries, observer) => {
|
||||
entries.forEach(entry => {
|
||||
// If the element is in the viewport, render the message and unobserve the element
|
||||
if (entry.isIntersecting) {
|
||||
let chat_log = entry.target.chat_log;
|
||||
let messageElement = renderMessageWithReference(
|
||||
chat_log.message,
|
||||
chat_log.by,
|
||||
chat_log.context,
|
||||
new Date(chat_log.created),
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]
|
||||
);
|
||||
entry.target.replaceWith(messageElement);
|
||||
|
||||
// Remove the observer after the element has been rendered
|
||||
observer.unobserve(entry.target);
|
||||
}
|
||||
});
|
||||
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
|
||||
|
||||
// Fetch remaining chat messages from conversation history
|
||||
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status != "ok") {
|
||||
throw new Error(data.message);
|
||||
}
|
||||
return data.response;
|
||||
})
|
||||
.then(response => {
|
||||
const fullChatLog = response.chat || [];
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
fullChatLog
|
||||
.reverse()
|
||||
.forEach(chat_log => {
|
||||
if (chat_log.message != null) {
|
||||
// Create a new element for each chat log
|
||||
let placeholder = document.createElement('div');
|
||||
placeholder.chat_log = chat_log;
|
||||
|
||||
// Insert the message placeholder as the first child of chat body after the welcome message
|
||||
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
|
||||
|
||||
// Observe the element
|
||||
placeholder.style.height = "20px";
|
||||
observer.observe(placeholder);
|
||||
}
|
||||
});
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
return;
|
||||
});
|
||||
}
|
||||
|
||||
function fadeOutLoadingAnimation(loadingScreen) {
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||
|
||||
@@ -156,6 +156,8 @@ export class KhojChatModal extends Modal {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image2") {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = ``;
|
||||
}
|
||||
if (inferredQueries) {
|
||||
imageMarkdown += "\n\n**Inferred Query**:";
|
||||
@@ -429,6 +431,8 @@ export class KhojChatModal extends Modal {
|
||||
responseText += ``;
|
||||
} else if (responseAsJson.intentType === "text-to-image2") {
|
||||
responseText += ``;
|
||||
} else if (responseAsJson.intentType === "text-to-image-v3") {
|
||||
responseText += ``;
|
||||
}
|
||||
const inferredQuery = responseAsJson.inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
|
||||
@@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -232,6 +232,9 @@ def configure_server(
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
setup_default_agent()
|
||||
|
||||
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
||||
logger.info(message)
|
||||
|
||||
if not init:
|
||||
initialize_content(regenerate, search_type, user)
|
||||
|
||||
@@ -329,9 +332,7 @@ def configure_search_types():
|
||||
|
||||
@schedule.repeat(schedule.every(2).minutes)
|
||||
def upload_telemetry():
|
||||
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
|
||||
message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
|
||||
logger.debug(message)
|
||||
if telemetry_disabled(state.config.app) or not state.telemetry:
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -197,9 +197,6 @@ def get_user_name(user: KhojUser):
|
||||
|
||||
|
||||
def get_user_photo(user: KhojUser):
|
||||
full_name = user.get_full_name()
|
||||
if not is_none_or_empty(full_name):
|
||||
return full_name
|
||||
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
||||
if google_profile:
|
||||
return google_profile.picture
|
||||
|
||||
@@ -23,6 +23,7 @@ from khoj.database.models import (
|
||||
TextToImageModelConfig,
|
||||
UserSearchModelConfig,
|
||||
)
|
||||
from khoj.utils.helpers import ImageIntentType
|
||||
|
||||
|
||||
class KhojUserAdmin(UserAdmin):
|
||||
@@ -114,9 +115,12 @@ class ConversationAdmin(admin.ModelAdmin):
|
||||
log["by"] == "khoj"
|
||||
and log["intent"]
|
||||
and log["intent"]["type"]
|
||||
and log["intent"]["type"] == "text-to-image"
|
||||
and (
|
||||
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||
)
|
||||
):
|
||||
log["message"] = "image redacted for space"
|
||||
log["message"] = "inline image redacted for space"
|
||||
chat_log[idx] = log
|
||||
modified_log["chat"] = chat_log
|
||||
|
||||
@@ -154,9 +158,12 @@ class ConversationAdmin(admin.ModelAdmin):
|
||||
log["by"] == "khoj"
|
||||
and log["intent"]
|
||||
and log["intent"]["type"]
|
||||
and log["intent"]["type"] == "text-to-image"
|
||||
and (
|
||||
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||
)
|
||||
):
|
||||
updated_log["message"] = "image redacted for space"
|
||||
updated_log["message"] = "inline image redacted for space"
|
||||
chat_log[idx] = updated_log
|
||||
return_log["chat"] = chat_log
|
||||
|
||||
|
||||
0
src/khoj/database/management/__init__.py
Normal file
0
src/khoj/database/management/__init__.py
Normal file
0
src/khoj/database/management/commands/__init__.py
Normal file
0
src/khoj/database/management/commands/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from khoj.database.models import Conversation
|
||||
from khoj.utils.helpers import ImageIntentType
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Convert all images to WebP format or reverse."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
# Add a new argument 'reverse' to the command
|
||||
parser.add_argument(
|
||||
"--reverse",
|
||||
action="store_true",
|
||||
help="Convert from WebP to PNG instead of PNG to WebP",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
updated_count = 0
|
||||
for conversation in Conversation.objects.all():
|
||||
conversation_updated = False
|
||||
for chat in conversation.conversation_log["chat"]:
|
||||
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
|
||||
if options["reverse"] and chat["message"].endswith(".webp"):
|
||||
# Convert WebP url to PNG url
|
||||
chat["message"] = chat["message"].replace(".webp", ".png")
|
||||
conversation_updated = True
|
||||
updated_count += 1
|
||||
elif chat["message"].endswith(".png"):
|
||||
# Convert PNG url to WebP url
|
||||
chat["message"] = chat["message"].replace(".png", ".webp")
|
||||
conversation_updated = True
|
||||
updated_count += 1
|
||||
if conversation_updated:
|
||||
conversation.save()
|
||||
|
||||
if updated_count > 0 and options["reverse"]:
|
||||
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format."))
|
||||
elif updated_count > 0:
|
||||
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format."))
|
||||
69
src/khoj/database/migrations/0035_convert_png_to_webp.py
Normal file
69
src/khoj/database/migrations/0035_convert_png_to_webp.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Generated by Django 4.2.10 on 2024-04-13 17:54
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
from django.db import migrations
|
||||
from PIL import Image
|
||||
|
||||
from khoj.utils.helpers import ImageIntentType
|
||||
|
||||
|
||||
def convert_png_images_to_webp(apps, schema_editor):
|
||||
# Get the model from the versioned app registry to ensure the correct version is used
|
||||
Conversations = apps.get_model("database", "Conversation")
|
||||
for conversation in Conversations.objects.all():
|
||||
for chat in conversation.conversation_log["chat"]:
|
||||
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
|
||||
# Decode the base64 encoded PNG image
|
||||
decoded_image = base64.b64decode(chat["message"])
|
||||
|
||||
# Convert images from PNG to WebP format
|
||||
image_io = io.BytesIO(decoded_image)
|
||||
with Image.open(image_io) as png_image:
|
||||
webp_image_io = io.BytesIO()
|
||||
png_image.save(webp_image_io, "WEBP")
|
||||
|
||||
# Encode the WebP image back to base64
|
||||
webp_image_bytes = webp_image_io.getvalue()
|
||||
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
||||
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||
webp_image_io.close()
|
||||
|
||||
# Save the updated conversation history
|
||||
conversation.save()
|
||||
|
||||
|
||||
def convert_webp_images_to_png(apps, schema_editor):
|
||||
# Get the model from the versioned app registry to ensure the correct version is used
|
||||
Conversations = apps.get_model("database", "Conversation")
|
||||
for conversation in Conversations.objects.all():
|
||||
for chat in conversation.conversation_log["chat"]:
|
||||
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
|
||||
# Decode the base64 encoded PNG image
|
||||
decoded_image = base64.b64decode(chat["message"])
|
||||
|
||||
# Convert images from PNG to WebP format
|
||||
image_io = io.BytesIO(decoded_image)
|
||||
with Image.open(image_io) as png_image:
|
||||
webp_image_io = io.BytesIO()
|
||||
png_image.save(webp_image_io, "PNG")
|
||||
|
||||
# Encode the WebP image back to base64
|
||||
webp_image_bytes = webp_image_io.getvalue()
|
||||
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
||||
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
|
||||
webp_image_io.close()
|
||||
|
||||
# Save the updated conversation history
|
||||
conversation.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0034_alter_chatmodeloptions_chat_model"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png),
|
||||
]
|
||||
@@ -4,10 +4,10 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
|
||||
<title>Khoj - Chat</title>
|
||||
|
||||
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
|
||||
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
|
||||
<link rel="apple-touch-icon" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
|
||||
<link rel="manifest" href="/static/khoj.webmanifest?v={{ khoj_version }}">
|
||||
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
|
||||
</head>
|
||||
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
|
||||
<script type="text/javascript" src="/static/assets/markdown-it.min.js?v={{ khoj_version }}"></script>
|
||||
@@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
return referenceButton;
|
||||
}
|
||||
|
||||
function renderMessage(message, by, dt=null, annotations=null, raw=false) {
|
||||
function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
|
||||
let message_time = formatDate(dt ?? new Date());
|
||||
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
|
||||
let formattedMessage = formatHTMLMessage(message, raw);
|
||||
@@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
|
||||
// Append chat message div to chat body
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.appendChild(chatMessage);
|
||||
|
||||
// Scroll to bottom of chat-body element
|
||||
chatBody.scrollTop = chatBody.scrollHeight;
|
||||
if (renderType === "append") {
|
||||
chatBody.appendChild(chatMessage);
|
||||
// Scroll to bottom of chat-body element
|
||||
chatBody.scrollTop = chatBody.scrollHeight;
|
||||
} else if (renderType === "prepend"){
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
chatBody.insertBefore(chatMessage, chatBody.firstChild);
|
||||
} else if (renderType === "return") {
|
||||
return chatMessage;
|
||||
}
|
||||
|
||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||
@@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||
// If no document or online context is provided, render the message as is
|
||||
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||
if (intentType?.includes("text-to-image")) {
|
||||
let imageMarkdown;
|
||||
@@ -244,24 +251,24 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image2") {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = ``;
|
||||
}
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
return renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
return renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
return renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
|
||||
// If document or online context is provided, render the message with its references
|
||||
let references = document.createElement('div');
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
@@ -312,16 +319,17 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image2") {
|
||||
imageMarkdown = ``;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = ``;
|
||||
}
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt, references);
|
||||
return;
|
||||
return renderMessage(imageMarkdown, by, dt, references, false, "return");
|
||||
}
|
||||
|
||||
renderMessage(message, by, dt, references);
|
||||
return renderMessage(message, by, dt, references, false, "return");
|
||||
}
|
||||
|
||||
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
|
||||
@@ -619,6 +627,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
rawResponse += ``;
|
||||
} else if (imageJson.intentType === "text-to-image2") {
|
||||
rawResponse += ``;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = ``;
|
||||
}
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
@@ -1064,7 +1074,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
loadingScreen.appendChild(yellowOrb);
|
||||
chatBody.appendChild(loadingScreen);
|
||||
|
||||
fetch(chatHistoryUrl, { method: "GET" })
|
||||
// Get the most recent 10 chat messages from conversation history
|
||||
fetch(`${chatHistoryUrl}&n=10`, { method: "GET" })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
@@ -1134,11 +1145,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
agentMetadataElement.style.display = "none";
|
||||
}
|
||||
|
||||
const fullChatLog = response.chat || [];
|
||||
// Create a new IntersectionObserver
|
||||
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
|
||||
entries.forEach(entry => {
|
||||
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||
if (entry.isIntersecting) {
|
||||
fetchRemainingChatMessages(chatHistoryUrl);
|
||||
observer.unobserve(entry.target);
|
||||
}
|
||||
});
|
||||
}, {rootMargin: '0px 0px 0px 0px'});
|
||||
|
||||
fullChatLog.forEach(chat_log => {
|
||||
if (chat_log.message != null){
|
||||
renderMessageWithReference(
|
||||
const fullChatLog = response.chat || [];
|
||||
fullChatLog.forEach((chat_log, index) => {
|
||||
// Render the last 10 messages immediately
|
||||
if (chat_log.message != null) {
|
||||
let messageElement = renderMessageWithReference(
|
||||
chat_log.message,
|
||||
chat_log.by,
|
||||
chat_log.context,
|
||||
@@ -1146,14 +1168,26 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]);
|
||||
chatBody.appendChild(messageElement);
|
||||
|
||||
// When the 4th oldest message is within viewing distance (~60% scroll up)
|
||||
// Fetch the remaining chat messages
|
||||
if (index === 4) {
|
||||
fetchRemainingMessagesObserver.observe(messageElement);
|
||||
}
|
||||
}
|
||||
loadingScreen.style.height = chatBody.scrollHeight + 'px';
|
||||
});
|
||||
|
||||
// Add fade out animation to loading screen and remove it after the animation ends
|
||||
// Scroll to bottom of chat-body element
|
||||
chatBody.scrollTop = chatBody.scrollHeight;
|
||||
|
||||
// Set height of chat-body element to the height of the chat-body-wrapper
|
||||
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
|
||||
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
|
||||
chatBody.style.height = chatBodyWrapperHeight;
|
||||
|
||||
// Add fade out animation to loading screen and remove it after the animation ends
|
||||
setTimeout(() => {
|
||||
loadingScreen.remove();
|
||||
chatBody.classList.remove("relative-position");
|
||||
@@ -1211,6 +1245,66 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
document.getElementById("chat-input").value = query_via_url;
|
||||
chat();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
function fetchRemainingChatMessages(chatHistoryUrl) {
|
||||
// Create a new IntersectionObserver
|
||||
let observer = new IntersectionObserver((entries, observer) => {
|
||||
entries.forEach(entry => {
|
||||
// If the element is in the viewport, render the message and unobserve the element
|
||||
if (entry.isIntersecting) {
|
||||
let chat_log = entry.target.chat_log;
|
||||
let messageElement = renderMessageWithReference(
|
||||
chat_log.message,
|
||||
chat_log.by,
|
||||
chat_log.context,
|
||||
new Date(chat_log.created),
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]
|
||||
);
|
||||
entry.target.replaceWith(messageElement);
|
||||
|
||||
// Remove the observer after the element has been rendered
|
||||
observer.unobserve(entry.target);
|
||||
}
|
||||
});
|
||||
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
|
||||
|
||||
// Fetch remaining chat messages from conversation history
|
||||
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status != "ok") {
|
||||
throw new Error(data.message);
|
||||
}
|
||||
return data.response;
|
||||
})
|
||||
.then(response => {
|
||||
const fullChatLog = response.chat || [];
|
||||
let chatBody = document.getElementById("chat-body");
|
||||
fullChatLog
|
||||
.reverse()
|
||||
.forEach(chat_log => {
|
||||
if (chat_log.message != null) {
|
||||
// Create a new element for each chat log
|
||||
let placeholder = document.createElement('div');
|
||||
placeholder.chat_log = chat_log;
|
||||
|
||||
// Insert the message placeholder as the first child of chat body after the welcome message
|
||||
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
|
||||
|
||||
// Observe the element
|
||||
placeholder.style.height = "20px";
|
||||
observer.observe(placeholder);
|
||||
}
|
||||
});
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
return;
|
||||
});
|
||||
}
|
||||
|
||||
function flashStatusInChatInput(message) {
|
||||
|
||||
@@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries):
|
||||
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
|
||||
except ConnectionAbortedError as e:
|
||||
logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
||||
raise e
|
||||
|
||||
@@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries):
|
||||
|
||||
for response in responses:
|
||||
with timer("Processing response", logger=logger):
|
||||
pages_or_databases = response["results"] if response.get("results") else []
|
||||
pages_or_databases = response.get("results", [])
|
||||
|
||||
# Get all pages content
|
||||
for p_or_d in pages_or_databases:
|
||||
@@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries):
|
||||
|
||||
current_entries = []
|
||||
curr_heading = ""
|
||||
for block in content["results"]:
|
||||
for block in content.get("results", []):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == None:
|
||||
@@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries):
|
||||
return f"\n<b>{heading}</b>\n"
|
||||
|
||||
def process_nested_children(self, children, raw_content, block_type=None):
|
||||
results = children["results"] if children.get("results") else []
|
||||
results = children.get("results", [])
|
||||
for child in results:
|
||||
child_type = child.get("type")
|
||||
if child_type == None:
|
||||
|
||||
@@ -30,6 +30,7 @@ def extract_questions_offline(
|
||||
use_history: bool = True,
|
||||
should_extract_questions: bool = True,
|
||||
location_data: LocationData = None,
|
||||
max_prompt_size: int = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
@@ -41,7 +42,7 @@ def extract_questions_offline(
|
||||
return all_questions
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
|
||||
@@ -67,12 +68,14 @@ def extract_questions_offline(
|
||||
location=location,
|
||||
)
|
||||
messages = generate_chatml_messages_with_context(
|
||||
example_questions, model_name=model, loaded_model=offline_chat_model
|
||||
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||
)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
|
||||
response = send_message_to_model_offline(
|
||||
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
||||
@@ -138,7 +141,7 @@ def converse_offline(
|
||||
"""
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
@@ -190,18 +193,18 @@ def converse_offline(
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages, loaded_model=model, stop=stop_phrases, streaming=True
|
||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
@@ -216,9 +219,10 @@ def send_message_to_model_offline(
|
||||
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
streaming=False,
|
||||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
||||
if streaming:
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_device_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||
from llama_cpp.llama import Llama
|
||||
|
||||
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
|
||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
||||
# Initialize Model Parameters
|
||||
# Use n_ctx=0 to get context size from the model
|
||||
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||
|
||||
# Decide whether to load model to GPU or CPU
|
||||
@@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||
model_path = load_model_from_cache(repo_id, filename)
|
||||
chat_model = None
|
||||
try:
|
||||
if model_path:
|
||||
chat_model = Llama(model_path, **kwargs)
|
||||
else:
|
||||
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
except:
|
||||
# Load model on CPU if GPU is not available
|
||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||
if model_path:
|
||||
chat_model = Llama(model_path, **kwargs)
|
||||
else:
|
||||
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
|
||||
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
|
||||
# Now load the model with context size set based on:
|
||||
# 1. context size supported by model and
|
||||
# 2. configured size or machine (V)RAM
|
||||
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
|
||||
logger.debug(
|
||||
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
|
||||
)
|
||||
return chat_model
|
||||
|
||||
|
||||
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
|
||||
from llama_cpp.llama import Llama
|
||||
|
||||
if model_path:
|
||||
return Llama(model_path, **kwargs)
|
||||
else:
|
||||
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
|
||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||
# Construct the path to the model file in the cache directory
|
||||
repo_org, repo_name = repo_id.split("/")
|
||||
@@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||
return paths[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
||||
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
|
||||
if configured_max_tokens:
|
||||
return min(configured_max_tokens, model_context_window)
|
||||
else:
|
||||
return min(vram_based_n_ctx, model_context_window)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
@@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
|
||||
tokenizer_name=None,
|
||||
):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
||||
try:
|
||||
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
||||
except:
|
||||
max_prompt_size = 2000
|
||||
logger.warning(
|
||||
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
||||
)
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
if not max_prompt_size:
|
||||
if loaded_model:
|
||||
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
||||
else:
|
||||
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
||||
|
||||
# Scale lookback turns proportional to max prompt size supported by model
|
||||
lookback_turns = max_prompt_size // 750
|
||||
@@ -187,7 +186,7 @@ def truncate_messages(
|
||||
max_prompt_size,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
tokenizer_name="hf-internal-testing/llama-tokenizer",
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
@@ -197,15 +196,11 @@ def truncate_messages(
|
||||
elif model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
try:
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
except:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
except:
|
||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
logger.warning(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
|
||||
# Extract system message from messages
|
||||
|
||||
@@ -289,9 +289,7 @@ async def extract_references_and_questions(
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||
logger.warning(
|
||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||
)
|
||||
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
# Extract filter terms from user message
|
||||
@@ -317,8 +315,9 @@ async def extract_references_and_questions(
|
||||
using_offline_chat = True
|
||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
||||
chat_model = default_offline_llm.chat_model
|
||||
max_tokens = default_offline_llm.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
@@ -328,6 +327,7 @@ async def extract_references_and_questions(
|
||||
conversation_log=meta_log,
|
||||
should_extract_questions=True,
|
||||
location_data=location_data,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
)
|
||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
|
||||
@@ -76,6 +76,7 @@ def chat_history(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
conversation_id: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
validate_conversation_config()
|
||||
@@ -109,6 +110,13 @@ def chat_history(
|
||||
}
|
||||
)
|
||||
|
||||
# Get latest N messages if N > 0
|
||||
if n > 0:
|
||||
meta_log["chat"] = meta_log["chat"][-n:]
|
||||
# Else return all messages except latest N
|
||||
else:
|
||||
meta_log["chat"] = meta_log["chat"][:n]
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
@@ -425,8 +433,7 @@ async def websocket_endpoint(
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
intent_type = "text-to-image"
|
||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||
q,
|
||||
user,
|
||||
meta_log,
|
||||
@@ -445,9 +452,6 @@ async def websocket_endpoint(
|
||||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
if image_url:
|
||||
intent_type = "text-to-image2"
|
||||
image = image_url
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
@@ -621,17 +625,13 @@ async def chat(
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
intent_type = "text-to-image"
|
||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||
)
|
||||
if image is None:
|
||||
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
if image_url:
|
||||
intent_type = "text-to-image2"
|
||||
image = image_url
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -18,6 +20,7 @@ from typing import (
|
||||
|
||||
import openai
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from PIL import Image
|
||||
from starlette.authentication import has_required_scope
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||
@@ -46,6 +49,7 @@ from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
ImageIntentType,
|
||||
is_none_or_empty,
|
||||
is_valid_url,
|
||||
log_telemetry,
|
||||
@@ -79,9 +83,10 @@ async def is_ready_to_chat(user: KhojUser):
|
||||
|
||||
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
||||
chat_model = user_conversation_config.chat_model
|
||||
max_tokens = user_conversation_config.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
logger.info("Loading Offline Chat Model...")
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
return True
|
||||
|
||||
ready = has_openai_config or has_offline_config
|
||||
@@ -382,10 +387,11 @@ async def send_message_to_model_wrapper(
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
@@ -452,7 +458,9 @@ def generate_chat_response(
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response = converse_offline(
|
||||
@@ -508,18 +516,19 @@ async def text_to_image(
|
||||
references: List[str],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
|
||||
) -> Tuple[Optional[str], int, Optional[str], str]:
|
||||
status_code = 200
|
||||
image = None
|
||||
response = None
|
||||
image_url = None
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
|
||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
message = "Failed to generate image. Setup image generation on the server."
|
||||
return image, status_code, message, image_url
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
logger.info("Generating image with OpenAI")
|
||||
text2image_model = text_to_image_config.model_name
|
||||
@@ -550,21 +559,38 @@ async def text_to_image(
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
|
||||
with timer("Convert image to webp", logger):
|
||||
# Convert png to webp for faster loading
|
||||
decoded_image = base64.b64decode(image)
|
||||
image_io = io.BytesIO(decoded_image)
|
||||
png_image = Image.open(image_io)
|
||||
webp_image_io = io.BytesIO()
|
||||
png_image.save(webp_image_io, "WEBP")
|
||||
webp_image_bytes = webp_image_io.getvalue()
|
||||
webp_image_io.close()
|
||||
image_io.close()
|
||||
|
||||
with timer("Upload image to S3", logger):
|
||||
image_url = upload_image(image, user.uuid)
|
||||
return image, status_code, improved_image_prompt, image_url
|
||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||
if image_url:
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||
else:
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||
|
||||
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
status_code = e.status_code # type: ignore
|
||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||
return image, status_code, message, image_url
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
else:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||
status_code = e.status_code # type: ignore
|
||||
return image, status_code, message, image_url
|
||||
return image, status_code, response, image_url
|
||||
return image_url or image, status_code, message, intent_type.value
|
||||
return image_url or image, status_code, response, intent_type.value
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
@@ -17,16 +16,15 @@ if aws_enabled:
|
||||
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
|
||||
|
||||
|
||||
def upload_image(image: str, user_id: uuid.UUID):
|
||||
def upload_image(image: bytes, user_id: uuid.UUID):
|
||||
"""Upload the image to the S3 bucket"""
|
||||
if not aws_enabled:
|
||||
logger.info("AWS is not enabled. Skipping image upload")
|
||||
return None
|
||||
|
||||
decoded_image = base64.b64decode(image)
|
||||
image_key = f"{user_id}/{uuid.uuid4()}.png"
|
||||
image_key = f"{user_id}/{uuid.uuid4()}.webp"
|
||||
try:
|
||||
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
|
||||
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
|
||||
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
|
||||
return url
|
||||
except Exception as e:
|
||||
|
||||
@@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
|
||||
|
||||
|
||||
class OfflineChatProcessorModel:
|
||||
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
|
||||
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
|
||||
self.chat_model = chat_model
|
||||
self.loaded_model = None
|
||||
try:
|
||||
self.loaded_model = download_model(self.chat_model)
|
||||
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
|
||||
except ValueError as e:
|
||||
self.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
|
||||
@@ -17,6 +17,7 @@ from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from magika import Magika
|
||||
@@ -233,6 +234,10 @@ def get_server_id():
|
||||
return server_id
|
||||
|
||||
|
||||
def telemetry_disabled(app_config: AppConfig):
|
||||
return not app_config or not app_config.should_log_telemetry
|
||||
|
||||
|
||||
def log_telemetry(
|
||||
telemetry_type: str,
|
||||
api: str = None,
|
||||
@@ -242,7 +247,7 @@ def log_telemetry(
|
||||
):
|
||||
"""Log basic app usage telemetry like client, os, api called"""
|
||||
# Do not log usage telemetry, if telemetry is disabled via app config
|
||||
if not app_config or not app_config.should_log_telemetry:
|
||||
if telemetry_disabled(app_config):
|
||||
return []
|
||||
|
||||
if properties.get("server_id") is None:
|
||||
@@ -267,6 +272,17 @@ def log_telemetry(
|
||||
return request_body
|
||||
|
||||
|
||||
def get_device_memory() -> int:
|
||||
"""Get device memory in GB"""
|
||||
device = get_device()
|
||||
if device.type == "cuda":
|
||||
return torch.cuda.get_device_properties(device).total_memory
|
||||
elif device.type == "mps":
|
||||
return torch.mps.driver_allocated_memory()
|
||||
else:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
|
||||
def get_device() -> torch.device:
|
||||
"""Get device to run model on"""
|
||||
if torch.cuda.is_available():
|
||||
@@ -313,6 +329,20 @@ mode_descriptions_for_llm = {
|
||||
}
|
||||
|
||||
|
||||
class ImageIntentType(Enum):
|
||||
"""
|
||||
Chat message intent by Khoj for image responses.
|
||||
Marks the schema used to reference image in chat messages
|
||||
"""
|
||||
|
||||
# Images as Inline PNG
|
||||
TEXT_TO_IMAGE = "text-to-image"
|
||||
# Images as URLs
|
||||
TEXT_TO_IMAGE2 = "text-to-image2"
|
||||
# Images as Inline WebP
|
||||
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
|
||||
|
||||
|
||||
def generate_random_name():
|
||||
# List of adjectives and nouns to choose from
|
||||
adjectives = [
|
||||
|
||||
Reference in New Issue
Block a user