diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx index 4aa2f960..7b2866f4 100644 --- a/documentation/docs/get-started/setup.mdx +++ b/documentation/docs/get-started/setup.mdx @@ -134,7 +134,7 @@ python -m pip install khoj-assistant # CPU python -m pip install khoj-assistant # NVIDIA (CUDA) GPU - CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant + CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant # AMD (ROCm) GPU CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant # VULCAN GPU diff --git a/pyproject.toml b/pyproject.toml index 0b7483a1..c9c96691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "phonenumbers == 8.13.27", "markdownify ~= 0.11.6", "websockets == 12.0", + "psutil >= 5.8.0", ] dynamic = ["version"] @@ -105,7 +106,6 @@ dev = [ "pytest-asyncio == 0.21.1", "freezegun >= 1.2.0", "factory-boy >= 3.2.1", - "psutil >= 5.8.0", "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 627ce3de..fc7ecc2e 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -4,9 +4,9 @@ Khoj - Chat + - @@ -130,7 +130,7 @@ return referenceButton; } - function renderMessage(message, by, dt=null, annotations=null, raw=false) { + function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") { let message_time = formatDate(dt ?? new Date()); let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let formattedMessage = formatHTMLMessage(message, raw); @@ -153,10 +153,15 @@ // Append chat message div to chat body let chatBody = document.getElementById("chat-body"); - chatBody.appendChild(chatMessage); - - // Scroll to bottom of chat-body element - chatBody.scrollTop = chatBody.scrollHeight; + if (renderType === "append") { + chatBody.appendChild(chatMessage); + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + } else if (renderType === "prepend") { + chatBody.insertBefore(chatMessage, chatBody.firstChild); + } else if (renderType === "return") { + return chatMessage; + } let chatBodyWrapper = document.getElementById("chat-body-wrapper"); chatBodyWrapperHeight = chatBodyWrapper.clientHeight; @@ -207,6 +212,7 @@ } function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { + // If no document or online context is provided, render the message as is if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if (intentType?.includes("text-to-image")) { let imageMarkdown; @@ -214,30 +220,29 @@ imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt); - return; + return renderMessage(imageMarkdown, by, dt, null, false, "return"); } - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } if (context == null && onlineContext == null) { - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } + // If document or online context is provided, render the message with its references let references = document.createElement('div'); let referenceExpandButton = document.createElement('button'); @@ -288,16 +293,17 @@ imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt, references); - return; + return renderMessage(imageMarkdown, by, dt, references, false, "return"); } - renderMessage(message, by, dt, references); + return renderMessage(message, by, dt, references, false, "return"); } function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) { @@ -509,6 +515,8 @@ rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } else if (responseAsJson.intentType === "text-to-image2") { rawResponse += `![${query}](${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image-v3") { + rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; } const inferredQueries = responseAsJson.inferredQueries?.[0]; if (inferredQueries) { @@ -671,7 +679,7 @@ let firstRunSetupMessageRendered = false; let chatBody = document.getElementById("chat-body"); chatBody.innerHTML = ""; - let chatHistoryUrl = `/api/chat/history?client=desktop`; + let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; } @@ -683,7 +691,8 @@ loadingScreen.appendChild(yellowOrb); chatBody.appendChild(loadingScreen); - fetch(`${hostURL}${chatHistoryUrl}`, { headers }) + // Get the most recent 10 chat messages from conversation history + fetch(`${chatHistoryUrl}&n=10`, { headers }) .then(response => response.json()) .then(data => { if (data.detail) { @@ -703,11 +712,21 @@ chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; - const fullChatLog = response.chat || []; + // Create a new IntersectionObserver + let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, fetch the remaining message and unobserve the element + if (entry.isIntersecting) { + fetchRemainingChatMessages(chatHistoryUrl); + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 0px 0px'}); - fullChatLog.forEach(chat_log => { + const fullChatLog = response.chat || []; + fullChatLog.forEach((chat_log, index) => { if (chat_log.message != null) { - renderMessageWithReference( + let messageElement = renderMessageWithReference( chat_log.message, chat_log.by, chat_log.context, @@ -715,10 +734,25 @@ chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]); + chatBody.appendChild(messageElement); + + // When the 4th oldest message is within viewing distance (~60% scrolled up) + // Fetch the remaining chat messages + if (index === 4) { + fetchRemainingMessagesObserver.observe(messageElement); + } } loadingScreen.style.height = chatBody.scrollHeight + 'px'; }) + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + + // Set height of chat-body element to the height of the chat-body-wrapper + let chatBodyWrapper = document.getElementById("chat-body-wrapper"); + let chatBodyWrapperHeight = chatBodyWrapper.clientHeight; + chatBody.style.height = chatBodyWrapperHeight; + // Add fade out animation to loading screen and remove it after the animation ends fadeOutLoadingAnimation(loadingScreen); }) @@ -726,9 +760,9 @@ // If the server returns a 500 error with detail, render a setup hint. if (!firstRunSetupMessageRendered) { renderFirstRunSetupMessage(); - fadeOutLoadingAnimation(loadingScreen); } - return; + fadeOutLoadingAnimation(loadingScreen); + return; }); await refreshChatSessionsPanel(); @@ -778,6 +812,65 @@ } } + function fetchRemainingChatMessages(chatHistoryUrl) { + // Create a new IntersectionObserver + let observer = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, render the message and unobserve the element + if (entry.isIntersecting) { + let chat_log = entry.target.chat_log; + let messageElement = renderMessageWithReference( + chat_log.message, + chat_log.by, + chat_log.context, + new Date(chat_log.created), + chat_log.onlineContext, + chat_log.intent?.type, + chat_log.intent?.["inferred-queries"] + ); + entry.target.replaceWith(messageElement); + + // Remove the observer after the element has been rendered + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport + + // Fetch remaining chat messages from conversation history + fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" }) + .then(response => response.json()) + .then(data => { + if (data.status != "ok") { + throw new Error(data.message); + } + return data.response; + }) + .then(response => { + const fullChatLog = response.chat || []; + let chatBody = document.getElementById("chat-body"); + fullChatLog + .reverse() + .forEach(chat_log => { + if (chat_log.message != null) { + // Create a new element for each chat log + let placeholder = document.createElement('div'); + placeholder.chat_log = chat_log; + + // Insert the message placeholder as the first child of chat body after the welcome message + chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling); + + // Observe the element + placeholder.style.height = "20px"; + observer.observe(placeholder); + } + }); + }) + .catch(err => { + console.log(err); + return; + }); + } + function fadeOutLoadingAnimation(loadingScreen) { let chatBody = document.getElementById("chat-body"); let chatBodyWrapper = document.getElementById("chat-body-wrapper"); diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 328ce299..504ce4db 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -156,6 +156,8 @@ export class KhojChatModal extends Modal { imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } if (inferredQueries) { imageMarkdown += "\n\n**Inferred Query**:"; @@ -429,6 +431,8 @@ export class KhojChatModal extends Modal { responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`; } else if (responseAsJson.intentType === "text-to-image2") { responseText += `![${query}](${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image-v3") { + responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`; } const inferredQuery = responseAsJson.inferredQueries?.[0]; if (inferredQuery) { diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 60aaf658..9f8abeb4 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state from khoj.utils.config import SearchType from khoj.utils.fs_syncer import collect_files -from khoj.utils.helpers import is_none_or_empty +from khoj.utils.helpers import is_none_or_empty, telemetry_disabled from khoj.utils.rawconfig import FullConfig logger = logging.getLogger(__name__) @@ -232,6 +232,9 @@ def configure_server( state.search_models = configure_search(state.search_models, state.config.search_type) setup_default_agent() + message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled" + logger.info(message) + if not init: initialize_content(regenerate, search_type, user) @@ -329,9 +332,7 @@ def configure_search_types(): @schedule.repeat(schedule.every(2).minutes) def upload_telemetry(): - if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry: - message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled" - logger.debug(message) + if telemetry_disabled(state.config.app) or not state.telemetry: return try: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 95a75b7d..fd1a2314 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -197,9 +197,6 @@ def get_user_name(user: KhojUser): def get_user_photo(user: KhojUser): - full_name = user.get_full_name() - if not is_none_or_empty(full_name): - return full_name google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() if google_profile: return google_profile.picture diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 00e065e5..98f9f38d 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -23,6 +23,7 @@ from khoj.database.models import ( TextToImageModelConfig, UserSearchModelConfig, ) +from khoj.utils.helpers import ImageIntentType class KhojUserAdmin(UserAdmin): @@ -114,9 +115,12 @@ class ConversationAdmin(admin.ModelAdmin): log["by"] == "khoj" and log["intent"] and log["intent"]["type"] - and log["intent"]["type"] == "text-to-image" + and ( + log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value + or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value + ) ): - log["message"] = "image redacted for space" + log["message"] = "inline image redacted for space" chat_log[idx] = log modified_log["chat"] = chat_log @@ -154,9 +158,12 @@ class ConversationAdmin(admin.ModelAdmin): log["by"] == "khoj" and log["intent"] and log["intent"]["type"] - and log["intent"]["type"] == "text-to-image" + and ( + log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value + or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value + ) ): - updated_log["message"] = "image redacted for space" + updated_log["message"] = "inline image redacted for space" chat_log[idx] = updated_log return_log["chat"] = chat_log diff --git a/src/khoj/database/management/__init__.py b/src/khoj/database/management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/database/management/commands/__init__.py b/src/khoj/database/management/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/database/management/commands/convert_images_png_to_webp.py b/src/khoj/database/management/commands/convert_images_png_to_webp.py new file mode 100644 index 00000000..b1ad8615 --- /dev/null +++ b/src/khoj/database/management/commands/convert_images_png_to_webp.py @@ -0,0 +1,40 @@ +from django.core.management.base import BaseCommand + +from khoj.database.models import Conversation +from khoj.utils.helpers import ImageIntentType + + +class Command(BaseCommand): + help = "Convert all images to WebP format or reverse." + + def add_arguments(self, parser): + # Add a new argument 'reverse' to the command + parser.add_argument( + "--reverse", + action="store_true", + help="Convert from WebP to PNG instead of PNG to WebP", + ) + + def handle(self, *args, **options): + updated_count = 0 + for conversation in Conversation.objects.all(): + conversation_updated = False + for chat in conversation.conversation_log["chat"]: + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: + if options["reverse"] and chat["message"].endswith(".webp"): + # Convert WebP url to PNG url + chat["message"] = chat["message"].replace(".webp", ".png") + conversation_updated = True + updated_count += 1 + elif chat["message"].endswith(".png"): + # Convert PNG url to WebP url + chat["message"] = chat["message"].replace(".png", ".webp") + conversation_updated = True + updated_count += 1 + if conversation_updated: + conversation.save() + + if updated_count > 0 and options["reverse"]: + self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format.")) + elif updated_count > 0: + self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format.")) diff --git a/src/khoj/database/migrations/0035_convert_png_to_webp.py b/src/khoj/database/migrations/0035_convert_png_to_webp.py new file mode 100644 index 00000000..35495629 --- /dev/null +++ b/src/khoj/database/migrations/0035_convert_png_to_webp.py @@ -0,0 +1,69 @@ +# Generated by Django 4.2.10 on 2024-04-13 17:54 + +import base64 +import io + +from django.db import migrations +from PIL import Image + +from khoj.utils.helpers import ImageIntentType + + +def convert_png_images_to_webp(apps, schema_editor): + # Get the model from the versioned app registry to ensure the correct version is used + Conversations = apps.get_model("database", "Conversation") + for conversation in Conversations.objects.all(): + for chat in conversation.conversation_log["chat"]: + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value: + # Decode the base64 encoded PNG image + decoded_image = base64.b64decode(chat["message"]) + + # Convert images from PNG to WebP format + image_io = io.BytesIO(decoded_image) + with Image.open(image_io) as png_image: + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + + # Encode the WebP image back to base64 + webp_image_bytes = webp_image_io.getvalue() + chat["message"] = base64.b64encode(webp_image_bytes).decode() + chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value + webp_image_io.close() + + # Save the updated conversation history + conversation.save() + + +def convert_webp_images_to_png(apps, schema_editor): + # Get the model from the versioned app registry to ensure the correct version is used + Conversations = apps.get_model("database", "Conversation") + for conversation in Conversations.objects.all(): + for chat in conversation.conversation_log["chat"]: + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value: + # Decode the base64 encoded PNG image + decoded_image = base64.b64decode(chat["message"]) + + # Convert images from PNG to WebP format + image_io = io.BytesIO(decoded_image) + with Image.open(image_io) as png_image: + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "PNG") + + # Encode the WebP image back to base64 + webp_image_bytes = webp_image_io.getvalue() + chat["message"] = base64.b64encode(webp_image_bytes).decode() + chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value + webp_image_io.close() + + # Save the updated conversation history + conversation.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0034_alter_chatmodeloptions_chat_model"), + ] + + operations = [ + migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png), + ] diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index c2f51ca5..c96fb773 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -4,10 +4,10 @@ Khoj - Chat + - @@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co return referenceButton; } - function renderMessage(message, by, dt=null, annotations=null, raw=false) { + function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") { let message_time = formatDate(dt ?? new Date()); let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let formattedMessage = formatHTMLMessage(message, raw); @@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co // Append chat message div to chat body let chatBody = document.getElementById("chat-body"); - chatBody.appendChild(chatMessage); - - // Scroll to bottom of chat-body element - chatBody.scrollTop = chatBody.scrollHeight; + if (renderType === "append") { + chatBody.appendChild(chatMessage); + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + } else if (renderType === "prepend"){ + let chatBody = document.getElementById("chat-body"); + chatBody.insertBefore(chatMessage, chatBody.firstChild); + } else if (renderType === "return") { + return chatMessage; + } let chatBodyWrapper = document.getElementById("chat-body-wrapper"); chatBodyWrapperHeight = chatBodyWrapper.clientHeight; @@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co } function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { + // If no document or online context is provided, render the message as is if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if (intentType?.includes("text-to-image")) { let imageMarkdown; @@ -244,24 +251,24 @@ To get started, just start typing below. You can also type / to see a list of co imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt); - return; + return renderMessage(imageMarkdown, by, dt, null, false, "return"); } - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } + // If document or online context is provided, render the message with its references let references = document.createElement('div'); let referenceExpandButton = document.createElement('button'); @@ -312,16 +319,17 @@ To get started, just start typing below. You can also type / to see a list of co imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt, references); - return; + return renderMessage(imageMarkdown, by, dt, references, false, "return"); } - renderMessage(message, by, dt, references); + return renderMessage(message, by, dt, references, false, "return"); } function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) { @@ -619,6 +627,8 @@ To get started, just start typing below. You can also type / to see a list of co rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; } else if (imageJson.intentType === "text-to-image2") { rawResponse += `![generated_image](${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image-v3") { + rawResponse = `![](data:image/webp;base64,${imageJson.image})`; } if (inferredQuery) { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; @@ -1064,7 +1074,8 @@ To get started, just start typing below. You can also type / to see a list of co loadingScreen.appendChild(yellowOrb); chatBody.appendChild(loadingScreen); - fetch(chatHistoryUrl, { method: "GET" }) + // Get the most recent 10 chat messages from conversation history + fetch(`${chatHistoryUrl}&n=10`, { method: "GET" }) .then(response => response.json()) .then(data => { if (data.detail) { @@ -1134,11 +1145,22 @@ To get started, just start typing below. You can also type / to see a list of co agentMetadataElement.style.display = "none"; } - const fullChatLog = response.chat || []; + // Create a new IntersectionObserver + let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, fetch the remaining message and unobserve the element + if (entry.isIntersecting) { + fetchRemainingChatMessages(chatHistoryUrl); + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 0px 0px'}); - fullChatLog.forEach(chat_log => { - if (chat_log.message != null){ - renderMessageWithReference( + const fullChatLog = response.chat || []; + fullChatLog.forEach((chat_log, index) => { + // Render the last 10 messages immediately + if (chat_log.message != null) { + let messageElement = renderMessageWithReference( chat_log.message, chat_log.by, chat_log.context, @@ -1146,14 +1168,26 @@ To get started, just start typing below. You can also type / to see a list of co chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]); + chatBody.appendChild(messageElement); + + // When the 4th oldest message is within viewing distance (~60% scroll up) + // Fetch the remaining chat messages + if (index === 4) { + fetchRemainingMessagesObserver.observe(messageElement); + } } loadingScreen.style.height = chatBody.scrollHeight + 'px'; }); - // Add fade out animation to loading screen and remove it after the animation ends + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + + // Set height of chat-body element to the height of the chat-body-wrapper let chatBodyWrapper = document.getElementById("chat-body-wrapper"); - chatBodyWrapperHeight = chatBodyWrapper.clientHeight; + let chatBodyWrapperHeight = chatBodyWrapper.clientHeight; chatBody.style.height = chatBodyWrapperHeight; + + // Add fade out animation to loading screen and remove it after the animation ends setTimeout(() => { loadingScreen.remove(); chatBody.classList.remove("relative-position"); @@ -1211,6 +1245,66 @@ To get started, just start typing below. You can also type / to see a list of co document.getElementById("chat-input").value = query_via_url; chat(); } + + } + + function fetchRemainingChatMessages(chatHistoryUrl) { + // Create a new IntersectionObserver + let observer = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, render the message and unobserve the element + if (entry.isIntersecting) { + let chat_log = entry.target.chat_log; + let messageElement = renderMessageWithReference( + chat_log.message, + chat_log.by, + chat_log.context, + new Date(chat_log.created), + chat_log.onlineContext, + chat_log.intent?.type, + chat_log.intent?.["inferred-queries"] + ); + entry.target.replaceWith(messageElement); + + // Remove the observer after the element has been rendered + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport + + // Fetch remaining chat messages from conversation history + fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" }) + .then(response => response.json()) + .then(data => { + if (data.status != "ok") { + throw new Error(data.message); + } + return data.response; + }) + .then(response => { + const fullChatLog = response.chat || []; + let chatBody = document.getElementById("chat-body"); + fullChatLog + .reverse() + .forEach(chat_log => { + if (chat_log.message != null) { + // Create a new element for each chat log + let placeholder = document.createElement('div'); + placeholder.chat_log = chat_log; + + // Insert the message placeholder as the first child of chat body after the welcome message + chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling); + + // Observe the element + placeholder.style.height = "20px"; + observer.observe(placeholder); + } + }); + }) + .catch(err => { + console.log(err); + return; + }); } function flashStatusInChatInput(message) { diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 02fa4cf0..2aa63d4e 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries): markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo) except ConnectionAbortedError as e: logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}") + raise e except Exception as e: logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True) raise e diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 6e078f07..57456ed5 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries): for response in responses: with timer("Processing response", logger=logger): - pages_or_databases = response["results"] if response.get("results") else [] + pages_or_databases = response.get("results", []) # Get all pages content for p_or_d in pages_or_databases: @@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries): current_entries = [] curr_heading = "" - for block in content["results"]: + for block in content.get("results", []): block_type = block.get("type") if block_type == None: @@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries): return f"\n{heading}\n" def process_nested_children(self, children, raw_content, block_type=None): - results = children["results"] if children.get("results") else [] + results = children.get("results", []) for child in results: child_type = child.get("type") if child_type == None: diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 10dc08fa..a559df22 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -30,6 +30,7 @@ def extract_questions_offline( use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, + max_prompt_size: int = None, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -41,7 +42,7 @@ def extract_questions_offline( return all_questions assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" @@ -67,12 +68,14 @@ def extract_questions_offline( location=location, ) messages = generate_chatml_messages_with_context( - example_questions, model_name=model, loaded_model=offline_chat_model + example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size ) state.chat_lock.acquire() try: - response = send_message_to_model_offline(messages, loaded_model=offline_chat_model) + response = send_message_to_model_offline( + messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size + ) finally: state.chat_lock.release() @@ -138,7 +141,7 @@ def converse_offline( """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) compiled_references_message = "\n\n".join({f"{item}" for item in references}) current_date = datetime.now().strftime("%Y-%m-%d") @@ -190,18 +193,18 @@ def converse_offline( ) g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) t.start() return g -def llm_thread(g, messages: List[ChatMessage], model: Any): +def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): stop_phrases = ["", "INST]", "Notes:"] state.chat_lock.acquire() try: response_iterator = send_message_to_model_offline( - messages, loaded_model=model, stop=stop_phrases, streaming=True + messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True ) for response in response_iterator: g.send(response["choices"][0]["delta"].get("content", "")) @@ -216,9 +219,10 @@ def send_message_to_model_offline( model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", streaming=False, stop=[], + max_prompt_size: int = None, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) if streaming: diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index b711c11a..c2b08bfa 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -1,18 +1,19 @@ import glob import logging +import math import os from huggingface_hub.constants import HF_HUB_CACHE from khoj.utils import state +from khoj.utils.helpers import get_device_memory logger = logging.getLogger(__name__) -def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): - from llama_cpp.llama import Llama - - # Initialize Model Parameters. Use n_ctx=0 to get context size from the model +def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None): + # Initialize Model Parameters + # Use n_ctx=0 to get context size from the model kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False} # Decide whether to load model to GPU or CPU @@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): model_path = load_model_from_cache(repo_id, filename) chat_model = None try: - if model_path: - chat_model = Llama(model_path, **kwargs) - else: - Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + chat_model = load_model(model_path, repo_id, filename, kwargs) except: # Load model on CPU if GPU is not available kwargs["n_gpu_layers"], device = 0, "cpu" - if model_path: - chat_model = Llama(model_path, **kwargs) - else: - chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + chat_model = load_model(model_path, repo_id, filename, kwargs) - logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}") + # Now load the model with context size set based on: + # 1. context size supported by model and + # 2. configured size or machine (V)RAM + kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens) + chat_model = load_model(model_path, repo_id, filename, kwargs) + logger.debug( + f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window." + ) return chat_model +def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}): + from llama_cpp.llama import Llama + + if model_path: + return Llama(model_path, **kwargs) + else: + return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + + def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): # Construct the path to the model file in the cache directory repo_org, repo_name = repo_id.split("/") @@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): return paths[0] else: return None + + +def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int: + """Infer max prompt size based on device memory and max context window supported by the model""" + vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic + if configured_max_tokens: + return min(configured_max_tokens, model_context_window) + else: + return min(vram_based_n_ctx, model_context_window) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 845ccb48..e787eedf 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,5 +1,6 @@ import json import logging +import math import queue from datetime import datetime from time import perf_counter @@ -141,14 +142,12 @@ def generate_chatml_messages_with_context( tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" - # Set max prompt size from user config, pre-configured for model or to default prompt size - try: - max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] - except: - max_prompt_size = 2000 - logger.warning( - f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." - ) + # Set max prompt size from user config or based on pre-configured for model and machine specs + if not max_prompt_size: + if loaded_model: + max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) + else: + max_prompt_size = model_to_prompt_size.get(model_name, 2000) # Scale lookback turns proportional to max prompt size supported by model lookback_turns = max_prompt_size // 750 @@ -187,7 +186,7 @@ def truncate_messages( max_prompt_size, model_name: str, loaded_model: Optional[Llama] = None, - tokenizer_name=None, + tokenizer_name="hf-internal-testing/llama-tokenizer", ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" @@ -197,15 +196,11 @@ def truncate_messages( elif model_name.startswith("gpt-"): encoder = tiktoken.encoding_for_model(model_name) else: - try: - encoder = download_model(model_name).tokenizer() - except: - encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + encoder = download_model(model_name).tokenizer() except: - default_tokenizer = "hf-internal-testing/llama-tokenizer" - encoder = AutoTokenizer.from_pretrained(default_tokenizer) + encoder = AutoTokenizer.from_pretrained(tokenizer_name) logger.warning( - f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." + f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." ) # Extract system message from messages diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cf84e724..c511b6d9 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -289,9 +289,7 @@ async def extract_references_and_questions( return compiled_references, inferred_queries, q if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): - logger.warning( - "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." - ) + logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") return compiled_references, inferred_queries, q # Extract filter terms from user message @@ -317,8 +315,9 @@ async def extract_references_and_questions( using_offline_chat = True default_offline_llm = await ConversationAdapters.get_default_offline_llm() chat_model = default_offline_llm.chat_model + max_tokens = default_offline_llm.max_prompt_size if state.offline_chat_processor_config is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model @@ -328,6 +327,7 @@ async def extract_references_and_questions( conversation_log=meta_log, should_extract_questions=True, location_data=location_data, + max_prompt_size=conversation_config.max_prompt_size, ) elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4e7a8cc9..9af00053 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -76,6 +76,7 @@ def chat_history( request: Request, common: CommonQueryParams, conversation_id: Optional[int] = None, + n: Optional[int] = None, ): user = request.user.object validate_conversation_config() @@ -109,6 +110,13 @@ def chat_history( } ) + # Get latest N messages if N > 0 + if n > 0: + meta_log["chat"] = meta_log["chat"][-n:] + # Else return all messages except latest N + else: + meta_log["chat"] = meta_log["chat"][:n] + update_telemetry_state( request=request, telemetry_type="api", @@ -425,8 +433,7 @@ async def websocket_endpoint( api="chat", metadata={"conversation_command": conversation_commands[0].value}, ) - intent_type = "text-to-image" - image, status_code, improved_image_prompt, image_url = await text_to_image( + image, status_code, improved_image_prompt, intent_type = await text_to_image( q, user, meta_log, @@ -445,9 +452,6 @@ async def websocket_endpoint( await send_complete_llm_response(json.dumps(content_obj)) continue - if image_url: - intent_type = "text-to-image2" - image = image_url await sync_to_async(save_to_conversation_log)( q, image, @@ -621,17 +625,13 @@ async def chat( metadata={"conversation_command": conversation_commands[0].value}, **common.__dict__, ) - intent_type = "text-to-image" - image, status_code, improved_image_prompt, image_url = await text_to_image( + image, status_code, improved_image_prompt, intent_type = await text_to_image( q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results ) if image is None: content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - if image_url: - intent_type = "text-to-image2" - image = image_url await sync_to_async(save_to_conversation_log)( q, image, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f3be3162..3c93385d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,6 @@ import asyncio +import base64 +import io import json import logging from concurrent.futures import ThreadPoolExecutor @@ -18,6 +20,7 @@ from typing import ( import openai from fastapi import Depends, Header, HTTPException, Request, UploadFile +from PIL import Image from starlette.authentication import has_required_scope from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters @@ -46,6 +49,7 @@ from khoj.utils import state from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( ConversationCommand, + ImageIntentType, is_none_or_empty, is_valid_url, log_telemetry, @@ -79,9 +83,10 @@ async def is_ready_to_chat(user: KhojUser): if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline": chat_model = user_conversation_config.chat_model + max_tokens = user_conversation_config.max_prompt_size if state.offline_chat_processor_config is None: logger.info("Loading Offline Chat Model...") - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) return True ready = has_openai_config or has_offline_config @@ -382,10 +387,11 @@ async def send_message_to_model_wrapper( raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size if conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( @@ -452,7 +458,9 @@ def generate_chat_response( conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) if conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) + chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model chat_response = converse_offline( @@ -508,18 +516,19 @@ async def text_to_image( references: List[str], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, -) -> Tuple[Optional[str], int, Optional[str], Optional[str]]: +) -> Tuple[Optional[str], int, Optional[str], str]: status_code = 200 image = None response = None image_url = None + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image, status_code, message, image_url + return image_url or image, status_code, message, intent_type.value elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name @@ -550,21 +559,38 @@ async def text_to_image( ) image = response.data[0].b64_json + with timer("Convert image to webp", logger): + # Convert png to webp for faster loading + decoded_image = base64.b64decode(image) + image_io = io.BytesIO(decoded_image) + png_image = Image.open(image_io) + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + webp_image_bytes = webp_image_io.getvalue() + webp_image_io.close() + image_io.close() + with timer("Upload image to S3", logger): - image_url = upload_image(image, user.uuid) - return image, status_code, improved_image_prompt, image_url + image_url = upload_image(webp_image_bytes, user.uuid) + if image_url: + intent_type = ImageIntentType.TEXT_TO_IMAGE2 + else: + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 + image = base64.b64encode(webp_image_bytes).decode("utf-8") + + return image_url or image, status_code, improved_image_prompt, intent_type.value except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image, status_code, message, image_url + return image_url or image, status_code, message, intent_type.value else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image, status_code, message, image_url - return image, status_code, response, image_url + return image_url or image, status_code, message, intent_type.value + return image_url or image, status_code, response, intent_type.value class ApiUserRateLimiter: diff --git a/src/khoj/routers/storage.py b/src/khoj/routers/storage.py index 57c28c5a..9a5d448f 100644 --- a/src/khoj/routers/storage.py +++ b/src/khoj/routers/storage.py @@ -1,4 +1,3 @@ -import base64 import logging import os import uuid @@ -17,16 +16,15 @@ if aws_enabled: s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY) -def upload_image(image: str, user_id: uuid.UUID): +def upload_image(image: bytes, user_id: uuid.UUID): """Upload the image to the S3 bucket""" if not aws_enabled: logger.info("AWS is not enabled. Skipping image upload") return None - decoded_image = base64.b64decode(image) - image_key = f"{user_id}/{uuid.uuid4()}.png" + image_key = f"{user_id}/{uuid.uuid4()}.webp" try: - s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read") + s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read") url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}" return url except Exception as e: diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 3f95030f..1732271a 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -69,11 +69,11 @@ class OfflineChatProcessorConfig: class OfflineChatProcessorModel: - def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"): + def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None): self.chat_model = chat_model self.loaded_model = None try: - self.loaded_model = download_model(self.chat_model) + self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens) except ValueError as e: self.loaded_model = None logger.error(f"Error while loading offline chat model: {e}", exc_info=True) diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index f6a66b4f..e621f53e 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -17,6 +17,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Optional, Union from urllib.parse import urlparse +import psutil import torch from asgiref.sync import sync_to_async from magika import Magika @@ -233,6 +234,10 @@ def get_server_id(): return server_id +def telemetry_disabled(app_config: AppConfig): + return not app_config or not app_config.should_log_telemetry + + def log_telemetry( telemetry_type: str, api: str = None, @@ -242,7 +247,7 @@ def log_telemetry( ): """Log basic app usage telemetry like client, os, api called""" # Do not log usage telemetry, if telemetry is disabled via app config - if not app_config or not app_config.should_log_telemetry: + if telemetry_disabled(app_config): return [] if properties.get("server_id") is None: @@ -267,6 +272,17 @@ def log_telemetry( return request_body +def get_device_memory() -> int: + """Get device memory in GB""" + device = get_device() + if device.type == "cuda": + return torch.cuda.get_device_properties(device).total_memory + elif device.type == "mps": + return torch.mps.driver_allocated_memory() + else: + return psutil.virtual_memory().total + + def get_device() -> torch.device: """Get device to run model on""" if torch.cuda.is_available(): @@ -313,6 +329,20 @@ mode_descriptions_for_llm = { } +class ImageIntentType(Enum): + """ + Chat message intent by Khoj for image responses. + Marks the schema used to reference image in chat messages + """ + + # Images as Inline PNG + TEXT_TO_IMAGE = "text-to-image" + # Images as URLs + TEXT_TO_IMAGE2 = "text-to-image2" + # Images as Inline WebP + TEXT_TO_IMAGE_V3 = "text-to-image-v3" + + def generate_random_name(): # List of adjectives and nouns to choose from adjectives = [