Improve Chat Page Load Perf, Offline Chat Perf and Miscellaneous Fixes (#703)

### Store Generated Images as WebP 
- 78bac4ae Add migration script to convert PNG to WebP references in database
- c6e84436 Update clients to support rendering webp images inline
- d21f22ff Store Khoj generated images as webp instead of png for faster loading

### Lazy Fetch Chat Messages to Improve Time, Data to First Render
This is especially helpful for long conversations with lots of images
- 128829c4 Render latest msgs on chat session load. Fetch, render rest as they near viewport
- 9e558577 Support getting latest N chat messages via chat history API

### Intelligently set Context Window of Offline Chat to Improve Performance
- 4977b551 Use offline chat prompt config to set context window of loaded chat model

### Fixes
- 148923c1 Fix to raise error on hitting rate limit during Github indexing
- b8bc6bee Always remove loading animation on Desktop app if can't login to server
- 38250705 Fix `get_user_photo` to only return photo, not user name from DB

### Miscellaneous Improvements
- 689202e0 Update recommended CMAKE flag to enable using CUDA on linux in Docs
- b820daf3 Makes logs less noisy
This commit is contained in:
Debanjum
2024-04-15 18:34:29 +05:30
committed by GitHub
23 changed files with 510 additions and 131 deletions

View File

@@ -134,7 +134,7 @@ python -m pip install khoj-assistant
# CPU # CPU
python -m pip install khoj-assistant python -m pip install khoj-assistant
# NVIDIA (CUDA) GPU # NVIDIA (CUDA) GPU
CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
# AMD (ROCm) GPU # AMD (ROCm) GPU
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
# VULCAN GPU # VULCAN GPU

View File

@@ -78,6 +78,7 @@ dependencies = [
"phonenumbers == 8.13.27", "phonenumbers == 8.13.27",
"markdownify ~= 0.11.6", "markdownify ~= 0.11.6",
"websockets == 12.0", "websockets == 12.0",
"psutil >= 5.8.0",
] ]
dynamic = ["version"] dynamic = ["version"]
@@ -105,7 +106,6 @@ dev = [
"pytest-asyncio == 0.21.1", "pytest-asyncio == 0.21.1",
"freezegun >= 1.2.0", "freezegun >= 1.2.0",
"factory-boy >= 3.2.1", "factory-boy >= 3.2.1",
"psutil >= 5.8.0",
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",

View File

@@ -4,9 +4,9 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Chat</title> <title>Khoj - Chat</title>
<link rel="stylesheet" href="./assets/khoj.css">
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png"> <link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
<link rel="manifest" href="/static/khoj.webmanifest"> <link rel="manifest" href="/static/khoj.webmanifest">
<link rel="stylesheet" href="./assets/khoj.css">
</head> </head>
<script type="text/javascript" src="./assets/markdown-it.min.js"></script> <script type="text/javascript" src="./assets/markdown-it.min.js"></script>
<script src="./utils.js"></script> <script src="./utils.js"></script>
@@ -130,7 +130,7 @@
return referenceButton; return referenceButton;
} }
function renderMessage(message, by, dt=null, annotations=null, raw=false) { function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
let message_time = formatDate(dt ?? new Date()); let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
let formattedMessage = formatHTMLMessage(message, raw); let formattedMessage = formatHTMLMessage(message, raw);
@@ -153,10 +153,15 @@
// Append chat message div to chat body // Append chat message div to chat body
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
if (renderType === "append") {
chatBody.appendChild(chatMessage); chatBody.appendChild(chatMessage);
// Scroll to bottom of chat-body element // Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight; chatBody.scrollTop = chatBody.scrollHeight;
} else if (renderType === "prepend") {
chatBody.insertBefore(chatMessage, chatBody.firstChild);
} else if (renderType === "return") {
return chatMessage;
}
let chatBodyWrapper = document.getElementById("chat-body-wrapper"); let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight; chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
@@ -207,6 +212,7 @@
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
// If no document or online context is provided, render the message as is
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType?.includes("text-to-image")) { if (intentType?.includes("text-to-image")) {
let imageMarkdown; let imageMarkdown;
@@ -214,30 +220,29 @@
imageMarkdown = `![](data:image/png;base64,${message})`; imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") { } else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`; imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} }
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
renderMessage(imageMarkdown, by, dt); return renderMessage(imageMarkdown, by, dt, null, false, "return");
return;
} }
renderMessage(message, by, dt); return renderMessage(message, by, dt, null, false, "return");
return;
} }
if (context == null && onlineContext == null) { if (context == null && onlineContext == null) {
renderMessage(message, by, dt); return renderMessage(message, by, dt, null, false, "return");
return;
} }
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
renderMessage(message, by, dt); return renderMessage(message, by, dt, null, false, "return");
return;
} }
// If document or online context is provided, render the message with its references
let references = document.createElement('div'); let references = document.createElement('div');
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
@@ -288,16 +293,17 @@
imageMarkdown = `![](data:image/png;base64,${message})`; imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") { } else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`; imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} }
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
renderMessage(imageMarkdown, by, dt, references); return renderMessage(imageMarkdown, by, dt, references, false, "return");
return;
} }
renderMessage(message, by, dt, references); return renderMessage(message, by, dt, references, false, "return");
} }
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) { function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
@@ -509,6 +515,8 @@
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") { } else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`; rawResponse += `![${query}](${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image-v3") {
rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
} }
const inferredQueries = responseAsJson.inferredQueries?.[0]; const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) { if (inferredQueries) {
@@ -671,7 +679,7 @@
let firstRunSetupMessageRendered = false; let firstRunSetupMessageRendered = false;
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = ""; chatBody.innerHTML = "";
let chatHistoryUrl = `/api/chat/history?client=desktop`; let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`;
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
} }
@@ -683,7 +691,8 @@
loadingScreen.appendChild(yellowOrb); loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen); chatBody.appendChild(loadingScreen);
fetch(`${hostURL}${chatHistoryUrl}`, { headers }) // Get the most recent 10 chat messages from conversation history
fetch(`${chatHistoryUrl}&n=10`, { headers })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.detail) { if (data.detail) {
@@ -703,11 +712,21 @@
chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationId = response.conversation_id;
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
const fullChatLog = response.chat || []; // Create a new IntersectionObserver
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, fetch the remaining message and unobserve the element
if (entry.isIntersecting) {
fetchRemainingChatMessages(chatHistoryUrl);
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 0px 0px'});
fullChatLog.forEach(chat_log => { const fullChatLog = response.chat || [];
fullChatLog.forEach((chat_log, index) => {
if (chat_log.message != null) { if (chat_log.message != null) {
renderMessageWithReference( let messageElement = renderMessageWithReference(
chat_log.message, chat_log.message,
chat_log.by, chat_log.by,
chat_log.context, chat_log.context,
@@ -715,10 +734,25 @@
chat_log.onlineContext, chat_log.onlineContext,
chat_log.intent?.type, chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]); chat_log.intent?.["inferred-queries"]);
chatBody.appendChild(messageElement);
// When the 4th oldest message is within viewing distance (~60% scrolled up)
// Fetch the remaining chat messages
if (index === 4) {
fetchRemainingMessagesObserver.observe(messageElement);
}
} }
loadingScreen.style.height = chatBody.scrollHeight + 'px'; loadingScreen.style.height = chatBody.scrollHeight + 'px';
}) })
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
// Set height of chat-body element to the height of the chat-body-wrapper
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight;
// Add fade out animation to loading screen and remove it after the animation ends // Add fade out animation to loading screen and remove it after the animation ends
fadeOutLoadingAnimation(loadingScreen); fadeOutLoadingAnimation(loadingScreen);
}) })
@@ -726,8 +760,8 @@
// If the server returns a 500 error with detail, render a setup hint. // If the server returns a 500 error with detail, render a setup hint.
if (!firstRunSetupMessageRendered) { if (!firstRunSetupMessageRendered) {
renderFirstRunSetupMessage(); renderFirstRunSetupMessage();
fadeOutLoadingAnimation(loadingScreen);
} }
fadeOutLoadingAnimation(loadingScreen);
return; return;
}); });
@@ -778,6 +812,65 @@
} }
} }
function fetchRemainingChatMessages(chatHistoryUrl) {
// Create a new IntersectionObserver
let observer = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, render the message and unobserve the element
if (entry.isIntersecting) {
let chat_log = entry.target.chat_log;
let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]
);
entry.target.replaceWith(messageElement);
// Remove the observer after the element has been rendered
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
// Fetch remaining chat messages from conversation history
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.status != "ok") {
throw new Error(data.message);
}
return data.response;
})
.then(response => {
const fullChatLog = response.chat || [];
let chatBody = document.getElementById("chat-body");
fullChatLog
.reverse()
.forEach(chat_log => {
if (chat_log.message != null) {
// Create a new element for each chat log
let placeholder = document.createElement('div');
placeholder.chat_log = chat_log;
// Insert the message placeholder as the first child of chat body after the welcome message
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
// Observe the element
placeholder.style.height = "20px";
observer.observe(placeholder);
}
});
})
.catch(err => {
console.log(err);
return;
});
}
function fadeOutLoadingAnimation(loadingScreen) { function fadeOutLoadingAnimation(loadingScreen) {
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
let chatBodyWrapper = document.getElementById("chat-body-wrapper"); let chatBodyWrapper = document.getElementById("chat-body-wrapper");

View File

@@ -156,6 +156,8 @@ export class KhojChatModal extends Modal {
imageMarkdown = `![](data:image/png;base64,${message})`; imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") { } else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`; imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} }
if (inferredQueries) { if (inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:"; imageMarkdown += "\n\n**Inferred Query**:";
@@ -429,6 +431,8 @@ export class KhojChatModal extends Modal {
responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`; responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") { } else if (responseAsJson.intentType === "text-to-image2") {
responseText += `![${query}](${responseAsJson.image})`; responseText += `![${query}](${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image-v3") {
responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
} }
const inferredQuery = responseAsJson.inferredQueries?.[0]; const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {

View File

@@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -232,6 +232,9 @@ def configure_server(
state.search_models = configure_search(state.search_models, state.config.search_type) state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent() setup_default_agent()
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
logger.info(message)
if not init: if not init:
initialize_content(regenerate, search_type, user) initialize_content(regenerate, search_type, user)
@@ -329,9 +332,7 @@ def configure_search_types():
@schedule.repeat(schedule.every(2).minutes) @schedule.repeat(schedule.every(2).minutes)
def upload_telemetry(): def upload_telemetry():
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry: if telemetry_disabled(state.config.app) or not state.telemetry:
message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
logger.debug(message)
return return
try: try:

View File

@@ -197,9 +197,6 @@ def get_user_name(user: KhojUser):
def get_user_photo(user: KhojUser): def get_user_photo(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
return full_name
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile: if google_profile:
return google_profile.picture return google_profile.picture

View File

@@ -23,6 +23,7 @@ from khoj.database.models import (
TextToImageModelConfig, TextToImageModelConfig,
UserSearchModelConfig, UserSearchModelConfig,
) )
from khoj.utils.helpers import ImageIntentType
class KhojUserAdmin(UserAdmin): class KhojUserAdmin(UserAdmin):
@@ -114,9 +115,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj" log["by"] == "khoj"
and log["intent"] and log["intent"]
and log["intent"]["type"] and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image" and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
): ):
log["message"] = "image redacted for space" log["message"] = "inline image redacted for space"
chat_log[idx] = log chat_log[idx] = log
modified_log["chat"] = chat_log modified_log["chat"] = chat_log
@@ -154,9 +158,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj" log["by"] == "khoj"
and log["intent"] and log["intent"]
and log["intent"]["type"] and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image" and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
): ):
updated_log["message"] = "image redacted for space" updated_log["message"] = "inline image redacted for space"
chat_log[idx] = updated_log chat_log[idx] = updated_log
return_log["chat"] = chat_log return_log["chat"] = chat_log

View File

View File

@@ -0,0 +1,40 @@
from django.core.management.base import BaseCommand
from khoj.database.models import Conversation
from khoj.utils.helpers import ImageIntentType
class Command(BaseCommand):
help = "Convert all images to WebP format or reverse."
def add_arguments(self, parser):
# Add a new argument 'reverse' to the command
parser.add_argument(
"--reverse",
action="store_true",
help="Convert from WebP to PNG instead of PNG to WebP",
)
def handle(self, *args, **options):
updated_count = 0
for conversation in Conversation.objects.all():
conversation_updated = False
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
if options["reverse"] and chat["message"].endswith(".webp"):
# Convert WebP url to PNG url
chat["message"] = chat["message"].replace(".webp", ".png")
conversation_updated = True
updated_count += 1
elif chat["message"].endswith(".png"):
# Convert PNG url to WebP url
chat["message"] = chat["message"].replace(".png", ".webp")
conversation_updated = True
updated_count += 1
if conversation_updated:
conversation.save()
if updated_count > 0 and options["reverse"]:
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format."))
elif updated_count > 0:
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format."))

View File

@@ -0,0 +1,69 @@
# Generated by Django 4.2.10 on 2024-04-13 17:54
import base64
import io
from django.db import migrations
from PIL import Image
from khoj.utils.helpers import ImageIntentType
def convert_png_images_to_webp(apps, schema_editor):
# Get the model from the versioned app registry to ensure the correct version is used
Conversations = apps.get_model("database", "Conversation")
for conversation in Conversations.objects.all():
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
# Decode the base64 encoded PNG image
decoded_image = base64.b64decode(chat["message"])
# Convert images from PNG to WebP format
image_io = io.BytesIO(decoded_image)
with Image.open(image_io) as png_image:
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
chat["message"] = base64.b64encode(webp_image_bytes).decode()
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
webp_image_io.close()
# Save the updated conversation history
conversation.save()
def convert_webp_images_to_png(apps, schema_editor):
# Get the model from the versioned app registry to ensure the correct version is used
Conversations = apps.get_model("database", "Conversation")
for conversation in Conversations.objects.all():
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
# Decode the base64 encoded PNG image
decoded_image = base64.b64decode(chat["message"])
# Convert images from PNG to WebP format
image_io = io.BytesIO(decoded_image)
with Image.open(image_io) as png_image:
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "PNG")
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
chat["message"] = base64.b64encode(webp_image_bytes).decode()
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
webp_image_io.close()
# Save the updated conversation history
conversation.save()
class Migration(migrations.Migration):
dependencies = [
("database", "0034_alter_chatmodeloptions_chat_model"),
]
operations = [
migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png),
]

