diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx
index 4aa2f960..7b2866f4 100644
--- a/documentation/docs/get-started/setup.mdx
+++ b/documentation/docs/get-started/setup.mdx
@@ -134,7 +134,7 @@ python -m pip install khoj-assistant
# CPU
python -m pip install khoj-assistant
# NVIDIA (CUDA) GPU
- CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
+ CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
# AMD (ROCm) GPU
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
# VULCAN GPU
diff --git a/pyproject.toml b/pyproject.toml
index 0b7483a1..c9c96691 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,6 +78,7 @@ dependencies = [
"phonenumbers == 8.13.27",
"markdownify ~= 0.11.6",
"websockets == 12.0",
+ "psutil >= 5.8.0",
]
dynamic = ["version"]
@@ -105,7 +106,6 @@ dev = [
"pytest-asyncio == 0.21.1",
"freezegun >= 1.2.0",
"factory-boy >= 3.2.1",
- "psutil >= 5.8.0",
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",
diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index 627ce3de..fc7ecc2e 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -4,9 +4,9 @@
Khoj - Chat
+
-
@@ -130,7 +130,7 @@
return referenceButton;
}
- function renderMessage(message, by, dt=null, annotations=null, raw=false) {
+ function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
let formattedMessage = formatHTMLMessage(message, raw);
@@ -153,10 +153,15 @@
// Append chat message div to chat body
let chatBody = document.getElementById("chat-body");
- chatBody.appendChild(chatMessage);
-
- // Scroll to bottom of chat-body element
- chatBody.scrollTop = chatBody.scrollHeight;
+ if (renderType === "append") {
+ chatBody.appendChild(chatMessage);
+ // Scroll to bottom of chat-body element
+ chatBody.scrollTop = chatBody.scrollHeight;
+ } else if (renderType === "prepend") {
+ chatBody.insertBefore(chatMessage, chatBody.firstChild);
+ } else if (renderType === "return") {
+ return chatMessage;
+ }
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
@@ -207,6 +212,7 @@
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
+ // If no document or online context is provided, render the message as is
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
@@ -214,30 +220,29 @@
imageMarkdown = ``;
} else if (intentType === "text-to-image2") {
imageMarkdown = ``;
+ } else if (intentType === "text-to-image-v3") {
+ imageMarkdown = ``;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
- renderMessage(imageMarkdown, by, dt);
- return;
+ return renderMessage(imageMarkdown, by, dt, null, false, "return");
}
- renderMessage(message, by, dt);
- return;
+ return renderMessage(message, by, dt, null, false, "return");
}
if (context == null && onlineContext == null) {
- renderMessage(message, by, dt);
- return;
+ return renderMessage(message, by, dt, null, false, "return");
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
- renderMessage(message, by, dt);
- return;
+ return renderMessage(message, by, dt, null, false, "return");
}
+ // If document or online context is provided, render the message with its references
let references = document.createElement('div');
let referenceExpandButton = document.createElement('button');
@@ -288,16 +293,17 @@
imageMarkdown = ``;
} else if (intentType === "text-to-image2") {
imageMarkdown = ``;
+ } else if (intentType === "text-to-image-v3") {
+ imageMarkdown = ``;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
- renderMessage(imageMarkdown, by, dt, references);
- return;
+ return renderMessage(imageMarkdown, by, dt, references, false, "return");
}
- renderMessage(message, by, dt, references);
+ return renderMessage(message, by, dt, references, false, "return");
}
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
@@ -509,6 +515,8 @@
rawResponse += ``;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += ``;
+ } else if (responseAsJson.intentType === "text-to-image-v3") {
+ rawResponse += ``;
}
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
@@ -671,7 +679,7 @@
let firstRunSetupMessageRendered = false;
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
- let chatHistoryUrl = `/api/chat/history?client=desktop`;
+ let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`;
if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
}
@@ -683,7 +691,8 @@
loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen);
- fetch(`${hostURL}${chatHistoryUrl}`, { headers })
+ // Get the most recent 10 chat messages from conversation history
+ fetch(`${chatHistoryUrl}&n=10`, { headers })
.then(response => response.json())
.then(data => {
if (data.detail) {
@@ -703,11 +712,21 @@
chatBody.dataset.conversationId = response.conversation_id;
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
- const fullChatLog = response.chat || [];
+ // Create a new IntersectionObserver
+ let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
+ entries.forEach(entry => {
+ // If the element is in the viewport, fetch the remaining message and unobserve the element
+ if (entry.isIntersecting) {
+ fetchRemainingChatMessages(chatHistoryUrl);
+ observer.unobserve(entry.target);
+ }
+ });
+ }, {rootMargin: '0px 0px 0px 0px'});
- fullChatLog.forEach(chat_log => {
+ const fullChatLog = response.chat || [];
+ fullChatLog.forEach((chat_log, index) => {
if (chat_log.message != null) {
- renderMessageWithReference(
+ let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
@@ -715,10 +734,25 @@
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]);
+ chatBody.appendChild(messageElement);
+
+ // When the 4th oldest message is within viewing distance (~60% scrolled up)
+ // Fetch the remaining chat messages
+ if (index === 4) {
+ fetchRemainingMessagesObserver.observe(messageElement);
+ }
}
loadingScreen.style.height = chatBody.scrollHeight + 'px';
})
+ // Scroll to bottom of chat-body element
+ chatBody.scrollTop = chatBody.scrollHeight;
+
+ // Set height of chat-body element to the height of the chat-body-wrapper
+ let chatBodyWrapper = document.getElementById("chat-body-wrapper");
+ let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
+ chatBody.style.height = chatBodyWrapperHeight;
+
// Add fade out animation to loading screen and remove it after the animation ends
fadeOutLoadingAnimation(loadingScreen);
})
@@ -726,9 +760,9 @@
// If the server returns a 500 error with detail, render a setup hint.
if (!firstRunSetupMessageRendered) {
renderFirstRunSetupMessage();
- fadeOutLoadingAnimation(loadingScreen);
}
- return;
+ fadeOutLoadingAnimation(loadingScreen);
+ return;
});
await refreshChatSessionsPanel();
@@ -778,6 +812,65 @@
}
}
+ function fetchRemainingChatMessages(chatHistoryUrl) {
+ // Create a new IntersectionObserver
+ let observer = new IntersectionObserver((entries, observer) => {
+ entries.forEach(entry => {
+ // If the element is in the viewport, render the message and unobserve the element
+ if (entry.isIntersecting) {
+ let chat_log = entry.target.chat_log;
+ let messageElement = renderMessageWithReference(
+ chat_log.message,
+ chat_log.by,
+ chat_log.context,
+ new Date(chat_log.created),
+ chat_log.onlineContext,
+ chat_log.intent?.type,
+ chat_log.intent?.["inferred-queries"]
+ );
+ entry.target.replaceWith(messageElement);
+
+ // Remove the observer after the element has been rendered
+ observer.unobserve(entry.target);
+ }
+ });
+ }, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
+
+ // Fetch remaining chat messages from conversation history
+ fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
+ .then(response => response.json())
+ .then(data => {
+ if (data.status != "ok") {
+ throw new Error(data.message);
+ }
+ return data.response;
+ })
+ .then(response => {
+ const fullChatLog = response.chat || [];
+ let chatBody = document.getElementById("chat-body");
+ fullChatLog
+ .reverse()
+ .forEach(chat_log => {
+ if (chat_log.message != null) {
+ // Create a new element for each chat log
+ let placeholder = document.createElement('div');
+ placeholder.chat_log = chat_log;
+
+ // Insert the message placeholder as the first child of chat body after the welcome message
+ chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
+
+ // Observe the element
+ placeholder.style.height = "20px";
+ observer.observe(placeholder);
+ }
+ });
+ })
+ .catch(err => {
+ console.log(err);
+ return;
+ });
+ }
+
function fadeOutLoadingAnimation(loadingScreen) {
let chatBody = document.getElementById("chat-body");
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts
index 328ce299..504ce4db 100644
--- a/src/interface/obsidian/src/chat_modal.ts
+++ b/src/interface/obsidian/src/chat_modal.ts
@@ -156,6 +156,8 @@ export class KhojChatModal extends Modal {
imageMarkdown = ``;
} else if (intentType === "text-to-image2") {
imageMarkdown = ``;
+ } else if (intentType === "text-to-image-v3") {
+ imageMarkdown = ``;
}
if (inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:";
@@ -429,6 +431,8 @@ export class KhojChatModal extends Modal {
responseText += ``;
} else if (responseAsJson.intentType === "text-to-image2") {
responseText += ``;
+ } else if (responseAsJson.intentType === "text-to-image-v3") {
+ responseText += ``;
}
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 60aaf658..9f8abeb4 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files
-from khoj.utils.helpers import is_none_or_empty
+from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__)
@@ -232,6 +232,9 @@ def configure_server(
state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent()
+ message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
+ logger.info(message)
+
if not init:
initialize_content(regenerate, search_type, user)
@@ -329,9 +332,7 @@ def configure_search_types():
@schedule.repeat(schedule.every(2).minutes)
def upload_telemetry():
- if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
- message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
- logger.debug(message)
+ if telemetry_disabled(state.config.app) or not state.telemetry:
return
try:
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 95a75b7d..fd1a2314 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -197,9 +197,6 @@ def get_user_name(user: KhojUser):
def get_user_photo(user: KhojUser):
- full_name = user.get_full_name()
- if not is_none_or_empty(full_name):
- return full_name
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile:
return google_profile.picture
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index 00e065e5..98f9f38d 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -23,6 +23,7 @@ from khoj.database.models import (
TextToImageModelConfig,
UserSearchModelConfig,
)
+from khoj.utils.helpers import ImageIntentType
class KhojUserAdmin(UserAdmin):
@@ -114,9 +115,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
- and log["intent"]["type"] == "text-to-image"
+ and (
+ log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
+ or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
+ )
):
- log["message"] = "image redacted for space"
+ log["message"] = "inline image redacted for space"
chat_log[idx] = log
modified_log["chat"] = chat_log
@@ -154,9 +158,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
- and log["intent"]["type"] == "text-to-image"
+ and (
+ log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
+ or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
+ )
):
- updated_log["message"] = "image redacted for space"
+ updated_log["message"] = "inline image redacted for space"
chat_log[idx] = updated_log
return_log["chat"] = chat_log
diff --git a/src/khoj/database/management/__init__.py b/src/khoj/database/management/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/khoj/database/management/commands/__init__.py b/src/khoj/database/management/commands/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/khoj/database/management/commands/convert_images_png_to_webp.py b/src/khoj/database/management/commands/convert_images_png_to_webp.py
new file mode 100644
index 00000000..b1ad8615
--- /dev/null
+++ b/src/khoj/database/management/commands/convert_images_png_to_webp.py
@@ -0,0 +1,40 @@
+from django.core.management.base import BaseCommand
+
+from khoj.database.models import Conversation
+from khoj.utils.helpers import ImageIntentType
+
+
+class Command(BaseCommand):
+ help = "Convert all images to WebP format or reverse."
+
+ def add_arguments(self, parser):
+ # Add a new argument 'reverse' to the command
+ parser.add_argument(
+ "--reverse",
+ action="store_true",
+ help="Convert from WebP to PNG instead of PNG to WebP",
+ )
+
+ def handle(self, *args, **options):
+ updated_count = 0
+ for conversation in Conversation.objects.all():
+ conversation_updated = False
+ for chat in conversation.conversation_log["chat"]:
+ if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
+ if options["reverse"] and chat["message"].endswith(".webp"):
+ # Convert WebP url to PNG url
+ chat["message"] = chat["message"].replace(".webp", ".png")
+ conversation_updated = True
+ updated_count += 1
+ elif chat["message"].endswith(".png"):
+ # Convert PNG url to WebP url
+ chat["message"] = chat["message"].replace(".png", ".webp")
+ conversation_updated = True
+ updated_count += 1
+ if conversation_updated:
+ conversation.save()
+
+ if updated_count > 0 and options["reverse"]:
+ self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format."))
+ elif updated_count > 0:
+ self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format."))
diff --git a/src/khoj/database/migrations/0035_convert_png_to_webp.py b/src/khoj/database/migrations/0035_convert_png_to_webp.py
new file mode 100644
index 00000000..35495629
--- /dev/null
+++ b/src/khoj/database/migrations/0035_convert_png_to_webp.py
@@ -0,0 +1,69 @@
+# Generated by Django 4.2.10 on 2024-04-13 17:54
+
+import base64
+import io
+
+from django.db import migrations
+from PIL import Image
+
+from khoj.utils.helpers import ImageIntentType
+
+
+def convert_png_images_to_webp(apps, schema_editor):
+ # Get the model from the versioned app registry to ensure the correct version is used
+ Conversations = apps.get_model("database", "Conversation")
+ for conversation in Conversations.objects.all():
+ for chat in conversation.conversation_log["chat"]:
+ if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
+ # Decode the base64 encoded PNG image
+ decoded_image = base64.b64decode(chat["message"])
+
+ # Convert images from PNG to WebP format
+ image_io = io.BytesIO(decoded_image)
+ with Image.open(image_io) as png_image:
+ webp_image_io = io.BytesIO()
+ png_image.save(webp_image_io, "WEBP")
+
+ # Encode the WebP image back to base64
+ webp_image_bytes = webp_image_io.getvalue()
+ chat["message"] = base64.b64encode(webp_image_bytes).decode()
+ chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
+ webp_image_io.close()
+
+ # Save the updated conversation history
+ conversation.save()
+
+
+def convert_webp_images_to_png(apps, schema_editor):
+ # Get the model from the versioned app registry to ensure the correct version is used
+ Conversations = apps.get_model("database", "Conversation")
+ for conversation in Conversations.objects.all():
+ for chat in conversation.conversation_log["chat"]:
+ if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
+ # Decode the base64 encoded PNG image
+ decoded_image = base64.b64decode(chat["message"])
+
+ # Convert images from PNG to WebP format
+ image_io = io.BytesIO(decoded_image)
+ with Image.open(image_io) as png_image:
+ webp_image_io = io.BytesIO()
+ png_image.save(webp_image_io, "PNG")
+
+ # Encode the WebP image back to base64
+ webp_image_bytes = webp_image_io.getvalue()
+ chat["message"] = base64.b64encode(webp_image_bytes).decode()
+ chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
+ webp_image_io.close()
+
+ # Save the updated conversation history
+ conversation.save()
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0034_alter_chatmodeloptions_chat_model"),
+ ]
+
+ operations = [
+ migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png),
+ ]
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index c2f51ca5..c96fb773 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -4,10 +4,10 @@
Khoj - Chat
+
-
@@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co
return referenceButton;
}
- function renderMessage(message, by, dt=null, annotations=null, raw=false) {
+ function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
let formattedMessage = formatHTMLMessage(message, raw);
@@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co
// Append chat message div to chat body
let chatBody = document.getElementById("chat-body");
- chatBody.appendChild(chatMessage);
-
- // Scroll to bottom of chat-body element
- chatBody.scrollTop = chatBody.scrollHeight;
+ if (renderType === "append") {
+ chatBody.appendChild(chatMessage);
+ // Scroll to bottom of chat-body element
+ chatBody.scrollTop = chatBody.scrollHeight;
+ } else if (renderType === "prepend"){
+ let chatBody = document.getElementById("chat-body");
+ chatBody.insertBefore(chatMessage, chatBody.firstChild);
+ } else if (renderType === "return") {
+ return chatMessage;
+ }
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
@@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
+ // If no document or online context is provided, render the message as is
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
@@ -244,24 +251,24 @@ To get started, just start typing below. You can also type / to see a list of co
imageMarkdown = ``;
} else if (intentType === "text-to-image2") {
imageMarkdown = ``;
+ } else if (intentType === "text-to-image-v3") {
+ imageMarkdown = ``;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
- renderMessage(imageMarkdown, by, dt);
- return;
+ return renderMessage(imageMarkdown, by, dt, null, false, "return");
}
- renderMessage(message, by, dt);
- return;
+ return renderMessage(message, by, dt, null, false, "return");
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
- renderMessage(message, by, dt);
- return;
+ return renderMessage(message, by, dt, null, false, "return");
}
+ // If document or online context is provided, render the message with its references
let references = document.createElement('div');
let referenceExpandButton = document.createElement('button');
@@ -312,16 +319,17 @@ To get started, just start typing below. You can also type / to see a list of co
imageMarkdown = ``;
} else if (intentType === "text-to-image2") {
imageMarkdown = ``;
+ } else if (intentType === "text-to-image-v3") {
+ imageMarkdown = ``;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
- renderMessage(imageMarkdown, by, dt, references);
- return;
+ return renderMessage(imageMarkdown, by, dt, references, false, "return");
}
- renderMessage(message, by, dt, references);
+ return renderMessage(message, by, dt, references, false, "return");
}
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
@@ -619,6 +627,8 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += ``;
} else if (imageJson.intentType === "text-to-image2") {
rawResponse += ``;
+ } else if (imageJson.intentType === "text-to-image-v3") {
+ rawResponse = ``;
}
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@@ -1064,7 +1074,8 @@ To get started, just start typing below. You can also type / to see a list of co
loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen);
- fetch(chatHistoryUrl, { method: "GET" })
+ // Get the most recent 10 chat messages from conversation history
+ fetch(`${chatHistoryUrl}&n=10`, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.detail) {
@@ -1134,11 +1145,22 @@ To get started, just start typing below. You can also type / to see a list of co
agentMetadataElement.style.display = "none";
}
- const fullChatLog = response.chat || [];
+ // Create a new IntersectionObserver
+ let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
+ entries.forEach(entry => {
+ // If the element is in the viewport, fetch the remaining message and unobserve the element
+ if (entry.isIntersecting) {
+ fetchRemainingChatMessages(chatHistoryUrl);
+ observer.unobserve(entry.target);
+ }
+ });
+ }, {rootMargin: '0px 0px 0px 0px'});
- fullChatLog.forEach(chat_log => {
- if (chat_log.message != null){
- renderMessageWithReference(
+ const fullChatLog = response.chat || [];
+ fullChatLog.forEach((chat_log, index) => {
+ // Render the last 10 messages immediately
+ if (chat_log.message != null) {
+ let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
@@ -1146,14 +1168,26 @@ To get started, just start typing below. You can also type / to see a list of co
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]);
+ chatBody.appendChild(messageElement);
+
+ // When the 4th oldest message is within viewing distance (~60% scroll up)
+ // Fetch the remaining chat messages
+ if (index === 4) {
+ fetchRemainingMessagesObserver.observe(messageElement);
+ }
}
loadingScreen.style.height = chatBody.scrollHeight + 'px';
});
- // Add fade out animation to loading screen and remove it after the animation ends
+ // Scroll to bottom of chat-body element
+ chatBody.scrollTop = chatBody.scrollHeight;
+
+ // Set height of chat-body element to the height of the chat-body-wrapper
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
- chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
+ let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight;
+
+ // Add fade out animation to loading screen and remove it after the animation ends
setTimeout(() => {
loadingScreen.remove();
chatBody.classList.remove("relative-position");
@@ -1211,6 +1245,66 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = query_via_url;
chat();
}
+
+ }
+
+ function fetchRemainingChatMessages(chatHistoryUrl) {
+ // Create a new IntersectionObserver
+ let observer = new IntersectionObserver((entries, observer) => {
+ entries.forEach(entry => {
+ // If the element is in the viewport, render the message and unobserve the element
+ if (entry.isIntersecting) {
+ let chat_log = entry.target.chat_log;
+ let messageElement = renderMessageWithReference(
+ chat_log.message,
+ chat_log.by,
+ chat_log.context,
+ new Date(chat_log.created),
+ chat_log.onlineContext,
+ chat_log.intent?.type,
+ chat_log.intent?.["inferred-queries"]
+ );
+ entry.target.replaceWith(messageElement);
+
+ // Remove the observer after the element has been rendered
+ observer.unobserve(entry.target);
+ }
+ });
+ }, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
+
+ // Fetch remaining chat messages from conversation history
+ fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
+ .then(response => response.json())
+ .then(data => {
+ if (data.status != "ok") {
+ throw new Error(data.message);
+ }
+ return data.response;
+ })
+ .then(response => {
+ const fullChatLog = response.chat || [];
+ let chatBody = document.getElementById("chat-body");
+ fullChatLog
+ .reverse()
+ .forEach(chat_log => {
+ if (chat_log.message != null) {
+ // Create a new element for each chat log
+ let placeholder = document.createElement('div');
+ placeholder.chat_log = chat_log;
+
+ // Insert the message placeholder as the first child of chat body after the welcome message
+ chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
+
+ // Observe the element
+ placeholder.style.height = "20px";
+ observer.observe(placeholder);
+ }
+ });
+ })
+ .catch(err => {
+ console.log(err);
+ return;
+ });
}
function flashStatusInChatInput(message) {
diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py
index 02fa4cf0..2aa63d4e 100644
--- a/src/khoj/processor/content/github/github_to_entries.py
+++ b/src/khoj/processor/content/github/github_to_entries.py
@@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries):
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
except ConnectionAbortedError as e:
logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}")
+ raise e
except Exception as e:
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e
diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py
index 6e078f07..57456ed5 100644
--- a/src/khoj/processor/content/notion/notion_to_entries.py
+++ b/src/khoj/processor/content/notion/notion_to_entries.py
@@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries):
for response in responses:
with timer("Processing response", logger=logger):
- pages_or_databases = response["results"] if response.get("results") else []
+ pages_or_databases = response.get("results", [])
# Get all pages content
for p_or_d in pages_or_databases:
@@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries):
current_entries = []
curr_heading = ""
- for block in content["results"]:
+ for block in content.get("results", []):
block_type = block.get("type")
if block_type == None:
@@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries):
return f"\n{heading}\n"
def process_nested_children(self, children, raw_content, block_type=None):
- results = children["results"] if children.get("results") else []
+ results = children.get("results", [])
for child in results:
child_type = child.get("type")
if child_type == None:
diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
index 10dc08fa..a559df22 100644
--- a/src/khoj/processor/conversation/offline/chat_model.py
+++ b/src/khoj/processor/conversation/offline/chat_model.py
@@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
+ max_prompt_size: int = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -41,7 +42,7 @@ def extract_questions_offline(
return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
- offline_chat_model = loaded_model or download_model(model)
+ offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
@@ -67,12 +68,14 @@ def extract_questions_offline(
location=location,
)
messages = generate_chatml_messages_with_context(
- example_questions, model_name=model, loaded_model=offline_chat_model
+ example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
state.chat_lock.acquire()
try:
- response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
+ response = send_message_to_model_offline(
+ messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
+ )
finally:
state.chat_lock.release()
@@ -138,7 +141,7 @@ def converse_offline(
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
- offline_chat_model = loaded_model or download_model(model)
+ offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item}" for item in references})
current_date = datetime.now().strftime("%Y-%m-%d")
@@ -190,18 +193,18 @@ def converse_offline(
)
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
- t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
+ t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t.start()
return g
-def llm_thread(g, messages: List[ChatMessage], model: Any):
+def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
stop_phrases = ["", "INST]", "Notes:"]
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
- messages, loaded_model=model, stop=stop_phrases, streaming=True
+ messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", ""))
@@ -216,9 +219,10 @@ def send_message_to_model_offline(
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
streaming=False,
stop=[],
+ max_prompt_size: int = None,
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
- offline_chat_model = loaded_model or download_model(model)
+ offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
if streaming:
diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py
index b711c11a..c2b08bfa 100644
--- a/src/khoj/processor/conversation/offline/utils.py
+++ b/src/khoj/processor/conversation/offline/utils.py
@@ -1,18 +1,19 @@
import glob
import logging
+import math
import os
from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state
+from khoj.utils.helpers import get_device_memory
logger = logging.getLogger(__name__)
-def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
- from llama_cpp.llama import Llama
-
- # Initialize Model Parameters. Use n_ctx=0 to get context size from the model
+def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
+ # Initialize Model Parameters
+ # Use n_ctx=0 to get context size from the model
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
# Decide whether to load model to GPU or CPU
@@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
model_path = load_model_from_cache(repo_id, filename)
chat_model = None
try:
- if model_path:
- chat_model = Llama(model_path, **kwargs)
- else:
- Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
+ chat_model = load_model(model_path, repo_id, filename, kwargs)
except:
# Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu"
- if model_path:
- chat_model = Llama(model_path, **kwargs)
- else:
- chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
+ chat_model = load_model(model_path, repo_id, filename, kwargs)
- logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
+ # Now load the model with context size set based on:
+ # 1. context size supported by model and
+ # 2. configured size or machine (V)RAM
+ kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
+ chat_model = load_model(model_path, repo_id, filename, kwargs)
+ logger.debug(
+ f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
+ )
return chat_model
+def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
+ from llama_cpp.llama import Llama
+
+ if model_path:
+ return Llama(model_path, **kwargs)
+ else:
+ return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
+
+
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/")
@@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
return paths[0]
else:
return None
+
+
+def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
+ """Infer max prompt size based on device memory and max context window supported by the model"""
+ vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
+ if configured_max_tokens:
+ return min(configured_max_tokens, model_context_window)
+ else:
+ return min(vram_based_n_ctx, model_context_window)
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 845ccb48..e787eedf 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -1,5 +1,6 @@
import json
import logging
+import math
import queue
from datetime import datetime
from time import perf_counter
@@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
tokenizer_name=None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
- # Set max prompt size from user config, pre-configured for model or to default prompt size
- try:
- max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
- except:
- max_prompt_size = 2000
- logger.warning(
- f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
- )
+ # Set max prompt size from user config or based on pre-configured for model and machine specs
+ if not max_prompt_size:
+ if loaded_model:
+ max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
+ else:
+ max_prompt_size = model_to_prompt_size.get(model_name, 2000)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
@@ -187,7 +186,7 @@ def truncate_messages(
max_prompt_size,
model_name: str,
loaded_model: Optional[Llama] = None,
- tokenizer_name=None,
+ tokenizer_name="hf-internal-testing/llama-tokenizer",
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
@@ -197,15 +196,11 @@ def truncate_messages(
elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
- try:
- encoder = download_model(model_name).tokenizer()
- except:
- encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
+ encoder = download_model(model_name).tokenizer()
except:
- default_tokenizer = "hf-internal-testing/llama-tokenizer"
- encoder = AutoTokenizer.from_pretrained(default_tokenizer)
+ encoder = AutoTokenizer.from_pretrained(tokenizer_name)
logger.warning(
- f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
+ f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
# Extract system message from messages
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index cf84e724..c511b6d9 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -289,9 +289,7 @@ async def extract_references_and_questions(
return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
- logger.warning(
- "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
- )
+ logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
return compiled_references, inferred_queries, q
# Extract filter terms from user message
@@ -317,8 +315,9 @@ async def extract_references_and_questions(
using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model
+ max_tokens = default_offline_llm.max_prompt_size
if state.offline_chat_processor_config is None:
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
+ state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
@@ -328,6 +327,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
should_extract_questions=True,
location_data=location_data,
+ max_prompt_size=conversation_config.max_prompt_size,
)
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 4e7a8cc9..9af00053 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -76,6 +76,7 @@ def chat_history(
request: Request,
common: CommonQueryParams,
conversation_id: Optional[int] = None,
+ n: Optional[int] = None,
):
user = request.user.object
validate_conversation_config()
@@ -109,6 +110,13 @@ def chat_history(
}
)
+ # Get latest N messages if N > 0
+ if n > 0:
+ meta_log["chat"] = meta_log["chat"][-n:]
+ # Else return all messages except latest N
+ else:
+ meta_log["chat"] = meta_log["chat"][:n]
+
update_telemetry_state(
request=request,
telemetry_type="api",
@@ -425,8 +433,7 @@ async def websocket_endpoint(
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
- intent_type = "text-to-image"
- image, status_code, improved_image_prompt, image_url = await text_to_image(
+ image, status_code, improved_image_prompt, intent_type = await text_to_image(
q,
user,
meta_log,
@@ -445,9 +452,6 @@ async def websocket_endpoint(
await send_complete_llm_response(json.dumps(content_obj))
continue
- if image_url:
- intent_type = "text-to-image2"
- image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,
@@ -621,17 +625,13 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
- intent_type = "text-to-image"
- image, status_code, improved_image_prompt, image_url = await text_to_image(
+ image, status_code, improved_image_prompt, intent_type = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None:
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
- if image_url:
- intent_type = "text-to-image2"
- image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index f3be3162..3c93385d 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -1,4 +1,6 @@
import asyncio
+import base64
+import io
import json
import logging
from concurrent.futures import ThreadPoolExecutor
@@ -18,6 +20,7 @@ from typing import (
import openai
from fastapi import Depends, Header, HTTPException, Request, UploadFile
+from PIL import Image
from starlette.authentication import has_required_scope
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
@@ -46,6 +49,7 @@ from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
ConversationCommand,
+ ImageIntentType,
is_none_or_empty,
is_valid_url,
log_telemetry,
@@ -79,9 +83,10 @@ async def is_ready_to_chat(user: KhojUser):
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model
+ max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
+ state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True
ready = has_openai_config or has_offline_config
@@ -382,10 +387,11 @@ async def send_message_to_model_wrapper(
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model
+ max_tokens = conversation_config.max_prompt_size
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
+ state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
@@ -452,7 +458,9 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
- state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
+ chat_model = conversation_config.chat_model
+ max_tokens = conversation_config.max_prompt_size
+ state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
@@ -508,18 +516,19 @@ async def text_to_image(
references: List[str],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
-) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
+) -> Tuple[Optional[str], int, Optional[str], str]:
status_code = 200
image = None
response = None
image_url = None
+ intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
- return image, status_code, message, image_url
+ return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
@@ -550,21 +559,38 @@ async def text_to_image(
)
image = response.data[0].b64_json
+ with timer("Convert image to webp", logger):
+ # Convert png to webp for faster loading
+ decoded_image = base64.b64decode(image)
+ image_io = io.BytesIO(decoded_image)
+ png_image = Image.open(image_io)
+ webp_image_io = io.BytesIO()
+ png_image.save(webp_image_io, "WEBP")
+ webp_image_bytes = webp_image_io.getvalue()
+ webp_image_io.close()
+ image_io.close()
+
with timer("Upload image to S3", logger):
- image_url = upload_image(image, user.uuid)
- return image, status_code, improved_image_prompt, image_url
+ image_url = upload_image(webp_image_bytes, user.uuid)
+ if image_url:
+ intent_type = ImageIntentType.TEXT_TO_IMAGE2
+ else:
+ intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
+ image = base64.b64encode(webp_image_bytes).decode("utf-8")
+
+ return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
- return image, status_code, message, image_url
+ return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
- return image, status_code, message, image_url
- return image, status_code, response, image_url
+ return image_url or image, status_code, message, intent_type.value
+ return image_url or image, status_code, response, intent_type.value
class ApiUserRateLimiter:
diff --git a/src/khoj/routers/storage.py b/src/khoj/routers/storage.py
index 57c28c5a..9a5d448f 100644
--- a/src/khoj/routers/storage.py
+++ b/src/khoj/routers/storage.py
@@ -1,4 +1,3 @@
-import base64
import logging
import os
import uuid
@@ -17,16 +16,15 @@ if aws_enabled:
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
-def upload_image(image: str, user_id: uuid.UUID):
+def upload_image(image: bytes, user_id: uuid.UUID):
"""Upload the image to the S3 bucket"""
if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload")
return None
- decoded_image = base64.b64decode(image)
- image_key = f"{user_id}/{uuid.uuid4()}.png"
+ image_key = f"{user_id}/{uuid.uuid4()}.webp"
try:
- s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
+ s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
return url
except Exception as e:
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 3f95030f..1732271a 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
class OfflineChatProcessorModel:
- def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
+ def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
self.chat_model = chat_model
self.loaded_model = None
try:
- self.loaded_model = download_model(self.chat_model)
+ self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index f6a66b4f..e621f53e 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -17,6 +17,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse
+import psutil
import torch
from asgiref.sync import sync_to_async
from magika import Magika
@@ -233,6 +234,10 @@ def get_server_id():
return server_id
+def telemetry_disabled(app_config: AppConfig):
+ return not app_config or not app_config.should_log_telemetry
+
+
def log_telemetry(
telemetry_type: str,
api: str = None,
@@ -242,7 +247,7 @@ def log_telemetry(
):
"""Log basic app usage telemetry like client, os, api called"""
# Do not log usage telemetry, if telemetry is disabled via app config
- if not app_config or not app_config.should_log_telemetry:
+ if telemetry_disabled(app_config):
return []
if properties.get("server_id") is None:
@@ -267,6 +272,17 @@ def log_telemetry(
return request_body
+def get_device_memory() -> int:
+ """Get device memory in GB"""
+ device = get_device()
+ if device.type == "cuda":
+ return torch.cuda.get_device_properties(device).total_memory
+ elif device.type == "mps":
+ return torch.mps.driver_allocated_memory()
+ else:
+ return psutil.virtual_memory().total
+
+
def get_device() -> torch.device:
"""Get device to run model on"""
if torch.cuda.is_available():
@@ -313,6 +329,20 @@ mode_descriptions_for_llm = {
}
+class ImageIntentType(Enum):
+ """
+ Chat message intent by Khoj for image responses.
+ Marks the schema used to reference image in chat messages
+ """
+
+ # Images as Inline PNG
+ TEXT_TO_IMAGE = "text-to-image"
+ # Images as URLs
+ TEXT_TO_IMAGE2 = "text-to-image2"
+ # Images as Inline WebP
+ TEXT_TO_IMAGE_V3 = "text-to-image-v3"
+
+
def generate_random_name():
# List of adjectives and nouns to choose from
adjectives = [