From 382507051ffbc85617cfcceb49e2daa259864b7f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 11:02:06 +0530 Subject: [PATCH 01/15] Fix get_user_photo to only return photo, not user name from DB --- src/khoj/database/adapters/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 95a75b7d..fd1a2314 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -197,9 +197,6 @@ def get_user_name(user: KhojUser): def get_user_photo(user: KhojUser): - full_name = user.get_full_name() - if not is_none_or_empty(full_name): - return full_name google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first() if google_profile: return google_profile.picture From b8bc6bee83137be09d2563ea0e18b5f76c34aefa Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 11:02:44 +0530 Subject: [PATCH 02/15] Always remove loading animation on Desktop app if can't login to server --- src/interface/desktop/chat.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 627ce3de..58fe0d57 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -726,9 +726,9 @@ // If the server returns a 500 error with detail, render a setup hint. if (!firstRunSetupMessageRendered) { renderFirstRunSetupMessage(); - fadeOutLoadingAnimation(loadingScreen); } - return; + fadeOutLoadingAnimation(loadingScreen); + return; }); await refreshChatSessionsPanel(); From b820daf38ffc622a378a0440f69ec85572f4e5e3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 11:22:58 +0530 Subject: [PATCH 03/15] Makes logs less noisy - Show telemetry enabled/disabled state on init, not every 2 minutes - Convert no docs synced logs to debug level instead of warning Having synced docs isn't as important to use Khoj now, unlike before --- src/khoj/configure.py | 9 +++++---- src/khoj/routers/api.py | 4 +--- src/khoj/utils/helpers.py | 6 +++++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 60aaf658..9f8abeb4 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state from khoj.utils.config import SearchType from khoj.utils.fs_syncer import collect_files -from khoj.utils.helpers import is_none_or_empty +from khoj.utils.helpers import is_none_or_empty, telemetry_disabled from khoj.utils.rawconfig import FullConfig logger = logging.getLogger(__name__) @@ -232,6 +232,9 @@ def configure_server( state.search_models = configure_search(state.search_models, state.config.search_type) setup_default_agent() + message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled" + logger.info(message) + if not init: initialize_content(regenerate, search_type, user) @@ -329,9 +332,7 @@ def configure_search_types(): @schedule.repeat(schedule.every(2).minutes) def upload_telemetry(): - if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry: - message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled" - logger.debug(message) + if telemetry_disabled(state.config.app) or not state.telemetry: return try: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cf84e724..7f546832 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -289,9 +289,7 @@ async def extract_references_and_questions( return compiled_references, inferred_queries, q if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): - logger.warning( - "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." - ) + logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") return compiled_references, inferred_queries, q # Extract filter terms from user message diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index f6a66b4f..a61387d5 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -233,6 +233,10 @@ def get_server_id(): return server_id +def telemetry_disabled(app_config: AppConfig): + return not app_config or not app_config.should_log_telemetry + + def log_telemetry( telemetry_type: str, api: str = None, @@ -242,7 +246,7 @@ def log_telemetry( ): """Log basic app usage telemetry like client, os, api called""" # Do not log usage telemetry, if telemetry is disabled via app config - if not app_config or not app_config.should_log_telemetry: + if telemetry_disabled(app_config): return [] if properties.get("server_id") is None: From d21f22ffa1517dfe079a7f50f3c84699217c131e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 13:03:32 +0530 Subject: [PATCH 04/15] Store Khoj generated images as webp instead of png for faster loading --- src/khoj/routers/api_chat.py | 12 ++---------- src/khoj/routers/helpers.py | 35 ++++++++++++++++++++++++++++------- src/khoj/routers/storage.py | 8 +++----- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4e7a8cc9..76cf6d12 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -425,8 +425,7 @@ async def websocket_endpoint( api="chat", metadata={"conversation_command": conversation_commands[0].value}, ) - intent_type = "text-to-image" - image, status_code, improved_image_prompt, image_url = await text_to_image( + image, status_code, improved_image_prompt, intent_type = await text_to_image( q, user, meta_log, @@ -445,9 +444,6 @@ async def websocket_endpoint( await send_complete_llm_response(json.dumps(content_obj)) continue - if image_url: - intent_type = "text-to-image2" - image = image_url await sync_to_async(save_to_conversation_log)( q, image, @@ -621,17 +617,13 @@ async def chat( metadata={"conversation_command": conversation_commands[0].value}, **common.__dict__, ) - intent_type = "text-to-image" - image, status_code, improved_image_prompt, image_url = await text_to_image( + image, status_code, improved_image_prompt, intent_type = await text_to_image( q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results ) if image is None: content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - if image_url: - intent_type = "text-to-image2" - image = image_url await sync_to_async(save_to_conversation_log)( q, image, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f3be3162..cbf29c02 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,6 @@ import asyncio +import base64 +import io import json import logging from concurrent.futures import ThreadPoolExecutor @@ -18,6 +20,7 @@ from typing import ( import openai from fastapi import Depends, Header, HTTPException, Request, UploadFile +from PIL import Image from starlette.authentication import has_required_scope from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters @@ -508,18 +511,19 @@ async def text_to_image( references: List[str], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, -) -> Tuple[Optional[str], int, Optional[str], Optional[str]]: +) -> Tuple[Optional[str], int, Optional[str], str]: status_code = 200 image = None response = None image_url = None + intent_type = "text-to-image-v3" text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image, status_code, message, image_url + return image_url or image, status_code, message, intent_type elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name @@ -550,21 +554,38 @@ async def text_to_image( ) image = response.data[0].b64_json + with timer("Convert image to webp", logger): + # Convert png to webp for faster loading + decoded_image = base64.b64decode(image) + image_io = io.BytesIO(decoded_image) + png_image = Image.open(image_io) + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + webp_image_bytes = webp_image_io.getvalue() + webp_image_io.close() + image_io.close() + with timer("Upload image to S3", logger): - image_url = upload_image(image, user.uuid) - return image, status_code, improved_image_prompt, image_url + image_url = upload_image(webp_image_bytes, user.uuid) + if image_url: + intent_type = "text-to-image-v2" + else: + intent_type = "text-to-image-v3" + image = base64.b64encode(webp_image_bytes).decode("utf-8") + + return image_url or image, status_code, improved_image_prompt, intent_type except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image, status_code, message, image_url + return image_url or image, status_code, message, intent_type else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image, status_code, message, image_url - return image, status_code, response, image_url + return image_url or image, status_code, message, intent_type + return image_url or image, status_code, response, intent_type class ApiUserRateLimiter: diff --git a/src/khoj/routers/storage.py b/src/khoj/routers/storage.py index 57c28c5a..9a5d448f 100644 --- a/src/khoj/routers/storage.py +++ b/src/khoj/routers/storage.py @@ -1,4 +1,3 @@ -import base64 import logging import os import uuid @@ -17,16 +16,15 @@ if aws_enabled: s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY) -def upload_image(image: str, user_id: uuid.UUID): +def upload_image(image: bytes, user_id: uuid.UUID): """Upload the image to the S3 bucket""" if not aws_enabled: logger.info("AWS is not enabled. Skipping image upload") return None - decoded_image = base64.b64decode(image) - image_key = f"{user_id}/{uuid.uuid4()}.png" + image_key = f"{user_id}/{uuid.uuid4()}.webp" try: - s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read") + s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read") url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}" return url except Exception as e: From c6e844363179ebb63006360e94e2bb49d6892368 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 13:11:18 +0530 Subject: [PATCH 05/15] Update clients to support rendering webp images inline This is for self-hosted scenarios where AWS S3 uploads is not enabled --- src/interface/desktop/chat.html | 6 ++++++ src/interface/obsidian/src/chat_modal.ts | 4 ++++ src/khoj/interface/web/chat.html | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 58fe0d57..655aeff7 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -214,6 +214,8 @@ imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; @@ -288,6 +290,8 @@ imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { @@ -509,6 +513,8 @@ rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } else if (responseAsJson.intentType === "text-to-image2") { rawResponse += `![${query}](${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image-v3") { + rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`; } const inferredQueries = responseAsJson.inferredQueries?.[0]; if (inferredQueries) { diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 328ce299..504ce4db 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -156,6 +156,8 @@ export class KhojChatModal extends Modal { imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } if (inferredQueries) { imageMarkdown += "\n\n**Inferred Query**:"; @@ -429,6 +431,8 @@ export class KhojChatModal extends Modal { responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`; } else if (responseAsJson.intentType === "text-to-image2") { responseText += `![${query}](${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image-v3") { + responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`; } const inferredQuery = responseAsJson.inferredQueries?.[0]; if (inferredQuery) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 87d42fd6..10a962c4 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -244,6 +244,8 @@ To get started, just start typing below. You can also type / to see a list of co imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { @@ -312,6 +314,8 @@ To get started, just start typing below. You can also type / to see a list of co imageMarkdown = `![](data:image/png;base64,${message})`; } else if (intentType === "text-to-image2") { imageMarkdown = `![](${message})`; + } else if (intentType === "text-to-image-v3") { + imageMarkdown = `![](data:image/webp;base64,${message})`; } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { @@ -619,6 +623,8 @@ To get started, just start typing below. You can also type / to see a list of co rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`; } else if (imageJson.intentType === "text-to-image2") { rawResponse += `![generated_image](${imageJson.image})`; + } else if (imageJson.intentType === "text-to-image-v3") { + rawResponse = `![](data:image/webp;base64,${imageJson.image})`; } if (inferredQuery) { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; From 78bac4ae05cf8056bcffc168541280b9a27f3435 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 19:06:28 +0530 Subject: [PATCH 06/15] Add migration script to convert PNG to WebP references in database --- .../migrations/0035_convert_png_to_webp.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 src/khoj/database/migrations/0035_convert_png_to_webp.py diff --git a/src/khoj/database/migrations/0035_convert_png_to_webp.py b/src/khoj/database/migrations/0035_convert_png_to_webp.py new file mode 100644 index 00000000..7d28a07d --- /dev/null +++ b/src/khoj/database/migrations/0035_convert_png_to_webp.py @@ -0,0 +1,77 @@ +# Generated by Django 4.2.10 on 2024-04-13 17:54 + +import base64 +import io + +from django.db import migrations +from PIL import Image + + +def convert_png_images_to_webp(apps, schema_editor): + # Get the model from the versioned app registry to ensure the correct version is used + Conversations = apps.get_model("database", "Conversation") + for conversation in Conversations.objects.all(): + for chat in conversation.conversation_log["chat"]: + if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image": + # Decode the base64 encoded PNG image + decoded_image = base64.b64decode(chat["message"]) + + # Convert images from PNG to WebP format + image_io = io.BytesIO(decoded_image) + with Image.open(image_io) as png_image: + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + + # Encode the WebP image back to base64 + webp_image_bytes = webp_image_io.getvalue() + chat["message"] = base64.b64encode(webp_image_bytes).decode() + chat["intent"]["type"] = "text-to-image-v3" + webp_image_io.close() + + if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2": + print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.") + # Convert PNG url to WebP url + chat["message"] = chat["message"].replace(".png", ".webp") + + # Save the updated conversation history + conversation.save() + + +def convert_webp_images_to_png(apps, schema_editor): + # Get the model from the versioned app registry to ensure the correct version is used + Conversations = apps.get_model("database", "Conversation") + for conversation in Conversations.objects.all(): + for chat in conversation.conversation_log["chat"]: + if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image": + # Decode the base64 encoded PNG image + decoded_image = base64.b64decode(chat["message"]) + + # Convert images from PNG to WebP format + image_io = io.BytesIO(decoded_image) + with Image.open(image_io) as png_image: + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "PNG") + + # Encode the WebP image back to base64 + webp_image_bytes = webp_image_io.getvalue() + chat["message"] = base64.b64encode(webp_image_bytes).decode() + chat["intent"]["type"] = "text-to-image" + webp_image_io.close() + + if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2": + # Convert WebP url to PNG url + print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.") + chat["message"] = chat["message"].replace(".webp", ".png") + + # Save the updated conversation history + conversation.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0034_alter_chatmodeloptions_chat_model"), + ] + + operations = [ + migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png), + ] From 148923c13a8ef619359661a4f7ed2b245d889569 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 22:09:13 +0530 Subject: [PATCH 07/15] Fix to raise error on hitting rate limit during Github indexing --- src/khoj/processor/content/github/github_to_entries.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 02fa4cf0..2aa63d4e 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries): markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo) except ConnectionAbortedError as e: logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}") + raise e except Exception as e: logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True) raise e From 689202e00eaf4cd8fdbc9b05b3024512e72f9a08 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 22:11:10 +0530 Subject: [PATCH 08/15] Update recommended CMAKE flag to enable using CUDA on linux in Docs --- documentation/docs/get-started/setup.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx index 4aa2f960..7b2866f4 100644 --- a/documentation/docs/get-started/setup.mdx +++ b/documentation/docs/get-started/setup.mdx @@ -134,7 +134,7 @@ python -m pip install khoj-assistant # CPU python -m pip install khoj-assistant # NVIDIA (CUDA) GPU - CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant + CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant # AMD (ROCm) GPU CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant # VULCAN GPU From 4977b55106e7aac94a5d618fae990b71fb9bcc1a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 22:15:34 +0530 Subject: [PATCH 09/15] Use offline chat prompt config to set context window of loaded chat model Previously you couldn't configure the n_ctx of the loaded offline chat model. This made it hard to use good offline chat model (which these days also have larger context) on machines with lower VRAM --- pyproject.toml | 2 +- .../conversation/offline/chat_model.py | 20 ++++---- .../processor/conversation/offline/utils.py | 46 +++++++++++++------ src/khoj/processor/conversation/utils.py | 27 +++++------ src/khoj/routers/api.py | 4 +- src/khoj/routers/helpers.py | 10 ++-- src/khoj/utils/config.py | 4 +- src/khoj/utils/helpers.py | 12 +++++ 8 files changed, 81 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0b7483a1..c9c96691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "phonenumbers == 8.13.27", "markdownify ~= 0.11.6", "websockets == 12.0", + "psutil >= 5.8.0", ] dynamic = ["version"] @@ -105,7 +106,6 @@ dev = [ "pytest-asyncio == 0.21.1", "freezegun >= 1.2.0", "factory-boy >= 3.2.1", - "psutil >= 5.8.0", "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 10dc08fa..a559df22 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -30,6 +30,7 @@ def extract_questions_offline( use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, + max_prompt_size: int = None, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -41,7 +42,7 @@ def extract_questions_offline( return all_questions assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" @@ -67,12 +68,14 @@ def extract_questions_offline( location=location, ) messages = generate_chatml_messages_with_context( - example_questions, model_name=model, loaded_model=offline_chat_model + example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size ) state.chat_lock.acquire() try: - response = send_message_to_model_offline(messages, loaded_model=offline_chat_model) + response = send_message_to_model_offline( + messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size + ) finally: state.chat_lock.release() @@ -138,7 +141,7 @@ def converse_offline( """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) compiled_references_message = "\n\n".join({f"{item}" for item in references}) current_date = datetime.now().strftime("%Y-%m-%d") @@ -190,18 +193,18 @@ def converse_offline( ) g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) t.start() return g -def llm_thread(g, messages: List[ChatMessage], model: Any): +def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): stop_phrases = ["", "INST]", "Notes:"] state.chat_lock.acquire() try: response_iterator = send_message_to_model_offline( - messages, loaded_model=model, stop=stop_phrases, streaming=True + messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True ) for response in response_iterator: g.send(response["choices"][0]["delta"].get("content", "")) @@ -216,9 +219,10 @@ def send_message_to_model_offline( model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", streaming=False, stop=[], + max_prompt_size: int = None, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) if streaming: diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index b711c11a..c2b08bfa 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -1,18 +1,19 @@ import glob import logging +import math import os from huggingface_hub.constants import HF_HUB_CACHE from khoj.utils import state +from khoj.utils.helpers import get_device_memory logger = logging.getLogger(__name__) -def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): - from llama_cpp.llama import Llama - - # Initialize Model Parameters. Use n_ctx=0 to get context size from the model +def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None): + # Initialize Model Parameters + # Use n_ctx=0 to get context size from the model kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False} # Decide whether to load model to GPU or CPU @@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): model_path = load_model_from_cache(repo_id, filename) chat_model = None try: - if model_path: - chat_model = Llama(model_path, **kwargs) - else: - Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + chat_model = load_model(model_path, repo_id, filename, kwargs) except: # Load model on CPU if GPU is not available kwargs["n_gpu_layers"], device = 0, "cpu" - if model_path: - chat_model = Llama(model_path, **kwargs) - else: - chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + chat_model = load_model(model_path, repo_id, filename, kwargs) - logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}") + # Now load the model with context size set based on: + # 1. context size supported by model and + # 2. configured size or machine (V)RAM + kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens) + chat_model = load_model(model_path, repo_id, filename, kwargs) + logger.debug( + f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window." + ) return chat_model +def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}): + from llama_cpp.llama import Llama + + if model_path: + return Llama(model_path, **kwargs) + else: + return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + + def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): # Construct the path to the model file in the cache directory repo_org, repo_name = repo_id.split("/") @@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): return paths[0] else: return None + + +def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int: + """Infer max prompt size based on device memory and max context window supported by the model""" + vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic + if configured_max_tokens: + return min(configured_max_tokens, model_context_window) + else: + return min(vram_based_n_ctx, model_context_window) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 845ccb48..e787eedf 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,5 +1,6 @@ import json import logging +import math import queue from datetime import datetime from time import perf_counter @@ -141,14 +142,12 @@ def generate_chatml_messages_with_context( tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" - # Set max prompt size from user config, pre-configured for model or to default prompt size - try: - max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] - except: - max_prompt_size = 2000 - logger.warning( - f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." - ) + # Set max prompt size from user config or based on pre-configured for model and machine specs + if not max_prompt_size: + if loaded_model: + max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) + else: + max_prompt_size = model_to_prompt_size.get(model_name, 2000) # Scale lookback turns proportional to max prompt size supported by model lookback_turns = max_prompt_size // 750 @@ -187,7 +186,7 @@ def truncate_messages( max_prompt_size, model_name: str, loaded_model: Optional[Llama] = None, - tokenizer_name=None, + tokenizer_name="hf-internal-testing/llama-tokenizer", ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" @@ -197,15 +196,11 @@ def truncate_messages( elif model_name.startswith("gpt-"): encoder = tiktoken.encoding_for_model(model_name) else: - try: - encoder = download_model(model_name).tokenizer() - except: - encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + encoder = download_model(model_name).tokenizer() except: - default_tokenizer = "hf-internal-testing/llama-tokenizer" - encoder = AutoTokenizer.from_pretrained(default_tokenizer) + encoder = AutoTokenizer.from_pretrained(tokenizer_name) logger.warning( - f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." + f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." ) # Extract system message from messages diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7f546832..c511b6d9 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -315,8 +315,9 @@ async def extract_references_and_questions( using_offline_chat = True default_offline_llm = await ConversationAdapters.get_default_offline_llm() chat_model = default_offline_llm.chat_model + max_tokens = default_offline_llm.max_prompt_size if state.offline_chat_processor_config is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model @@ -326,6 +327,7 @@ async def extract_references_and_questions( conversation_log=meta_log, should_extract_questions=True, location_data=location_data, + max_prompt_size=conversation_config.max_prompt_size, ) elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index cbf29c02..06d849ca 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -82,9 +82,10 @@ async def is_ready_to_chat(user: KhojUser): if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline": chat_model = user_conversation_config.chat_model + max_tokens = user_conversation_config.max_prompt_size if state.offline_chat_processor_config is None: logger.info("Loading Offline Chat Model...") - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) return True ready = has_openai_config or has_offline_config @@ -385,10 +386,11 @@ async def send_message_to_model_wrapper( raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size if conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( @@ -455,7 +457,9 @@ def generate_chat_response( conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) if conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) + chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model chat_response = converse_offline( diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 3f95030f..1732271a 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -69,11 +69,11 @@ class OfflineChatProcessorConfig: class OfflineChatProcessorModel: - def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"): + def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None): self.chat_model = chat_model self.loaded_model = None try: - self.loaded_model = download_model(self.chat_model) + self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens) except ValueError as e: self.loaded_model = None logger.error(f"Error while loading offline chat model: {e}", exc_info=True) diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index a61387d5..04974b7d 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -17,6 +17,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Optional, Union from urllib.parse import urlparse +import psutil import torch from asgiref.sync import sync_to_async from magika import Magika @@ -271,6 +272,17 @@ def log_telemetry( return request_body +def get_device_memory() -> int: + """Get device memory in GB""" + device = get_device() + if device.type == "cuda": + return torch.cuda.get_device_properties(device).total_memory + elif device.type == "mps": + return torch.mps.driver_allocated_memory() + else: + return psutil.virtual_memory().total + + def get_device() -> torch.device: """Get device to run model on""" if torch.cuda.is_available(): From d5de59d411ffb0c10e0d68d9dfe33e08ebff5933 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 08:01:36 +0530 Subject: [PATCH 10/15] Do not assume results key present in notion content when indexing --- src/khoj/processor/content/notion/notion_to_entries.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 6e078f07..57456ed5 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries): for response in responses: with timer("Processing response", logger=logger): - pages_or_databases = response["results"] if response.get("results") else [] + pages_or_databases = response.get("results", []) # Get all pages content for p_or_d in pages_or_databases: @@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries): current_entries = [] curr_heading = "" - for block in content["results"]: + for block in content.get("results", []): block_type = block.get("type") if block_type == None: @@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries): return f"\n{heading}\n" def process_nested_children(self, children, raw_content, block_type=None): - results = children["results"] if children.get("results") else [] + results = children.get("results", []) for child in results: child_type = child.get("type") if child_type == None: From e5ff85f6fb5058391b0cc3443651f73757725cc7 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 14:14:05 +0530 Subject: [PATCH 11/15] Start fetching khoj css before icons to reduce time with no styling This should reduce frequency of page load jitter when icons are loaded before style is applied --- src/interface/desktop/chat.html | 2 +- src/khoj/interface/web/chat.html | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 655aeff7..c5b25318 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -4,9 +4,9 @@ Khoj - Chat + - diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 10a962c4..79df4f2c 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -4,10 +4,10 @@ Khoj - Chat + - From 9e5585776c8621545239b29fced83fb88e275ec4 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 13:59:59 +0530 Subject: [PATCH 12/15] Support getting latest N chat messages via chat history API Get latest N if N > 0, else return all messages except latest N from the conversation --- src/khoj/routers/api_chat.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 76cf6d12..9af00053 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -76,6 +76,7 @@ def chat_history( request: Request, common: CommonQueryParams, conversation_id: Optional[int] = None, + n: Optional[int] = None, ): user = request.user.object validate_conversation_config() @@ -109,6 +110,13 @@ def chat_history( } ) + # Get latest N messages if N > 0 + if n > 0: + meta_log["chat"] = meta_log["chat"][-n:] + # Else return all messages except latest N + else: + meta_log["chat"] = meta_log["chat"][:n] + update_telemetry_state( request=request, telemetry_type="api", From 128829c4770e9580f585b5a22baa8d0614047e4c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 13:44:01 +0530 Subject: [PATCH 13/15] Show latest msgs on chat session load. Fetch rest as they near viewport - Reduces time to first render when loading long chat sessions - Limits size of first page load, when loading long chat sessions These performance improvements are maximally felt for large chat sessions with lots of images generated by Khoj Updated web and desktop app to support these changes for now --- src/interface/desktop/chat.html | 129 +++++++++++++++++++++++++----- src/khoj/interface/web/chat.html | 130 ++++++++++++++++++++++++++----- 2 files changed, 217 insertions(+), 42 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index c5b25318..fc7ecc2e 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -130,7 +130,7 @@ return referenceButton; } - function renderMessage(message, by, dt=null, annotations=null, raw=false) { + function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") { let message_time = formatDate(dt ?? new Date()); let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let formattedMessage = formatHTMLMessage(message, raw); @@ -153,10 +153,15 @@ // Append chat message div to chat body let chatBody = document.getElementById("chat-body"); - chatBody.appendChild(chatMessage); - - // Scroll to bottom of chat-body element - chatBody.scrollTop = chatBody.scrollHeight; + if (renderType === "append") { + chatBody.appendChild(chatMessage); + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + } else if (renderType === "prepend") { + chatBody.insertBefore(chatMessage, chatBody.firstChild); + } else if (renderType === "return") { + return chatMessage; + } let chatBodyWrapper = document.getElementById("chat-body-wrapper"); chatBodyWrapperHeight = chatBodyWrapper.clientHeight; @@ -207,6 +212,7 @@ } function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { + // If no document or online context is provided, render the message as is if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if (intentType?.includes("text-to-image")) { let imageMarkdown; @@ -222,24 +228,21 @@ if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt); - return; + return renderMessage(imageMarkdown, by, dt, null, false, "return"); } - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } if (context == null && onlineContext == null) { - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } + // If document or online context is provided, render the message with its references let references = document.createElement('div'); let referenceExpandButton = document.createElement('button'); @@ -297,11 +300,10 @@ if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt, references); - return; + return renderMessage(imageMarkdown, by, dt, references, false, "return"); } - renderMessage(message, by, dt, references); + return renderMessage(message, by, dt, references, false, "return"); } function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) { @@ -677,7 +679,7 @@ let firstRunSetupMessageRendered = false; let chatBody = document.getElementById("chat-body"); chatBody.innerHTML = ""; - let chatHistoryUrl = `/api/chat/history?client=desktop`; + let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`; if (chatBody.dataset.conversationId) { chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; } @@ -689,7 +691,8 @@ loadingScreen.appendChild(yellowOrb); chatBody.appendChild(loadingScreen); - fetch(`${hostURL}${chatHistoryUrl}`, { headers }) + // Get the most recent 10 chat messages from conversation history + fetch(`${chatHistoryUrl}&n=10`, { headers }) .then(response => response.json()) .then(data => { if (data.detail) { @@ -709,11 +712,21 @@ chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; - const fullChatLog = response.chat || []; + // Create a new IntersectionObserver + let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, fetch the remaining message and unobserve the element + if (entry.isIntersecting) { + fetchRemainingChatMessages(chatHistoryUrl); + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 0px 0px'}); - fullChatLog.forEach(chat_log => { + const fullChatLog = response.chat || []; + fullChatLog.forEach((chat_log, index) => { if (chat_log.message != null) { - renderMessageWithReference( + let messageElement = renderMessageWithReference( chat_log.message, chat_log.by, chat_log.context, @@ -721,10 +734,25 @@ chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]); + chatBody.appendChild(messageElement); + + // When the 4th oldest message is within viewing distance (~60% scrolled up) + // Fetch the remaining chat messages + if (index === 4) { + fetchRemainingMessagesObserver.observe(messageElement); + } } loadingScreen.style.height = chatBody.scrollHeight + 'px'; }) + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + + // Set height of chat-body element to the height of the chat-body-wrapper + let chatBodyWrapper = document.getElementById("chat-body-wrapper"); + let chatBodyWrapperHeight = chatBodyWrapper.clientHeight; + chatBody.style.height = chatBodyWrapperHeight; + // Add fade out animation to loading screen and remove it after the animation ends fadeOutLoadingAnimation(loadingScreen); }) @@ -784,6 +812,65 @@ } } + function fetchRemainingChatMessages(chatHistoryUrl) { + // Create a new IntersectionObserver + let observer = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, render the message and unobserve the element + if (entry.isIntersecting) { + let chat_log = entry.target.chat_log; + let messageElement = renderMessageWithReference( + chat_log.message, + chat_log.by, + chat_log.context, + new Date(chat_log.created), + chat_log.onlineContext, + chat_log.intent?.type, + chat_log.intent?.["inferred-queries"] + ); + entry.target.replaceWith(messageElement); + + // Remove the observer after the element has been rendered + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport + + // Fetch remaining chat messages from conversation history + fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" }) + .then(response => response.json()) + .then(data => { + if (data.status != "ok") { + throw new Error(data.message); + } + return data.response; + }) + .then(response => { + const fullChatLog = response.chat || []; + let chatBody = document.getElementById("chat-body"); + fullChatLog + .reverse() + .forEach(chat_log => { + if (chat_log.message != null) { + // Create a new element for each chat log + let placeholder = document.createElement('div'); + placeholder.chat_log = chat_log; + + // Insert the message placeholder as the first child of chat body after the welcome message + chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling); + + // Observe the element + placeholder.style.height = "20px"; + observer.observe(placeholder); + } + }); + }) + .catch(err => { + console.log(err); + return; + }); + } + function fadeOutLoadingAnimation(loadingScreen) { let chatBody = document.getElementById("chat-body"); let chatBodyWrapper = document.getElementById("chat-body-wrapper"); diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 79df4f2c..81946fa7 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co return referenceButton; } - function renderMessage(message, by, dt=null, annotations=null, raw=false) { + function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") { let message_time = formatDate(dt ?? new Date()); let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let formattedMessage = formatHTMLMessage(message, raw); @@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co // Append chat message div to chat body let chatBody = document.getElementById("chat-body"); - chatBody.appendChild(chatMessage); - - // Scroll to bottom of chat-body element - chatBody.scrollTop = chatBody.scrollHeight; + if (renderType === "append") { + chatBody.appendChild(chatMessage); + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + } else if (renderType === "prepend"){ + let chatBody = document.getElementById("chat-body"); + chatBody.insertBefore(chatMessage, chatBody.firstChild); + } else if (renderType === "return") { + return chatMessage; + } let chatBodyWrapper = document.getElementById("chat-body-wrapper"); chatBodyWrapperHeight = chatBodyWrapper.clientHeight; @@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co } function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { + // If no document or online context is provided, render the message as is if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if (intentType?.includes("text-to-image")) { let imageMarkdown; @@ -251,19 +258,17 @@ To get started, just start typing below. You can also type / to see a list of co if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt); - return; + return renderMessage(imageMarkdown, by, dt, null, false, "return"); } - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { - renderMessage(message, by, dt); - return; + return renderMessage(message, by, dt, null, false, "return"); } + // If document or online context is provided, render the message with its references let references = document.createElement('div'); let referenceExpandButton = document.createElement('button'); @@ -321,11 +326,10 @@ To get started, just start typing below. You can also type / to see a list of co if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } - renderMessage(imageMarkdown, by, dt, references); - return; + return renderMessage(imageMarkdown, by, dt, references, false, "return"); } - renderMessage(message, by, dt, references); + return renderMessage(message, by, dt, references, false, "return"); } function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) { @@ -1068,7 +1072,8 @@ To get started, just start typing below. You can also type / to see a list of co loadingScreen.appendChild(yellowOrb); chatBody.appendChild(loadingScreen); - fetch(chatHistoryUrl, { method: "GET" }) + // Get the most recent 10 chat messages from conversation history + fetch(`${chatHistoryUrl}&n=10`, { method: "GET" }) .then(response => response.json()) .then(data => { if (data.detail) { @@ -1121,11 +1126,22 @@ To get started, just start typing below. You can also type / to see a list of co agentMetadataElement.style.display = "none"; } - const fullChatLog = response.chat || []; + // Create a new IntersectionObserver + let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, fetch the remaining message and unobserve the element + if (entry.isIntersecting) { + fetchRemainingChatMessages(chatHistoryUrl); + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 0px 0px'}); - fullChatLog.forEach(chat_log => { - if (chat_log.message != null){ - renderMessageWithReference( + const fullChatLog = response.chat || []; + fullChatLog.forEach((chat_log, index) => { + // Render the last 10 messages immediately + if (chat_log.message != null) { + let messageElement = renderMessageWithReference( chat_log.message, chat_log.by, chat_log.context, @@ -1133,14 +1149,26 @@ To get started, just start typing below. You can also type / to see a list of co chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]); + chatBody.appendChild(messageElement); + + // When the 4th oldest message is within viewing distance (~60% scroll up) + // Fetch the remaining chat messages + if (index === 4) { + fetchRemainingMessagesObserver.observe(messageElement); + } } loadingScreen.style.height = chatBody.scrollHeight + 'px'; }); - // Add fade out animation to loading screen and remove it after the animation ends + // Scroll to bottom of chat-body element + chatBody.scrollTop = chatBody.scrollHeight; + + // Set height of chat-body element to the height of the chat-body-wrapper let chatBodyWrapper = document.getElementById("chat-body-wrapper"); - chatBodyWrapperHeight = chatBodyWrapper.clientHeight; + let chatBodyWrapperHeight = chatBodyWrapper.clientHeight; chatBody.style.height = chatBodyWrapperHeight; + + // Add fade out animation to loading screen and remove it after the animation ends setTimeout(() => { loadingScreen.remove(); chatBody.classList.remove("relative-position"); @@ -1198,6 +1226,66 @@ To get started, just start typing below. You can also type / to see a list of co document.getElementById("chat-input").value = query_via_url; chat(); } + + } + + function fetchRemainingChatMessages(chatHistoryUrl) { + // Create a new IntersectionObserver + let observer = new IntersectionObserver((entries, observer) => { + entries.forEach(entry => { + // If the element is in the viewport, render the message and unobserve the element + if (entry.isIntersecting) { + let chat_log = entry.target.chat_log; + let messageElement = renderMessageWithReference( + chat_log.message, + chat_log.by, + chat_log.context, + new Date(chat_log.created), + chat_log.onlineContext, + chat_log.intent?.type, + chat_log.intent?.["inferred-queries"] + ); + entry.target.replaceWith(messageElement); + + // Remove the observer after the element has been rendered + observer.unobserve(entry.target); + } + }); + }, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport + + // Fetch remaining chat messages from conversation history + fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" }) + .then(response => response.json()) + .then(data => { + if (data.status != "ok") { + throw new Error(data.message); + } + return data.response; + }) + .then(response => { + const fullChatLog = response.chat || []; + let chatBody = document.getElementById("chat-body"); + fullChatLog + .reverse() + .forEach(chat_log => { + if (chat_log.message != null) { + // Create a new element for each chat log + let placeholder = document.createElement('div'); + placeholder.chat_log = chat_log; + + // Insert the message placeholder as the first child of chat body after the welcome message + chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling); + + // Observe the element + placeholder.style.height = "20px"; + observer.observe(placeholder); + } + }); + }) + .catch(err => { + console.log(err); + return; + }); } function flashStatusInChatInput(message) { From 7d8e8eb0cf831556c34c4e47ab28ccbbab03623b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 16:44:56 +0530 Subject: [PATCH 14/15] Use Enum to type text-to-image intent of Khoj chat response --- src/khoj/database/admin.py | 15 +++++++++++---- .../migrations/0035_convert_png_to_webp.py | 14 ++++++++------ src/khoj/routers/helpers.py | 17 +++++++++-------- src/khoj/utils/helpers.py | 14 ++++++++++++++ 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 97a0f3ed..9b82029b 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -23,6 +23,7 @@ from khoj.database.models import ( TextToImageModelConfig, UserSearchModelConfig, ) +from khoj.utils.helpers import ImageIntentType class KhojUserAdmin(UserAdmin): @@ -104,9 +105,12 @@ class ConversationAdmin(admin.ModelAdmin): log["by"] == "khoj" and log["intent"] and log["intent"]["type"] - and log["intent"]["type"] == "text-to-image" + and ( + log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value + or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value + ) ): - log["message"] = "image redacted for space" + log["message"] = "inline image redacted for space" chat_log[idx] = log modified_log["chat"] = chat_log @@ -144,9 +148,12 @@ class ConversationAdmin(admin.ModelAdmin): log["by"] == "khoj" and log["intent"] and log["intent"]["type"] - and log["intent"]["type"] == "text-to-image" + and ( + log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value + or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value + ) ): - updated_log["message"] = "image redacted for space" + updated_log["message"] = "inline image redacted for space" chat_log[idx] = updated_log return_log["chat"] = chat_log diff --git a/src/khoj/database/migrations/0035_convert_png_to_webp.py b/src/khoj/database/migrations/0035_convert_png_to_webp.py index 7d28a07d..6ffa024b 100644 --- a/src/khoj/database/migrations/0035_convert_png_to_webp.py +++ b/src/khoj/database/migrations/0035_convert_png_to_webp.py @@ -6,13 +6,15 @@ import io from django.db import migrations from PIL import Image +from khoj.utils.helpers import ImageIntentType + def convert_png_images_to_webp(apps, schema_editor): # Get the model from the versioned app registry to ensure the correct version is used Conversations = apps.get_model("database", "Conversation") for conversation in Conversations.objects.all(): for chat in conversation.conversation_log["chat"]: - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value: # Decode the base64 encoded PNG image decoded_image = base64.b64decode(chat["message"]) @@ -25,10 +27,10 @@ def convert_png_images_to_webp(apps, schema_editor): # Encode the WebP image back to base64 webp_image_bytes = webp_image_io.getvalue() chat["message"] = base64.b64encode(webp_image_bytes).decode() - chat["intent"]["type"] = "text-to-image-v3" + chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value webp_image_io.close() - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.") # Convert PNG url to WebP url chat["message"] = chat["message"].replace(".png", ".webp") @@ -42,7 +44,7 @@ def convert_webp_images_to_png(apps, schema_editor): Conversations = apps.get_model("database", "Conversation") for conversation in Conversations.objects.all(): for chat in conversation.conversation_log["chat"]: - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value: # Decode the base64 encoded PNG image decoded_image = base64.b64decode(chat["message"]) @@ -55,10 +57,10 @@ def convert_webp_images_to_png(apps, schema_editor): # Encode the WebP image back to base64 webp_image_bytes = webp_image_io.getvalue() chat["message"] = base64.b64encode(webp_image_bytes).decode() - chat["intent"]["type"] = "text-to-image" + chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value webp_image_io.close() - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: # Convert WebP url to PNG url print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.") chat["message"] = chat["message"].replace(".webp", ".png") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 06d849ca..3c93385d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -49,6 +49,7 @@ from khoj.utils import state from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( ConversationCommand, + ImageIntentType, is_none_or_empty, is_valid_url, log_telemetry, @@ -520,14 +521,14 @@ async def text_to_image( image = None response = None image_url = None - intent_type = "text-to-image-v3" + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image_url or image, status_code, message, intent_type + return image_url or image, status_code, message, intent_type.value elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name @@ -572,24 +573,24 @@ async def text_to_image( with timer("Upload image to S3", logger): image_url = upload_image(webp_image_bytes, user.uuid) if image_url: - intent_type = "text-to-image-v2" + intent_type = ImageIntentType.TEXT_TO_IMAGE2 else: - intent_type = "text-to-image-v3" + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 image = base64.b64encode(webp_image_bytes).decode("utf-8") - return image_url or image, status_code, improved_image_prompt, intent_type + return image_url or image, status_code, improved_image_prompt, intent_type.value except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type + return image_url or image, status_code, message, intent_type.value else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type - return image_url or image, status_code, response, intent_type + return image_url or image, status_code, message, intent_type.value + return image_url or image, status_code, response, intent_type.value class ApiUserRateLimiter: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 04974b7d..e621f53e 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -329,6 +329,20 @@ mode_descriptions_for_llm = { } +class ImageIntentType(Enum): + """ + Chat message intent by Khoj for image responses. + Marks the schema used to reference image in chat messages + """ + + # Images as Inline PNG + TEXT_TO_IMAGE = "text-to-image" + # Images as URLs + TEXT_TO_IMAGE2 = "text-to-image2" + # Images as Inline WebP + TEXT_TO_IMAGE_V3 = "text-to-image-v3" + + def generate_random_name(): # List of adjectives and nouns to choose from adjectives = [ From a352940dfd67813eadbf1825956370e20b1ff20c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 17:44:05 +0530 Subject: [PATCH 15/15] Use Django management command to update images URL in DB to WebP This provides Khoj server admins more control on migrating their S3 images to WebP format from PNG --- src/khoj/database/management/__init__.py | 0 .../database/management/commands/__init__.py | 0 .../commands/convert_images_png_to_webp.py | 40 +++++++++++++++++++ .../migrations/0035_convert_png_to_webp.py | 10 ----- 4 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 src/khoj/database/management/__init__.py create mode 100644 src/khoj/database/management/commands/__init__.py create mode 100644 src/khoj/database/management/commands/convert_images_png_to_webp.py diff --git a/src/khoj/database/management/__init__.py b/src/khoj/database/management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/database/management/commands/__init__.py b/src/khoj/database/management/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/database/management/commands/convert_images_png_to_webp.py b/src/khoj/database/management/commands/convert_images_png_to_webp.py new file mode 100644 index 00000000..b1ad8615 --- /dev/null +++ b/src/khoj/database/management/commands/convert_images_png_to_webp.py @@ -0,0 +1,40 @@ +from django.core.management.base import BaseCommand + +from khoj.database.models import Conversation +from khoj.utils.helpers import ImageIntentType + + +class Command(BaseCommand): + help = "Convert all images to WebP format or reverse." + + def add_arguments(self, parser): + # Add a new argument 'reverse' to the command + parser.add_argument( + "--reverse", + action="store_true", + help="Convert from WebP to PNG instead of PNG to WebP", + ) + + def handle(self, *args, **options): + updated_count = 0 + for conversation in Conversation.objects.all(): + conversation_updated = False + for chat in conversation.conversation_log["chat"]: + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: + if options["reverse"] and chat["message"].endswith(".webp"): + # Convert WebP url to PNG url + chat["message"] = chat["message"].replace(".webp", ".png") + conversation_updated = True + updated_count += 1 + elif chat["message"].endswith(".png"): + # Convert PNG url to WebP url + chat["message"] = chat["message"].replace(".png", ".webp") + conversation_updated = True + updated_count += 1 + if conversation_updated: + conversation.save() + + if updated_count > 0 and options["reverse"]: + self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format.")) + elif updated_count > 0: + self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format.")) diff --git a/src/khoj/database/migrations/0035_convert_png_to_webp.py b/src/khoj/database/migrations/0035_convert_png_to_webp.py index 6ffa024b..35495629 100644 --- a/src/khoj/database/migrations/0035_convert_png_to_webp.py +++ b/src/khoj/database/migrations/0035_convert_png_to_webp.py @@ -30,11 +30,6 @@ def convert_png_images_to_webp(apps, schema_editor): chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value webp_image_io.close() - if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: - print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.") - # Convert PNG url to WebP url - chat["message"] = chat["message"].replace(".png", ".webp") - # Save the updated conversation history conversation.save() @@ -60,11 +55,6 @@ def convert_webp_images_to_png(apps, schema_editor): chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value webp_image_io.close() - if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: - # Convert WebP url to PNG url - print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.") - chat["message"] = chat["message"].replace(".webp", ".png") - # Save the updated conversation history conversation.save()