View File

@@ -4,10 +4,10 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Chat</title> <title>Khoj - Chat</title>
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}"> <link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
<link rel="apple-touch-icon" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}"> <link rel="apple-touch-icon" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
<link rel="manifest" href="/static/khoj.webmanifest?v={{ khoj_version }}"> <link rel="manifest" href="/static/khoj.webmanifest?v={{ khoj_version }}">
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
</head> </head>
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script> <script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
<script type="text/javascript" src="/static/assets/markdown-it.min.js?v={{ khoj_version }}"></script> <script type="text/javascript" src="/static/assets/markdown-it.min.js?v={{ khoj_version }}"></script>
@@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co
return referenceButton; return referenceButton;
} }
function renderMessage(message, by, dt=null, annotations=null, raw=false) { function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
let message_time = formatDate(dt ?? new Date()); let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
let formattedMessage = formatHTMLMessage(message, raw); let formattedMessage = formatHTMLMessage(message, raw);
@@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co
// Append chat message div to chat body // Append chat message div to chat body
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
if (renderType === "append") {
chatBody.appendChild(chatMessage); chatBody.appendChild(chatMessage);
// Scroll to bottom of chat-body element // Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight; chatBody.scrollTop = chatBody.scrollHeight;
} else if (renderType === "prepend"){
let chatBody = document.getElementById("chat-body");
chatBody.insertBefore(chatMessage, chatBody.firstChild);
} else if (renderType === "return") {
return chatMessage;
}
let chatBodyWrapper = document.getElementById("chat-body-wrapper"); let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight; chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
@@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
// If no document or online context is provided, render the message as is
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType?.includes("text-to-image")) { if (intentType?.includes("text-to-image")) {
let imageMarkdown; let imageMarkdown;
@@ -244,24 +251,24 @@ To get started, just start typing below. You can also type / to see a list of co
imageMarkdown = `![](data:image/png;base64,${message})`; imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") { } else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`; imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} }
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
renderMessage(imageMarkdown, by, dt); return renderMessage(imageMarkdown, by, dt, null, false, "return");
return;
} }
renderMessage(message, by, dt); return renderMessage(message, by, dt, null, false, "return");
return;
} }
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
renderMessage(message, by, dt); return renderMessage(message, by, dt, null, false, "return");
return;
} }
// If document or online context is provided, render the message with its references
let references = document.createElement('div'); let references = document.createElement('div');
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
@@ -312,16 +319,17 @@ To get started, just start typing below. You can also type / to see a list of co
imageMarkdown = `![](data:image/png;base64,${message})`; imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") { } else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`; imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
} }
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
} }
renderMessage(imageMarkdown, by, dt, references); return renderMessage(imageMarkdown, by, dt, references, false, "return");
return;
} }
renderMessage(message, by, dt, references); return renderMessage(message, by, dt, references, false, "return");
} }
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) { function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
@@ -619,6 +627,8 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image2") { } else if (imageJson.intentType === "text-to-image2") {
rawResponse += `![generated_image](${imageJson.image})`; rawResponse += `![generated_image](${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image-v3") {
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
} }
if (inferredQuery) { if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@@ -1064,7 +1074,8 @@ To get started, just start typing below. You can also type / to see a list of co
loadingScreen.appendChild(yellowOrb); loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen); chatBody.appendChild(loadingScreen);
fetch(chatHistoryUrl, { method: "GET" }) // Get the most recent 10 chat messages from conversation history
fetch(`${chatHistoryUrl}&n=10`, { method: "GET" })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.detail) { if (data.detail) {
@@ -1134,11 +1145,22 @@ To get started, just start typing below. You can also type / to see a list of co
agentMetadataElement.style.display = "none"; agentMetadataElement.style.display = "none";
} }
const fullChatLog = response.chat || []; // Create a new IntersectionObserver
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, fetch the remaining message and unobserve the element
if (entry.isIntersecting) {
fetchRemainingChatMessages(chatHistoryUrl);
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 0px 0px'});
fullChatLog.forEach(chat_log => { const fullChatLog = response.chat || [];
fullChatLog.forEach((chat_log, index) => {
// Render the last 10 messages immediately
if (chat_log.message != null) { if (chat_log.message != null) {
renderMessageWithReference( let messageElement = renderMessageWithReference(
chat_log.message, chat_log.message,
chat_log.by, chat_log.by,
chat_log.context, chat_log.context,
@@ -1146,14 +1168,26 @@ To get started, just start typing below. You can also type / to see a list of co
chat_log.onlineContext, chat_log.onlineContext,
chat_log.intent?.type, chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]); chat_log.intent?.["inferred-queries"]);
chatBody.appendChild(messageElement);
// When the 4th oldest message is within viewing distance (~60% scroll up)
// Fetch the remaining chat messages
if (index === 4) {
fetchRemainingMessagesObserver.observe(messageElement);
}
} }
loadingScreen.style.height = chatBody.scrollHeight + 'px'; loadingScreen.style.height = chatBody.scrollHeight + 'px';
}); });
// Add fade out animation to loading screen and remove it after the animation ends // Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
// Set height of chat-body element to the height of the chat-body-wrapper
let chatBodyWrapper = document.getElementById("chat-body-wrapper"); let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight; let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight; chatBody.style.height = chatBodyWrapperHeight;
// Add fade out animation to loading screen and remove it after the animation ends
setTimeout(() => { setTimeout(() => {
loadingScreen.remove(); loadingScreen.remove();
chatBody.classList.remove("relative-position"); chatBody.classList.remove("relative-position");
@@ -1211,6 +1245,66 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = query_via_url; document.getElementById("chat-input").value = query_via_url;
chat(); chat();
} }
}
function fetchRemainingChatMessages(chatHistoryUrl) {
// Create a new IntersectionObserver
let observer = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, render the message and unobserve the element
if (entry.isIntersecting) {
let chat_log = entry.target.chat_log;
let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]
);
entry.target.replaceWith(messageElement);
// Remove the observer after the element has been rendered
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
// Fetch remaining chat messages from conversation history
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.status != "ok") {
throw new Error(data.message);
}
return data.response;
})
.then(response => {
const fullChatLog = response.chat || [];
let chatBody = document.getElementById("chat-body");
fullChatLog
.reverse()
.forEach(chat_log => {
if (chat_log.message != null) {
// Create a new element for each chat log
let placeholder = document.createElement('div');
placeholder.chat_log = chat_log;
// Insert the message placeholder as the first child of chat body after the welcome message
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
// Observe the element
placeholder.style.height = "20px";
observer.observe(placeholder);
}
});
})
.catch(err => {
console.log(err);
return;
});
} }
function flashStatusInChatInput(message) { function flashStatusInChatInput(message) {

View File

@@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries):
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo) markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
except ConnectionAbortedError as e: except ConnectionAbortedError as e:
logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}") logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}")
raise e
except Exception as e: except Exception as e:
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True) logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e raise e

View File

@@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries):
for response in responses: for response in responses:
with timer("Processing response", logger=logger): with timer("Processing response", logger=logger):
pages_or_databases = response["results"] if response.get("results") else [] pages_or_databases = response.get("results", [])
# Get all pages content # Get all pages content
for p_or_d in pages_or_databases: for p_or_d in pages_or_databases:
@@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries):
current_entries = [] current_entries = []
curr_heading = "" curr_heading = ""
for block in content["results"]: for block in content.get("results", []):
block_type = block.get("type") block_type = block.get("type")
if block_type == None: if block_type == None:
@@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries):
return f"\n<b>{heading}</b>\n" return f"\n<b>{heading}</b>\n"
def process_nested_children(self, children, raw_content, block_type=None): def process_nested_children(self, children, raw_content, block_type=None):
results = children["results"] if children.get("results") else [] results = children.get("results", [])
for child in results: for child in results:
child_type = child.get("type") child_type = child.get("type")
if child_type == None: if child_type == None:

View File

@@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True, use_history: bool = True,
should_extract_questions: bool = True, should_extract_questions: bool = True,
location_data: LocationData = None, location_data: LocationData = None,
max_prompt_size: int = None,
) -> List[str]: ) -> List[str]:
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@@ -41,7 +42,7 @@ def extract_questions_offline(
return all_questions return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
@@ -67,12 +68,14 @@ def extract_questions_offline(
location=location, location=location,
) )
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(
example_questions, model_name=model, loaded_model=offline_chat_model example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
) )
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model) response = send_message_to_model_offline(
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
finally: finally:
state.chat_lock.release() state.chat_lock.release()
@@ -138,7 +141,7 @@ def converse_offline(
""" """
# Initialize Variables # Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item}" for item in references}) compiled_references_message = "\n\n".join({f"{item}" for item in references})
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
@@ -190,18 +193,18 @@ def converse_offline(
) )
g = ThreadedGenerator(references, online_results, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model)) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t.start() t.start()
return g return g
def llm_thread(g, messages: List[ChatMessage], model: Any): def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response_iterator = send_message_to_model_offline( response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, streaming=True messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
) )
for response in response_iterator: for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", "")) g.send(response["choices"][0]["delta"].get("content", ""))
@@ -216,9 +219,10 @@ def send_message_to_model_offline(
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
streaming=False, streaming=False,
stop=[], stop=[],
max_prompt_size: int = None,
): ):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
if streaming: if streaming:

View File

@@ -1,18 +1,19 @@
import glob import glob
import logging import logging
import math
import os import os
from huggingface_hub.constants import HF_HUB_CACHE from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_device_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
from llama_cpp.llama import Llama # Initialize Model Parameters
# Use n_ctx=0 to get context size from the model
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False} kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
# Decide whether to load model to GPU or CPU # Decide whether to load model to GPU or CPU
@@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
model_path = load_model_from_cache(repo_id, filename) model_path = load_model_from_cache(repo_id, filename)
chat_model = None chat_model = None
try: try:
if model_path: chat_model = load_model(model_path, repo_id, filename, kwargs)
chat_model = Llama(model_path, **kwargs)
else:
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
except: except:
# Load model on CPU if GPU is not available # Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu" kwargs["n_gpu_layers"], device = 0, "cpu"
if model_path: chat_model = load_model(model_path, repo_id, filename, kwargs)
chat_model = Llama(model_path, **kwargs)
else:
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}") # Now load the model with context size set based on:
# 1. context size supported by model and
# 2. configured size or machine (V)RAM
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
)
return chat_model return chat_model
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
from llama_cpp.llama import Llama
if model_path:
return Llama(model_path, **kwargs)
else:
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory # Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/") repo_org, repo_name = repo_id.split("/")
@@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
return paths[0] return paths[0]
else: else:
return None return None
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
if configured_max_tokens:
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)

View File

@@ -1,5 +1,6 @@
import json import json
import logging import logging
import math
import queue import queue
from datetime import datetime from datetime import datetime
from time import perf_counter from time import perf_counter
@@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
tokenizer_name=None, tokenizer_name=None,
): ):
"""Generate messages for ChatGPT with context from previous conversation""" """Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config, pre-configured for model or to default prompt size # Set max prompt size from user config or based on pre-configured for model and machine specs
try: if not max_prompt_size:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] if loaded_model:
except: max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
max_prompt_size = 2000 else:
logger.warning( max_prompt_size = model_to_prompt_size.get(model_name, 2000)
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Scale lookback turns proportional to max prompt size supported by model # Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750 lookback_turns = max_prompt_size // 750
@@ -187,7 +186,7 @@ def truncate_messages(
max_prompt_size, max_prompt_size,
model_name: str, model_name: str,
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
tokenizer_name=None, tokenizer_name="hf-internal-testing/llama-tokenizer",
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
@@ -197,15 +196,11 @@ def truncate_messages(
elif model_name.startswith("gpt-"): elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
else: else:
try:
encoder = download_model(model_name).tokenizer() encoder = download_model(model_name).tokenizer()
except: except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) encoder = AutoTokenizer.from_pretrained(tokenizer_name)
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
logger.warning( logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
) )
# Extract system message from messages # Extract system message from messages

View File

@@ -289,9 +289,7 @@ async def extract_references_and_questions(
return compiled_references, inferred_queries, q return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.warning( logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
return compiled_references, inferred_queries, q return compiled_references, inferred_queries, q
# Extract filter terms from user message # Extract filter terms from user message
@@ -317,8 +315,9 @@ async def extract_references_and_questions(
using_offline_chat = True using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm() default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model chat_model = default_offline_llm.chat_model
max_tokens = default_offline_llm.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
@@ -328,6 +327,7 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
should_extract_questions=True, should_extract_questions=True,
location_data=location_data, location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size,
) )
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config()

View File

@@ -76,6 +76,7 @@ def chat_history(
request: Request, request: Request,
common: CommonQueryParams, common: CommonQueryParams,
conversation_id: Optional[int] = None, conversation_id: Optional[int] = None,
n: Optional[int] = None,
): ):
user = request.user.object user = request.user.object
validate_conversation_config() validate_conversation_config()
@@ -109,6 +110,13 @@ def chat_history(
} }
) )
# Get latest N messages if N > 0
if n > 0:
meta_log["chat"] = meta_log["chat"][-n:]
# Else return all messages except latest N
else:
meta_log["chat"] = meta_log["chat"][:n]
update_telemetry_state( update_telemetry_state(
request=request, request=request,
telemetry_type="api", telemetry_type="api",
@@ -425,8 +433,7 @@ async def websocket_endpoint(
api="chat", api="chat",
metadata={"conversation_command": conversation_commands[0].value}, metadata={"conversation_command": conversation_commands[0].value},
) )
intent_type = "text-to-image" image, status_code, improved_image_prompt, intent_type = await text_to_image(
image, status_code, improved_image_prompt, image_url = await text_to_image(
q, q,
user, user,
meta_log, meta_log,
@@ -445,9 +452,6 @@ async def websocket_endpoint(
await send_complete_llm_response(json.dumps(content_obj)) await send_complete_llm_response(json.dumps(content_obj))
continue continue
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
image, image,
@@ -621,17 +625,13 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value}, metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__, **common.__dict__,
) )
intent_type = "text-to-image" image, status_code, improved_image_prompt, intent_type = await text_to_image(
image, status_code, improved_image_prompt, image_url = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
) )
if image is None: if image is None:
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
image, image,

View File

@@ -1,4 +1,6 @@
import asyncio import asyncio
import base64
import io
import json import json
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@@ -18,6 +20,7 @@ from typing import (
import openai import openai
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
@@ -46,6 +49,7 @@ from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ( from khoj.utils.helpers import (
ConversationCommand, ConversationCommand,
ImageIntentType,
is_none_or_empty, is_none_or_empty,
is_valid_url, is_valid_url,
log_telemetry, log_telemetry,
@@ -79,9 +83,10 @@ async def is_ready_to_chat(user: KhojUser):
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline": if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...") logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True return True
ready = has_openai_config or has_offline_config ready = has_openai_config or has_offline_config
@@ -382,10 +387,11 @@ async def send_message_to_model_wrapper(
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
@@ -452,7 +458,9 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline( chat_response = converse_offline(
@@ -508,18 +516,19 @@ async def text_to_image(
references: List[str], references: List[str],
online_results: Dict[str, Any], online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]: ) -> Tuple[Optional[str], int, Optional[str], str]:
status_code = 200 status_code = 200
image = None image = None
response = None response = None
image_url = None image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config: if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
message = "Failed to generate image. Setup image generation on the server." message = "Failed to generate image. Setup image generation on the server."
return image, status_code, message, image_url return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI") logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name text2image_model = text_to_image_config.model_name
@@ -550,21 +559,38 @@ async def text_to_image(
) )
image = response.data[0].b64_json image = response.data[0].b64_json
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
decoded_image = base64.b64decode(image)
image_io = io.BytesIO(decoded_image)
png_image = Image.open(image_io)
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
image_io.close()
with timer("Upload image to S3", logger): with timer("Upload image to S3", logger):
image_url = upload_image(image, user.uuid) image_url = upload_image(webp_image_bytes, user.uuid)
return image, status_code, improved_image_prompt, image_url if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image, status_code, message, image_url return image_url or image, status_code, message, intent_type.value
else: else:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
return image, status_code, message, image_url return image_url or image, status_code, message, intent_type.value
return image, status_code, response, image_url return image_url or image, status_code, response, intent_type.value
class ApiUserRateLimiter: class ApiUserRateLimiter:

View File

@@ -1,4 +1,3 @@
import base64
import logging import logging
import os import os
import uuid import uuid
@@ -17,16 +16,15 @@ if aws_enabled:
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY) s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
def upload_image(image: str, user_id: uuid.UUID): def upload_image(image: bytes, user_id: uuid.UUID):
"""Upload the image to the S3 bucket""" """Upload the image to the S3 bucket"""
if not aws_enabled: if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload") logger.info("AWS is not enabled. Skipping image upload")
return None return None
decoded_image = base64.b64decode(image) image_key = f"{user_id}/{uuid.uuid4()}.webp"
image_key = f"{user_id}/{uuid.uuid4()}.png"
try: try:
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read") s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}" url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
return url return url
except Exception as e: except Exception as e:

View File

@@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
class OfflineChatProcessorModel: class OfflineChatProcessorModel:
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"): def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
self.chat_model = chat_model self.chat_model = chat_model
self.loaded_model = None self.loaded_model = None
try: try:
self.loaded_model = download_model(self.chat_model) self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
except ValueError as e: except ValueError as e:
self.loaded_model = None self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True) logger.error(f"Error while loading offline chat model: {e}", exc_info=True)

View File

@@ -17,6 +17,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import psutil
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from magika import Magika from magika import Magika
@@ -233,6 +234,10 @@ def get_server_id():
return server_id return server_id
def telemetry_disabled(app_config: AppConfig):
return not app_config or not app_config.should_log_telemetry
def log_telemetry( def log_telemetry(
telemetry_type: str, telemetry_type: str,
api: str = None, api: str = None,
@@ -242,7 +247,7 @@ def log_telemetry(
): ):
"""Log basic app usage telemetry like client, os, api called""" """Log basic app usage telemetry like client, os, api called"""
# Do not log usage telemetry, if telemetry is disabled via app config # Do not log usage telemetry, if telemetry is disabled via app config
if not app_config or not app_config.should_log_telemetry: if telemetry_disabled(app_config):
return [] return []
if properties.get("server_id") is None: if properties.get("server_id") is None:
@@ -267,6 +272,17 @@ def log_telemetry(
return request_body return request_body
def get_device_memory() -> int:
"""Get device memory in GB"""
device = get_device()
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
elif device.type == "mps":
return torch.mps.driver_allocated_memory()
else:
return psutil.virtual_memory().total
def get_device() -> torch.device: def get_device() -> torch.device:
"""Get device to run model on""" """Get device to run model on"""
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -313,6 +329,20 @@ mode_descriptions_for_llm = {
} }
class ImageIntentType(Enum):
"""
Chat message intent by Khoj for image responses.
Marks the schema used to reference image in chat messages
"""
# Images as Inline PNG
TEXT_TO_IMAGE = "text-to-image"
# Images as URLs
TEXT_TO_IMAGE2 = "text-to-image2"
# Images as Inline WebP
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
def generate_random_name(): def generate_random_name():
# List of adjectives and nouns to choose from # List of adjectives and nouns to choose from
adjectives = [ adjectives = [