diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 45da7dda..641bf68b 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -18,8 +18,8 @@ ENV PATH="/opt/venv/bin:${PATH}" COPY pyproject.toml README.md ./ # Setup python environment - # Use the pre-built llama-cpp-python, torch cpu wheel -ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://abetlen.github.io/llama-cpp-python/whl/cpu" \ + # Use the pre-built torch cpu wheel +ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" \ # Avoid downloading unused cuda specific python packages CUDA_VISIBLE_DEVICES="" \ # Use static version to build app without git dependency diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e8db2c89..8e04a8fc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -64,8 +64,6 @@ jobs: DEBIAN_FRONTEND: noninteractive run: | apt update && apt install -y git libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6 - # required by llama-cpp-python prebuilt wheels - apt install -y musl-dev && ln -s /usr/lib/x86_64-linux-musl/libc.so /lib/libc.musl-x86_64.so.1 - name: ⬇️ Install Postgres env: diff --git a/documentation/docs/advanced/admin.md b/documentation/docs/advanced/admin.md index e04f81e2..961056b8 100644 --- a/documentation/docs/advanced/admin.md +++ b/documentation/docs/advanced/admin.md @@ -20,7 +20,7 @@ Add all the agents you want to use for your different use-cases like Writer, Res ### Chat Model Options Add all the chat models you want to try, use and switch between for your different use-cases. For each chat model you add: - `Chat model`: The name of an [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/en/docs/about-claude/models#model-names), [Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) or [Offline](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf) chat model. -- `Model type`: The chat model provider like `OpenAI`, `Offline`. +- `Model type`: The chat model provider like `OpenAI`, `Google`. - `Vision enabled`: Set to `true` if your model supports vision. This is currently only supported for vision capable OpenAI models like `gpt-4o` - `Max prompt size`, `Subscribed max prompt size`: These are optional fields. They are used to truncate the context to the maximum context size that can be passed to the model. This can help with accuracy and cost-saving.
- `Tokenizer`: This is an optional field. It is used to accurately count tokens and truncate context passed to the chat model to stay within the models max prompt size. diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx index e3c1e12f..01d60496 100644 --- a/documentation/docs/get-started/setup.mdx +++ b/documentation/docs/get-started/setup.mdx @@ -18,10 +18,6 @@ import TabItem from '@theme/TabItem'; These are the general setup instructions for self-hosted Khoj. You can install the Khoj server using either [Docker](?server=docker) or [Pip](?server=pip). -:::info[Offline Model + GPU] -To use the offline chat model with your GPU, we recommend using the Docker setup with Ollama . You can also use the local Khoj setup via the Python package directly. -::: - :::info[First Run] Restart your Khoj server after the first run to ensure all settings are applied correctly. ::: @@ -225,10 +221,6 @@ To start Khoj automatically in the background use [Task scheduler](https://www.w You can now open the web app at http://localhost:42110 and start interacting!
Nothing else is necessary, but you can customize your setup further by following the steps below. -:::info[First Message to Offline Chat Model] -The offline chat model gets downloaded when you first send a message to it. The download can take a few minutes! Subsequent messages should be faster. -::: - ### Add Chat Models

Login to the Khoj Admin Panel

Go to http://localhost:42110/server/admin and login with the admin credentials you setup during installation. @@ -301,13 +293,14 @@ Offline chat stays completely private and can work without internet using any op - A Nvidia, AMD GPU or a Mac M1+ machine would significantly speed up chat responses ::: -1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*. -2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel -3. Set the `chat-model` field to the name of your preferred chat model - - Make sure the `model-type` is set to `Offline` -4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/). -5. Restart the Khoj server and [start chatting](http://localhost:42110) with your new offline model! - +1. Install any Openai API compatible local ai model server like [llama-cpp-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server), Ollama, vLLM etc. +2. Add an [ai model api](http://localhost:42110/server/admin/database/aimodelapi/add/) on the admin panel + - Set the `api url` field to the url of your local ai model provider like `http://localhost:11434/v1/` for Ollama +3. Restart the Khoj server to load models available on your local ai model provider + - If that doesn't work, you'll need to manually add available [chat model](http://localhost:42110/server/admin/database/chatmodel/add) in the admin panel. +4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) +5. [Start chatting](http://localhost:42110) with your local AI! + :::tip[Multiple Chat Models] diff --git a/pyproject.toml b/pyproject.toml index 1a580dae..6491dc06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ dependencies = [ "django == 5.1.10", "django-unfold == 0.42.0", "authlib == 1.2.1", - "llama-cpp-python == 0.2.88", "itsdangerous == 2.1.2", "httpx == 0.28.1", "pgvector == 0.2.4", diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a72f15d5..40d1eeb5 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -50,13 +50,11 @@ from khoj.database.adapters import ( ) from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel -from khoj.routers.api_content import configure_content, configure_search +from khoj.routers.api_content import configure_content from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state from khoj.utils.config import SearchType -from khoj.utils.fs_syncer import collect_files -from khoj.utils.helpers import is_none_or_empty, telemetry_disabled -from khoj.utils.rawconfig import FullConfig +from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -232,14 +230,6 @@ class UserAuthenticationBackend(AuthenticationBackend): return AuthCredentials(), UnauthenticatedUser() -def initialize_server(config: Optional[FullConfig]): - try: - configure_server(config, init=True) - except Exception as e: - logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True) - raise e - - def clean_connections(func): """ A decorator that ensures that Django database connections that have become unusable, or are obsolete, are closed @@ -260,19 +250,7 @@ def clean_connections(func): return func_wrapper -def configure_server( - config: FullConfig, - regenerate: bool = False, - search_type: Optional[SearchType] = None, - init=False, - user: KhojUser = None, -): - # Update Config - if config == None: - logger.info(f"Initializing with default config.") - config = FullConfig() - state.config = config - +def initialize_server(): if ConversationAdapters.has_valid_ai_model_api(): ai_model_api = ConversationAdapters.get_ai_model_api() state.openai_client = openai.OpenAI(api_key=ai_model_api.api_key, base_url=ai_model_api.api_base_url) @@ -309,43 +287,33 @@ def configure_server( ) state.SearchType = configure_search_types() - state.search_models = configure_search(state.search_models, state.config.search_type) - setup_default_agent(user) + setup_default_agent() - message = ( - "📡 Telemetry disabled" - if telemetry_disabled(state.config.app, state.telemetry_disabled) - else "📡 Telemetry enabled" - ) + message = "📡 Telemetry disabled" if state.telemetry_disabled else "📡 Telemetry enabled" logger.info(message) - if not init: - initialize_content(user, regenerate, search_type) - except Exception as e: logger.error(f"Failed to load some search models: {e}", exc_info=True) -def setup_default_agent(user: KhojUser): - AgentAdapters.create_default_agent(user) +def setup_default_agent(): + AgentAdapters.create_default_agent() def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None): # Initialize Content from Config - if state.search_models: - try: - logger.info("📬 Updating content index...") - all_files = collect_files(user=user) - status = configure_content( - user, - all_files, - regenerate, - search_type, - ) - if not status: - raise RuntimeError("Failed to update content index") - except Exception as e: - raise e + try: + logger.info("📬 Updating content index...") + status = configure_content( + user, + {}, + regenerate, + search_type, + ) + if not status: + raise RuntimeError("Failed to update content index") + except Exception as e: + raise e def configure_routes(app): @@ -438,8 +406,7 @@ def configure_middleware(app, ssl_enabled: bool = False): def update_content_index(): for user in get_all_users(): - all_files = collect_files(user=user) - success = configure_content(user, all_files) + success = configure_content(user, {}) if not success: raise RuntimeError("Failed to update content index") logger.info("📪 Content index updated via Scheduler") @@ -464,7 +431,7 @@ def configure_search_types(): @schedule.repeat(schedule.every(2).minutes) @clean_connections def upload_telemetry(): - if telemetry_disabled(state.config.app, state.telemetry_disabled) or not state.telemetry: + if state.telemetry_disabled or not state.telemetry: return try: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 14eaadfe..6d92b8e9 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -72,7 +72,6 @@ from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils import state -from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( clean_object_for_db, clean_text_for_db, @@ -789,8 +788,8 @@ class AgentAdapters: return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() @staticmethod - def create_default_agent(user: KhojUser): - default_chat_model = ConversationAdapters.get_default_chat_model(user) + def create_default_agent(): + default_chat_model = ConversationAdapters.get_default_chat_model(user=None) if default_chat_model is None: logger.info("No default conversation config found, skipping default agent creation") return None @@ -1553,14 +1552,6 @@ class ConversationAdapters: if chat_model is None: chat_model = await ConversationAdapters.aget_default_chat_model() - if chat_model.model_type == ChatModel.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - chat_model_name = chat_model.name - max_tokens = chat_model.max_prompt_size - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - - return chat_model - if ( chat_model.model_type in [ diff --git a/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py b/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py new file mode 100644 index 00000000..256aa93d --- /dev/null +++ b/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py @@ -0,0 +1,36 @@ +# Generated by Django 5.1.10 on 2025-07-19 21:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0091_chatmodel_friendly_name_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="chatmodel", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("anthropic", "Anthropic"), ("google", "Google")], + default="google", + max_length=200, + ), + ), + migrations.AlterField( + model_name="chatmodel", + name="name", + field=models.CharField(default="gemini-2.5-flash", max_length=200), + ), + migrations.AlterField( + model_name="speechtotextmodeloptions", + name="model_name", + field=models.CharField(default="whisper-1", max_length=200), + ), + migrations.AlterField( + model_name="speechtotextmodeloptions", + name="model_type", + field=models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200), + ), + ] diff --git a/src/khoj/database/migrations/0093_remove_localorgconfig_user_and_more.py b/src/khoj/database/migrations/0093_remove_localorgconfig_user_and_more.py new file mode 100644 index 00000000..ad3409cf --- /dev/null +++ b/src/khoj/database/migrations/0093_remove_localorgconfig_user_and_more.py @@ -0,0 +1,36 @@ +# Generated by Django 5.1.10 on 2025-07-25 23:30 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="localorgconfig", + name="user", + ), + migrations.RemoveField( + model_name="localpdfconfig", + name="user", + ), + migrations.RemoveField( + model_name="localplaintextconfig", + name="user", + ), + migrations.DeleteModel( + name="LocalMarkdownConfig", + ), + migrations.DeleteModel( + name="LocalOrgConfig", + ), + migrations.DeleteModel( + name="LocalPdfConfig", + ), + migrations.DeleteModel( + name="LocalPlaintextConfig", + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 7f80459a..1ed58572 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -220,16 +220,15 @@ class PriceTier(models.TextChoices): class ChatModel(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" - OFFLINE = "offline" ANTHROPIC = "anthropic" GOOGLE = "google" max_prompt_size = models.IntegerField(default=None, null=True, blank=True) subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) - name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") + name = models.CharField(max_length=200, default="gemini-2.5-flash") friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True) - model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.GOOGLE) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) vision_enabled = models.BooleanField(default=False) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) @@ -489,34 +488,6 @@ class ServerChatSettings(DbBaseModel): super().save(*args, **kwargs) -class LocalOrgConfig(DbBaseModel): - input_files = models.JSONField(default=list, null=True) - input_filter = models.JSONField(default=list, null=True) - index_heading_entries = models.BooleanField(default=False) - user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) - - -class LocalMarkdownConfig(DbBaseModel): - input_files = models.JSONField(default=list, null=True) - input_filter = models.JSONField(default=list, null=True) - index_heading_entries = models.BooleanField(default=False) - user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) - - -class LocalPdfConfig(DbBaseModel): - input_files = models.JSONField(default=list, null=True) - input_filter = models.JSONField(default=list, null=True) - index_heading_entries = models.BooleanField(default=False) - user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) - - -class LocalPlaintextConfig(DbBaseModel): - input_files = models.JSONField(default=list, null=True) - input_filter = models.JSONField(default=list, null=True) - index_heading_entries = models.BooleanField(default=False) - user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) - - class SearchModelConfig(DbBaseModel): class ModelType(models.TextChoices): TEXT = "text" @@ -605,11 +576,10 @@ class TextToImageModelConfig(DbBaseModel): class SpeechToTextModelOptions(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" - OFFLINE = "offline" - model_name = models.CharField(max_length=200, default="base") + model_name = models.CharField(max_length=200, default="whisper-1") friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True) - model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) diff --git a/src/khoj/main.py b/src/khoj/main.py index 50da2624..f42ae135 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -138,10 +138,10 @@ def run(should_start_server=True): initialization(not args.non_interactive) # Create app directory, if it doesn't exist - state.config_file.parent.mkdir(parents=True, exist_ok=True) + state.log_file.parent.mkdir(parents=True, exist_ok=True) # Set Log File - fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8") + fh = logging.FileHandler(state.log_file, encoding="utf-8") fh.setLevel(logging.DEBUG) logger.addHandler(fh) @@ -194,7 +194,7 @@ def run(should_start_server=True): # Configure Middleware configure_middleware(app, state.ssl_config) - initialize_server(args.config) + initialize_server() # If the server is started through gunicorn (external to the script), don't start the server if should_start_server: @@ -204,8 +204,7 @@ def run(should_start_server=True): def set_state(args): - state.config_file = args.config_file - state.config = args.config + state.log_file = args.log_file state.verbose = args.verbose state.host = args.host state.port = args.port @@ -214,7 +213,6 @@ def set_state(args): ) state.anonymous_mode = args.anonymous_mode state.khoj_version = version("khoj") - state.chat_on_gpu = args.chat_on_gpu def start_server(app, host=None, port=None, socket=None): diff --git a/src/khoj/migrations/__init__.py b/src/khoj/migrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/khoj/migrations/migrate_offline_chat_default_model.py b/src/khoj/migrations/migrate_offline_chat_default_model.py deleted file mode 100644 index 831f2d9d..00000000 --- a/src/khoj/migrations/migrate_offline_chat_default_model.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Current format of khoj.yml ---- -app: - ... -content-type: - ... -processor: - conversation: - offline-chat: - enable-offline-chat: false - chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin - ... -search-type: - ... - -New format of khoj.yml ---- -app: - ... -content-type: - ... -processor: - conversation: - offline-chat: - enable-offline-chat: false - chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf - ... -search-type: - ... -""" -import logging - -from packaging import version - -from khoj.utils.yaml import load_config_from_file, save_config_to_file - -logger = logging.getLogger(__name__) - - -def migrate_offline_chat_default_model(args): - schema_version = "0.12.4" - raw_config = load_config_from_file(args.config_file) - previous_version = raw_config.get("version") - - if "processor" not in raw_config: - return args - if raw_config["processor"] is None: - return args - if "conversation" not in raw_config["processor"]: - return args - if "offline-chat" not in raw_config["processor"]["conversation"]: - return args - if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]: - return args - - if previous_version is None or version.parse(previous_version) < version.parse("0.12.4"): - logger.info( - f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF" - ) - raw_config["version"] = schema_version - - # Update offline chat model to mistral in GGUF format to use latest GPT4All - offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"] - if offline_chat_model.endswith(".bin"): - raw_config["processor"]["conversation"]["offline-chat"]["chat-model"] = "mistral-7b-instruct-v0.1.Q4_0.gguf" - - save_config_to_file(raw_config, args.config_file) - return args diff --git a/src/khoj/migrations/migrate_offline_chat_default_model_2.py b/src/khoj/migrations/migrate_offline_chat_default_model_2.py deleted file mode 100644 index 107b7130..00000000 --- a/src/khoj/migrations/migrate_offline_chat_default_model_2.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Current format of khoj.yml ---- -app: - ... -content-type: - ... -processor: - conversation: - offline-chat: - enable-offline-chat: false - chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf - ... -search-type: - ... - -New format of khoj.yml ---- -app: - ... -content-type: - ... -processor: - conversation: - offline-chat: - enable-offline-chat: false - chat-model: NousResearch/Hermes-2-Pro-Mistral-7B-GGUF - ... -search-type: - ... -""" -import logging - -from packaging import version - -from khoj.utils.yaml import load_config_from_file, save_config_to_file - -logger = logging.getLogger(__name__) - - -def migrate_offline_chat_default_model(args): - schema_version = "1.7.0" - raw_config = load_config_from_file(args.config_file) - previous_version = raw_config.get("version") - - if "processor" not in raw_config: - return args - if raw_config["processor"] is None: - return args - if "conversation" not in raw_config["processor"]: - return args - if "offline-chat" not in raw_config["processor"]["conversation"]: - return args - if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]: - return args - - if previous_version is None or version.parse(previous_version) < version.parse(schema_version): - logger.info( - f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF" - ) - raw_config["version"] = schema_version - - # Update offline chat model to use Nous Research's Hermes-2-Pro GGUF in path format suitable for llama-cpp - offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"] - if offline_chat_model == "mistral-7b-instruct-v0.1.Q4_0.gguf": - raw_config["processor"]["conversation"]["offline-chat"][ - "chat-model" - ] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF" - - save_config_to_file(raw_config, args.config_file) - return args diff --git a/src/khoj/migrations/migrate_offline_chat_schema.py b/src/khoj/migrations/migrate_offline_chat_schema.py deleted file mode 100644 index 0c221652..00000000 --- a/src/khoj/migrations/migrate_offline_chat_schema.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Current format of khoj.yml ---- -app: - ... -content-type: - ... -processor: - conversation: - enable-offline-chat: false - conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json - openai: - ... -search-type: - ... - -New format of khoj.yml ---- -app: - ... -content-type: - ... -processor: - conversation: - offline-chat: - enable-offline-chat: false - chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin - tokenizer: null - max_prompt_size: null - conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json - openai: - ... -search-type: - ... -""" -import logging - -from packaging import version - -from khoj.utils.yaml import load_config_from_file, save_config_to_file - -logger = logging.getLogger(__name__) - - -def migrate_offline_chat_schema(args): - schema_version = "0.12.3" - raw_config = load_config_from_file(args.config_file) - previous_version = raw_config.get("version") - - if "processor" not in raw_config: - return args - if raw_config["processor"] is None: - return args - if "conversation" not in raw_config["processor"]: - return args - - if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"): - logger.info( - f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration" - ) - raw_config["version"] = schema_version - - # Create max-prompt-size field in conversation processor schema - raw_config["processor"]["conversation"]["max-prompt-size"] = None - raw_config["processor"]["conversation"]["tokenizer"] = None - - # Create offline chat schema based on existing enable_offline_chat field in khoj config schema - offline_chat_model = ( - raw_config["processor"]["conversation"] - .get("offline-chat", {}) - .get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin") - ) - raw_config["processor"]["conversation"]["offline-chat"] = { - "enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False), - "chat-model": offline_chat_model, - } - - # Delete old enable-offline-chat field from conversation processor schema - if "enable-offline-chat" in raw_config["processor"]["conversation"]: - del raw_config["processor"]["conversation"]["enable-offline-chat"] - - save_config_to_file(raw_config, args.config_file) - return args diff --git a/src/khoj/migrations/migrate_offline_model.py b/src/khoj/migrations/migrate_offline_model.py deleted file mode 100644 index 6294a4e8..00000000 --- a/src/khoj/migrations/migrate_offline_model.py +++ /dev/null @@ -1,29 +0,0 @@ -import logging -import os - -from packaging import version - -from khoj.utils.yaml import load_config_from_file, save_config_to_file - -logger = logging.getLogger(__name__) - - -def migrate_offline_model(args): - schema_version = "0.10.1" - raw_config = load_config_from_file(args.config_file) - previous_version = raw_config.get("version") - - if previous_version is None or version.parse(previous_version) < version.parse("0.10.1"): - logger.info( - f"Migrating offline model used for version {previous_version} to latest version for {args.version_no}" - ) - raw_config["version"] = schema_version - - # If the user has downloaded the offline model, remove it from the cache. - offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin") - if os.path.exists(offline_model_path): - os.remove(offline_model_path) - - save_config_to_file(raw_config, args.config_file) - - return args diff --git a/src/khoj/migrations/migrate_processor_config_openai.py b/src/khoj/migrations/migrate_processor_config_openai.py deleted file mode 100644 index c25e5306..00000000 --- a/src/khoj/migrations/migrate_processor_config_openai.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Current format of khoj.yml ---- -app: - should-log-telemetry: true -content-type: - ... -processor: - conversation: - chat-model: gpt-3.5-turbo - conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json - model: text-davinci-003 - openai-api-key: sk-secret-key -search-type: - ... - -New format of khoj.yml ---- -app: - should-log-telemetry: true -content-type: - ... -processor: - conversation: - openai: - chat-model: gpt-3.5-turbo - openai-api-key: sk-secret-key - conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json - enable-offline-chat: false -search-type: - ... -""" -from khoj.utils.yaml import load_config_from_file, save_config_to_file - - -def migrate_processor_conversation_schema(args): - schema_version = "0.10.0" - raw_config = load_config_from_file(args.config_file) - - if "processor" not in raw_config: - return args - if raw_config["processor"] is None: - return args - if "conversation" not in raw_config["processor"]: - return args - - current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None) - current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None) - if current_openai_api_key is None and current_chat_model is None: - return args - - raw_config["version"] = schema_version - - # Add enable_offline_chat to khoj config schema - if "enable-offline-chat" not in raw_config["processor"]["conversation"]: - raw_config["processor"]["conversation"]["enable-offline-chat"] = False - - # Update conversation processor schema - conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None) - raw_config["processor"]["conversation"] = { - "openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key}, - "conversation-logfile": conversation_logfile, - "enable-offline-chat": False, - } - - save_config_to_file(raw_config, args.config_file) - return args diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py deleted file mode 100644 index 316704b9..00000000 --- a/src/khoj/migrations/migrate_server_pg.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -The application config currently looks like this: -app: - should-log-telemetry: true -content-type: - ... -processor: - conversation: - conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json - max-prompt-size: null - offline-chat: - chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf - enable-offline-chat: false - openai: - api-key: sk-blah - chat-model: gpt-3.5-turbo - tokenizer: null -search-type: - asymmetric: - cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2 - encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1 - encoder-type: null - model-directory: /Users/si/.khoj/search/asymmetric - image: - encoder: sentence-transformers/clip-ViT-B-32 - encoder-type: null - model-directory: /Users/si/.khoj/search/image - symmetric: - cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2 - encoder: sentence-transformers/all-MiniLM-L6-v2 - encoder-type: null - model-directory: ~/.khoj/search/symmetric -version: 0.14.0 - - -The new version will looks like this: -app: - should-log-telemetry: true -processor: - conversation: - offline-chat: - enabled: false - openai: - api-key: sk-blah - chat-model-options: - - chat-model: gpt-3.5-turbo - tokenizer: null - type: openai - - chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf - tokenizer: null - type: offline -search-type: - asymmetric: - cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2 - encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1 -version: 0.15.0 -""" - -import logging - -from packaging import version - -from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig -from khoj.utils.yaml import load_config_from_file, save_config_to_file - -logger = logging.getLogger(__name__) - - -def migrate_server_pg(args): - schema_version = "0.15.0" - raw_config = load_config_from_file(args.config_file) - previous_version = raw_config.get("version") - - if previous_version is None or version.parse(previous_version) < version.parse(schema_version): - logger.info( - f"Migrating configuration used for version {previous_version} to latest version for server with postgres in {args.version_no}" - ) - raw_config["version"] = schema_version - - if raw_config is None: - return args - - if "search-type" in raw_config and raw_config["search-type"]: - if "asymmetric" in raw_config["search-type"]: - # Delete all existing search models - SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete() - # Create new search model from existing Khoj YAML config - asymmetric_search = raw_config["search-type"]["asymmetric"] - SearchModelConfig.objects.create( - name="default", - model_type=SearchModelConfig.ModelType.TEXT, - bi_encoder=asymmetric_search.get("encoder"), - cross_encoder=asymmetric_search.get("cross-encoder"), - ) - - if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]: - processor_conversation = raw_config["processor"]["conversation"] - - if "offline-chat" in raw_config["processor"]["conversation"]: - offline_chat = raw_config["processor"]["conversation"]["offline-chat"] - ChatModel.objects.create( - name=offline_chat.get("chat-model"), - tokenizer=processor_conversation.get("tokenizer"), - max_prompt_size=processor_conversation.get("max-prompt-size"), - model_type=ChatModel.ModelType.OFFLINE, - ) - - if ( - "openai" in raw_config["processor"]["conversation"] - and raw_config["processor"]["conversation"]["openai"] - ): - openai = raw_config["processor"]["conversation"]["openai"] - - if openai.get("api-key") is None: - logger.error("OpenAI API Key is not set. Will not be migrating OpenAI config.") - else: - if openai.get("chat-model") is None: - openai["chat-model"] = "gpt-3.5-turbo" - - openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default") - - ChatModel.objects.create( - name=openai.get("chat-model"), - tokenizer=processor_conversation.get("tokenizer"), - max_prompt_size=processor_conversation.get("max-prompt-size"), - model_type=ChatModel.ModelType.OPENAI, - ai_model_api=openai_model_api, - ) - - save_config_to_file(raw_config, args.config_file) - - return args diff --git a/src/khoj/migrations/migrate_version.py b/src/khoj/migrations/migrate_version.py deleted file mode 100644 index de8b9571..00000000 --- a/src/khoj/migrations/migrate_version.py +++ /dev/null @@ -1,17 +0,0 @@ -from khoj.utils.yaml import load_config_from_file, save_config_to_file - - -def migrate_config_to_version(args): - schema_version = "0.9.0" - raw_config = load_config_from_file(args.config_file) - - # Add version to khoj config schema - if "version" not in raw_config: - raw_config["version"] = schema_version - save_config_to_file(raw_config, args.config_file) - - # regenerate khoj index on first start of this version - # this should refresh index and apply index corruption fixes from #325 - args.regenerate = True - - return args diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 31f99f84..63ed50c6 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -20,7 +20,6 @@ magika = Magika() class GithubToEntries(TextToEntries): def __init__(self, config: GithubConfig): - super().__init__(config) raw_repos = config.githubrepoconfig.all() repos = [] for repo in raw_repos: diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 1e1ab4d3..23b96f63 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -47,7 +47,6 @@ class NotionBlockType(Enum): class NotionToEntries(TextToEntries): def __init__(self, config: NotionConfig): - super().__init__(config) self.config = NotionContentConfig( token=config.token, ) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 0ceda11d..0369d273 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -27,7 +27,6 @@ logger = logging.getLogger(__name__) class TextToEntries(ABC): def __init__(self, config: Any = None): self.embeddings_model = state.embeddings_model - self.config = config self.date_filter = DateFilter() @abstractmethod diff --git a/src/khoj/processor/conversation/offline/__init__.py b/src/khoj/processor/conversation/offline/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py deleted file mode 100644 index b117a48d..00000000 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ /dev/null @@ -1,224 +0,0 @@ -import asyncio -import logging -import os -from datetime import datetime -from threading import Thread -from time import perf_counter -from typing import Any, AsyncGenerator, Dict, List, Union - -from langchain_core.messages.chat import ChatMessage -from llama_cpp import Llama - -from khoj.database.models import Agent, ChatMessageModel, ChatModel -from khoj.processor.conversation import prompts -from khoj.processor.conversation.offline.utils import download_model -from khoj.processor.conversation.utils import ( - ResponseWithThought, - commit_conversation_trace, - generate_chatml_messages_with_context, - messages_to_print, -) -from khoj.utils import state -from khoj.utils.helpers import ( - is_none_or_empty, - is_promptrace_enabled, - truncate_code_context, -) -from khoj.utils.rawconfig import FileAttachment, LocationData -from khoj.utils.yaml import yaml_dump - -logger = logging.getLogger(__name__) - - -async def converse_offline( - # Query - user_query: str, - # Context - references: list[dict] = [], - online_results={}, - code_results={}, - query_files: str = None, - generated_files: List[FileAttachment] = None, - additional_context: List[str] = None, - generated_asset_results: Dict[str, Dict] = {}, - location_data: LocationData = None, - user_name: str = None, - chat_history: list[ChatMessageModel] = [], - # Model - model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - loaded_model: Union[Any, None] = None, - max_prompt_size=None, - tokenizer_name=None, - agent: Agent = None, - tracer: dict = {}, -) -> AsyncGenerator[ResponseWithThought, None]: - """ - Converse with user using Llama (Async Version) - """ - # Initialize Variables - assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) - tracer["chat_model"] = model_name - current_date = datetime.now() - - if agent and agent.personality: - system_prompt = prompts.custom_system_prompt_offline_chat.format( - name=agent.name, - bio=agent.personality, - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - else: - system_prompt = prompts.system_prompt_offline_chat.format( - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - - if location_data: - location_prompt = prompts.user_location.format(location=f"{location_data}") - system_prompt = f"{system_prompt}\n{location_prompt}" - - if user_name: - user_name_prompt = prompts.user_name.format(name=user_name) - system_prompt = f"{system_prompt}\n{user_name_prompt}" - - # Get Conversation Primer appropriate to Conversation Type - context_message = "" - if not is_none_or_empty(references): - context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n" - if not is_none_or_empty(online_results): - simplified_online_results = online_results.copy() - for result in online_results: - if online_results[result].get("webpages"): - simplified_online_results[result] = online_results[result]["webpages"] - - context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n" - if not is_none_or_empty(code_results): - context_message += ( - f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" - ) - context_message = context_message.strip() - - # Setup Prompt with Primer or Conversation History - messages = generate_chatml_messages_with_context( - user_query, - system_prompt, - chat_history, - context_message=context_message, - model_name=model_name, - loaded_model=offline_chat_model, - max_prompt_size=max_prompt_size, - tokenizer_name=tokenizer_name, - model_type=ChatModel.ModelType.OFFLINE, - query_files=query_files, - generated_files=generated_files, - generated_asset_results=generated_asset_results, - program_execution_context=additional_context, - ) - - logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") - - # Use asyncio.Queue and a thread to bridge sync iterator - queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue() - stop_phrases = ["", "INST]", "Notes:"] - - def _sync_llm_thread(): - """Synchronous function to run in a separate thread.""" - aggregated_response = "" - start_time = perf_counter() - state.chat_lock.acquire() - try: - response_iterator = send_message_to_model_offline( - messages, - loaded_model=offline_chat_model, - stop=stop_phrases, - max_prompt_size=max_prompt_size, - streaming=True, - tracer=tracer, - ) - for response in response_iterator: - response_delta: str = response["choices"][0]["delta"].get("content", "") - # Log the time taken to start response - if aggregated_response == "" and response_delta != "": - logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") - # Handle response chunk - aggregated_response += response_delta - # Put chunk into the asyncio queue (non-blocking) - try: - queue.put_nowait(ResponseWithThought(text=response_delta)) - except asyncio.QueueFull: - # Should not happen with default queue size unless consumer is very slow - logger.warning("Asyncio queue full during offline LLM streaming.") - # Potentially block here or handle differently if needed - asyncio.run(queue.put(ResponseWithThought(text=response_delta))) - - # Log the time taken to stream the entire response - logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - - # Save conversation trace - tracer["chat_model"] = model_name - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) - - except Exception as e: - logger.error(f"Error in offline LLM thread: {e}", exc_info=True) - finally: - state.chat_lock.release() - # Signal end of stream - queue.put_nowait(None) - - # Start the synchronous thread - thread = Thread(target=_sync_llm_thread) - thread.start() - - # Asynchronously consume from the queue - while True: - chunk = await queue.get() - if chunk is None: # End of stream signal - queue.task_done() - break - yield chunk - queue.task_done() - - # Wait for the thread to finish (optional, ensures cleanup) - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, thread.join) - - -def send_message_to_model_offline( - messages: List[ChatMessage], - loaded_model=None, - model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - temperature: float = 0.2, - streaming=False, - stop=[], - max_prompt_size: int = None, - response_type: str = "text", - tracer: dict = {}, -): - assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) - messages_dict = [{"role": message.role, "content": message.content} for message in messages] - seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None - response = offline_chat_model.create_chat_completion( - messages_dict, - stop=stop, - stream=streaming, - temperature=temperature, - response_format={"type": response_type}, - seed=seed, - ) - - if streaming: - return response - - response_text: str = response["choices"][0]["message"].get("content", "") - - # Save conversation trace for non-streaming responses - # Streamed responses need to be saved by the calling function - tracer["chat_model"] = model_name - tracer["temperature"] = temperature - if is_promptrace_enabled(): - commit_conversation_trace(messages, response_text, tracer) - - return ResponseWithThought(text=response_text) diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py deleted file mode 100644 index 88082ad1..00000000 --- a/src/khoj/processor/conversation/offline/utils.py +++ /dev/null @@ -1,80 +0,0 @@ -import glob -import logging -import math -import os -from typing import Any, Dict - -from huggingface_hub.constants import HF_HUB_CACHE - -from khoj.utils import state -from khoj.utils.helpers import get_device_memory - -logger = logging.getLogger(__name__) - - -def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None): - # Initialize Model Parameters - # Use n_ctx=0 to get context size from the model - kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False} - - # Decide whether to load model to GPU or CPU - device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu" - kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0 - - # Add chat format if known - if "llama-3" in repo_id.lower(): - kwargs["chat_format"] = "llama-3" - elif "gemma-2" in repo_id.lower(): - kwargs["chat_format"] = "gemma" - - # Check if the model is already downloaded - model_path = load_model_from_cache(repo_id, filename) - chat_model = None - try: - chat_model = load_model(model_path, repo_id, filename, kwargs) - except: - # Load model on CPU if GPU is not available - kwargs["n_gpu_layers"], device = 0, "cpu" - chat_model = load_model(model_path, repo_id, filename, kwargs) - - # Now load the model with context size set based on: - # 1. context size supported by model and - # 2. configured size or machine (V)RAM - kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens) - chat_model = load_model(model_path, repo_id, filename, kwargs) - - logger.debug( - f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window." - ) - return chat_model - - -def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}): - from llama_cpp.llama import Llama - - if model_path: - return Llama(model_path, **kwargs) - else: - return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) - - -def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): - # Construct the path to the model file in the cache directory - repo_org, repo_name = repo_id.split("/") - object_id = "--".join([repo_type, repo_org, repo_name]) - model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename]) - - # Check if the model file exists - paths = glob.glob(model_path) - if paths: - return paths[0] - else: - return None - - -def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int: - """Infer max prompt size based on device memory and max context window supported by the model""" - configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens - vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic - configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None - return min(configured_max_tokens, vram_based_n_ctx, model_context_window) diff --git a/src/khoj/processor/conversation/offline/whisper.py b/src/khoj/processor/conversation/offline/whisper.py deleted file mode 100644 index d8dd4457..00000000 --- a/src/khoj/processor/conversation/offline/whisper.py +++ /dev/null @@ -1,15 +0,0 @@ -import whisper -from asgiref.sync import sync_to_async - -from khoj.utils import state - - -async def transcribe_audio_offline(audio_filename: str, model: str) -> str: - """ - Transcribe audio file offline using Whisper - """ - # Send the audio data to the Whisper API - if not state.whisper_model: - state.whisper_model = whisper.load_model(model) - response = await sync_to_async(state.whisper_model.transcribe)(audio_filename) - return response["text"] diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 47589e93..fcf30d07 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -78,38 +78,6 @@ no_entries_found = PromptTemplate.from_template( """.strip() ) -## Conversation Prompts for Offline Chat Models -## -- -system_prompt_offline_chat = PromptTemplate.from_template( - """ -You are Khoj, a smart, inquisitive and helpful personal assistant. -- Use your general knowledge and past conversation with the user as context to inform your responses. -- If you do not know the answer, say 'I don't know.' -- Think step-by-step and ask questions to get the necessary information to answer the user's question. -- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations. -- Do not print verbatim Notes unless necessary. - -Note: More information about you, the company or Khoj apps can be found at https://khoj.dev. -Today is {day_of_week}, {current_date} in UTC. -""".strip() -) - -custom_system_prompt_offline_chat = PromptTemplate.from_template( - """ -You are {name}, a personal agent on Khoj. -- Use your general knowledge and past conversation with the user as context to inform your responses. -- If you do not know the answer, say 'I don't know.' -- Think step-by-step and ask questions to get the necessary information to answer the user's question. -- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations. -- Do not print verbatim Notes unless necessary. - -Note: More information about you, the company or Khoj apps can be found at https://khoj.dev. -Today is {day_of_week}, {current_date} in UTC. - -Instructions:\n{bio} -""".strip() -) - ## Notes Conversation ## -- notes_conversation = PromptTemplate.from_template( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 3bd17856..6e60c786 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -18,8 +18,6 @@ import requests import tiktoken import yaml from langchain_core.messages.chat import ChatMessage -from llama_cpp import LlamaTokenizer -from llama_cpp.llama import Llama from pydantic import BaseModel, ConfigDict, ValidationError, create_model from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -32,7 +30,6 @@ from khoj.database.models import ( KhojUser, ) from khoj.processor.conversation import prompts -from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -85,12 +82,6 @@ model_to_prompt_size = { "claude-sonnet-4-20250514": 60000, "claude-opus-4-0": 60000, "claude-opus-4-20250514": 60000, - # Offline Models - "bartowski/Qwen2.5-14B-Instruct-GGUF": 20000, - "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, - "bartowski/Llama-3.2-3B-Instruct-GGUF": 20000, - "bartowski/gemma-2-9b-it-GGUF": 6000, - "bartowski/gemma-2-2b-it-GGUF": 6000, } model_to_tokenizer: Dict[str, str] = {} @@ -573,7 +564,6 @@ def generate_chatml_messages_with_context( system_message: str = None, chat_history: list[ChatMessageModel] = [], model_name="gpt-4o-mini", - loaded_model: Optional[Llama] = None, max_prompt_size=None, tokenizer_name=None, query_images=None, @@ -588,10 +578,7 @@ def generate_chatml_messages_with_context( """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs if not max_prompt_size: - if loaded_model: - max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) - else: - max_prompt_size = model_to_prompt_size.get(model_name, 10000) + max_prompt_size = model_to_prompt_size.get(model_name, 10000) # Scale lookback turns proportional to max prompt size supported by model lookback_turns = max_prompt_size // 750 @@ -735,7 +722,7 @@ def generate_chatml_messages_with_context( message.content = [{"type": "text", "text": message.content}] # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) + messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) # Return message in chronological order return messages[::-1] @@ -743,25 +730,20 @@ def generate_chatml_messages_with_context( def get_encoder( model_name: str, - loaded_model: Optional[Llama] = None, tokenizer_name=None, -) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer: +) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast: default_tokenizer = "gpt-4o" try: - if loaded_model: - encoder = loaded_model.tokenizer() - elif model_name.startswith("gpt-") or model_name.startswith("o1"): - # as tiktoken doesn't recognize o1 model series yet - encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) - elif tokenizer_name: + if tokenizer_name: if tokenizer_name in state.pretrained_tokenizers: encoder = state.pretrained_tokenizers[tokenizer_name] else: encoder = AutoTokenizer.from_pretrained(tokenizer_name) state.pretrained_tokenizers[tokenizer_name] = encoder else: - encoder = download_model(model_name).tokenizer() + # as tiktoken doesn't recognize o1 model series yet + encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) except: encoder = tiktoken.encoding_for_model(default_tokenizer) if state.verbose > 2: @@ -773,7 +755,7 @@ def get_encoder( def count_tokens( message_content: str | list[str | dict], - encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding, + encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding, ) -> int: """ Count the total number of tokens in a list of messages. @@ -825,11 +807,10 @@ def truncate_messages( messages: list[ChatMessage], max_prompt_size: int, model_name: str, - loaded_model: Optional[Llama] = None, tokenizer_name=None, ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - encoder = get_encoder(model_name, loaded_model, tokenizer_name) + encoder = get_encoder(model_name, tokenizer_name) # Extract system message from messages system_message = None diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 979205a0..2b63c40d 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -235,7 +235,6 @@ def is_operator_model(model: str) -> ChatModel.ModelType | None: "claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC, "claude-sonnet-4": ChatModel.ModelType.ANTHROPIC, "claude-opus-4": ChatModel.ModelType.ANTHROPIC, - "ui-tars-1.5": ChatModel.ModelType.OFFLINE, } for operator_model in operator_models: if model.startswith(operator_model): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 23805200..44c2f2b7 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -15,7 +15,6 @@ from khoj.configure import initialize_content from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo from khoj.database.models import KhojUser, SpeechToTextModelOptions -from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, @@ -88,22 +87,14 @@ def update( force: Optional[bool] = False, ): user = request.user.object - if not state.config: - error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/settings, plugins or by editing {state.config_file}." - logger.warning(error_msg) - raise HTTPException(status_code=500, detail=error_msg) try: initialize_content(user=user, regenerate=force, search_type=t) except Exception as e: - error_msg = f"🚨 Failed to update server via API: {e}" + error_msg = f"🚨 Failed to update server indexed content via API: {e}" logger.error(error_msg, exc_info=True) raise HTTPException(status_code=500, detail=error_msg) else: - components = [] - if state.search_models: - components.append("Search models") - components_msg = ", ".join(components) - logger.info(f"📪 {components_msg} updated via API") + logger.info(f"📪 Server indexed content updated via API") update_telemetry_state( request=request, @@ -150,9 +141,6 @@ async def transcribe( if not speech_to_text_config: # If the user has not configured a speech to text model, return an unsupported on server error status_code = 501 - elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: - speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio_offline(audio_filename, speech2text_model) elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: speech2text_model = speech_to_text_config.model_name if speech_to_text_config.ai_model_api: diff --git a/src/khoj/routers/api_content.py b/src/khoj/routers/api_content.py index 4f9cc407..c2732ec8 100644 --- a/src/khoj/routers/api_content.py +++ b/src/khoj/routers/api_content.py @@ -27,16 +27,7 @@ from khoj.database.adapters import ( get_user_notion_config, ) from khoj.database.models import Entry as DbEntry -from khoj.database.models import ( - GithubConfig, - GithubRepoConfig, - KhojUser, - LocalMarkdownConfig, - LocalOrgConfig, - LocalPdfConfig, - LocalPlaintextConfig, - NotionConfig, -) +from khoj.database.models import GithubConfig, GithubRepoConfig, NotionConfig from khoj.processor.content.docx.docx_to_entries import DocxToEntries from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries from khoj.routers.helpers import ( @@ -47,17 +38,9 @@ from khoj.routers.helpers import ( get_user_config, update_telemetry_state, ) -from khoj.utils import constants, state -from khoj.utils.config import SearchModels -from khoj.utils.rawconfig import ( - ContentConfig, - FullConfig, - GithubContentConfig, - NotionContentConfig, - SearchConfig, -) +from khoj.utils import state +from khoj.utils.rawconfig import GithubContentConfig, NotionContentConfig from khoj.utils.state import SearchType -from khoj.utils.yaml import save_config_to_file_updated_state logger = logging.getLogger(__name__) @@ -192,8 +175,6 @@ async def set_content_github( updated_config: Union[GithubContentConfig, None], client: Optional[str] = None, ): - _initialize_config() - user = request.user.object try: @@ -225,8 +206,6 @@ async def set_content_notion( updated_config: Union[NotionContentConfig, None], client: Optional[str] = None, ): - _initialize_config() - user = request.user.object try: @@ -323,10 +302,6 @@ def get_content_types(request: Request, client: Optional[str] = None): configured_content_types = set(EntryAdapters.get_unique_file_types(user)) configured_content_types |= {"all"} - if state.config and state.config.content_type: - for ctype in state.config.content_type.model_dump(exclude_none=True): - configured_content_types.add(ctype) - return list(configured_content_types & all_content_types) @@ -606,28 +581,6 @@ async def indexer( docx=index_files["docx"], ) - if state.config == None: - logger.info("📬 Initializing content index on first run.") - default_full_config = FullConfig( - content_type=None, - search_type=SearchConfig.model_validate(constants.default_config["search-type"]), - processor=None, - ) - state.config = default_full_config - default_content_config = ContentConfig( - org=None, - markdown=None, - pdf=None, - docx=None, - image=None, - github=None, - notion=None, - plaintext=None, - ) - state.config.content_type = default_content_config - save_config_to_file_updated_state() - configure_search(state.search_models, state.config.search_type) - loop = asyncio.get_event_loop() success = await loop.run_in_executor( None, @@ -674,14 +627,6 @@ async def indexer( return Response(content=indexed_filenames, status_code=200) -def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]: - # Run Validation Checks - if search_models is None: - search_models = SearchModels() - - return search_models - - def map_config_to_object(content_source: str): if content_source == DbEntry.EntrySource.GITHUB: return GithubConfig @@ -689,56 +634,3 @@ def map_config_to_object(content_source: str): return NotionConfig if content_source == DbEntry.EntrySource.COMPUTER: return "Computer" - - -async def map_config_to_db(config: FullConfig, user: KhojUser): - if config.content_type: - if config.content_type.org: - await LocalOrgConfig.objects.filter(user=user).adelete() - await LocalOrgConfig.objects.acreate( - input_files=config.content_type.org.input_files, - input_filter=config.content_type.org.input_filter, - index_heading_entries=config.content_type.org.index_heading_entries, - user=user, - ) - if config.content_type.markdown: - await LocalMarkdownConfig.objects.filter(user=user).adelete() - await LocalMarkdownConfig.objects.acreate( - input_files=config.content_type.markdown.input_files, - input_filter=config.content_type.markdown.input_filter, - index_heading_entries=config.content_type.markdown.index_heading_entries, - user=user, - ) - if config.content_type.pdf: - await LocalPdfConfig.objects.filter(user=user).adelete() - await LocalPdfConfig.objects.acreate( - input_files=config.content_type.pdf.input_files, - input_filter=config.content_type.pdf.input_filter, - index_heading_entries=config.content_type.pdf.index_heading_entries, - user=user, - ) - if config.content_type.plaintext: - await LocalPlaintextConfig.objects.filter(user=user).adelete() - await LocalPlaintextConfig.objects.acreate( - input_files=config.content_type.plaintext.input_files, - input_filter=config.content_type.plaintext.input_filter, - index_heading_entries=config.content_type.plaintext.index_heading_entries, - user=user, - ) - if config.content_type.github: - await adapters.set_user_github_config( - user=user, - pat_token=config.content_type.github.pat_token, - repos=config.content_type.github.repos, - ) - if config.content_type.notion: - await adapters.set_notion_config( - user=user, - token=config.content_type.notion.token, - ) - - -def _initialize_config(): - if state.config is None: - state.config = FullConfig() - state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"]) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d3554f7f..8dcda86b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -89,10 +89,6 @@ from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, gemini_send_message_to_model, ) -from khoj.processor.conversation.offline.chat_model import ( - converse_offline, - send_message_to_model_offline, -) from khoj.processor.conversation.openai.gpt import ( converse_openai, send_message_to_model, @@ -117,7 +113,6 @@ from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.search_type import text_search from khoj.utils import state -from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( LRU, ConversationCommand, @@ -168,14 +163,6 @@ async def is_ready_to_chat(user: KhojUser): if user_chat_model == None: user_chat_model = await ConversationAdapters.aget_default_chat_model(user) - if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE: - chat_model_name = user_chat_model.name - max_tokens = user_chat_model.max_prompt_size - if state.offline_chat_processor_config is None: - logger.info("Loading Offline Chat Model...") - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - return True - if ( user_chat_model and ( @@ -231,7 +218,6 @@ def update_telemetry_state( telemetry_type=telemetry_type, api=api, client=client, - app_config=state.config.app, disable_telemetry_env=state.telemetry_disabled, properties=user_state, ) @@ -1470,12 +1456,6 @@ async def send_message_to_model_wrapper( vision_available = chat_model.vision_enabled api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - loaded_model = None - - if model_type == ChatModel.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( user_message=query, @@ -1483,7 +1463,6 @@ async def send_message_to_model_wrapper( system_message=system_message, chat_history=chat_history, model_name=chat_model_name, - loaded_model=loaded_model, tokenizer_name=tokenizer, max_prompt_size=max_tokens, vision_enabled=vision_available, @@ -1492,18 +1471,7 @@ async def send_message_to_model_wrapper( query_files=query_files, ) - if model_type == ChatModel.ModelType.OFFLINE: - return send_message_to_model_offline( - messages=truncated_messages, - loaded_model=loaded_model, - model_name=chat_model_name, - max_prompt_size=max_tokens, - streaming=False, - response_type=response_type, - tracer=tracer, - ) - - elif model_type == ChatModel.ModelType.OPENAI: + if model_type == ChatModel.ModelType.OPENAI: return send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1565,19 +1533,12 @@ def send_message_to_model_wrapper_sync( vision_available = chat_model.vision_enabled api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - loaded_model = None - - if model_type == ChatModel.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, chat_history=chat_history, model_name=chat_model_name, - loaded_model=loaded_model, max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=model_type, @@ -1585,18 +1546,7 @@ def send_message_to_model_wrapper_sync( query_files=query_files, ) - if model_type == ChatModel.ModelType.OFFLINE: - return send_message_to_model_offline( - messages=truncated_messages, - loaded_model=loaded_model, - model_name=chat_model_name, - max_prompt_size=max_tokens, - streaming=False, - response_type=response_type, - tracer=tracer, - ) - - elif model_type == ChatModel.ModelType.OPENAI: + if model_type == ChatModel.ModelType.OPENAI: return send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1678,30 +1628,7 @@ async def agenerate_chat_response( chat_model = vision_enabled_config vision_available = True - if chat_model.model_type == "offline": - loaded_model = state.offline_chat_processor_config.loaded_model - chat_response_generator = converse_offline( - # Query - user_query=query_to_run, - # Context - references=compiled_references, - online_results=online_results, - generated_files=raw_generated_files, - generated_asset_results=generated_asset_results, - location_data=location_data, - user_name=user_name, - query_files=query_files, - chat_history=chat_history, - # Model - loaded_model=loaded_model, - model_name=chat_model.name, - max_prompt_size=chat_model.max_prompt_size, - tokenizer_name=chat_model.tokenizer, - agent=agent, - tracer=tracer, - ) - - elif chat_model.model_type == ChatModel.ModelType.OPENAI: + if chat_model.model_type == ChatModel.ModelType.OPENAI: openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key chat_model_name = chat_model.name @@ -2798,7 +2725,8 @@ def configure_content( search_type = t.value if t else None - no_documents = all([not files.get(file_type) for file_type in files]) + # Check if client sent any documents of the supported types + no_client_sent_documents = all([not files.get(file_type) for file_type in files]) if files is None: logger.warning(f"🚨 No files to process for {search_type} search.") @@ -2872,7 +2800,8 @@ def configure_content( success = False try: - if no_documents: + # Run server side indexing of user Github docs if no client sent documents + if no_client_sent_documents: github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() if ( search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value @@ -2892,7 +2821,8 @@ def configure_content( success = False try: - if no_documents: + # Run server side indexing of user Notion docs if no client sent documents + if no_client_sent_documents: # Initialize Notion Search notion_config = NotionConfig.objects.filter(user=user).first() if ( diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 14581f41..a016f9a4 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -1,36 +1,19 @@ import argparse import logging -import os import pathlib from importlib.metadata import version logger = logging.getLogger(__name__) -from khoj.migrations.migrate_offline_chat_default_model import ( - migrate_offline_chat_default_model, -) -from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema -from khoj.migrations.migrate_offline_model import migrate_offline_model -from khoj.migrations.migrate_processor_config_openai import ( - migrate_processor_conversation_schema, -) -from khoj.migrations.migrate_server_pg import migrate_server_pg -from khoj.migrations.migrate_version import migrate_config_to_version -from khoj.utils.helpers import is_env_var_true, resolve_absolute_path -from khoj.utils.yaml import parse_config_from_file - def cli(args=None): # Setup Argument Parser for the Commandline Interface parser = argparse.ArgumentParser(description="Start Khoj; An AI personal assistant for your Digital Brain") parser.add_argument( - "--config-file", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj" - ) - parser.add_argument( - "--regenerate", - action="store_true", - default=False, - help="Regenerate model embeddings from source files. Default: false", + "--log-file", + default="~/.khoj/khoj.log", + type=pathlib.Path, + help="File path for server logs. Default: ~/.khoj/khoj.log", ) parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0") parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1") @@ -43,14 +26,11 @@ def cli(args=None): parser.add_argument("--sslcert", type=str, help="Path to SSL certificate file") parser.add_argument("--sslkey", type=str, help="Path to SSL key file") parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit") - parser.add_argument( - "--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model" - ) parser.add_argument( "--anonymous-mode", action="store_true", default=False, - help="Run Khoj in anonymous mode. This does not require any login for connecting users.", + help="Run Khoj in single user mode with no login required. Useful for personal use or testing.", ) parser.add_argument( "--non-interactive", @@ -64,38 +44,10 @@ def cli(args=None): if len(remaining_args) > 0: logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}") - # Set default values for arguments - args.chat_on_gpu = not args.disable_chat_on_gpu - args.version_no = version("khoj") if args.version: # Show version of khoj installed and exit print(args.version_no) exit(0) - # Normalize config_file path to absolute path - args.config_file = resolve_absolute_path(args.config_file) - - if not args.config_file.exists(): - args.config = None - else: - args = run_migrations(args) - args.config = parse_config_from_file(args.config_file) - if is_env_var_true("KHOJ_TELEMETRY_DISABLE"): - args.config.app.should_log_telemetry = False - - return args - - -def run_migrations(args): - migrations = [ - migrate_config_to_version, - migrate_processor_conversation_schema, - migrate_offline_model, - migrate_offline_chat_schema, - migrate_offline_chat_default_model, - migrate_server_pg, - ] - for migration in migrations: - args = migration(args) return args diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 03dad75c..c9cd1c43 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -1,22 +1,7 @@ # System Packages from __future__ import annotations # to avoid quoting type hints -import logging -from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, List, Optional, Union - -import torch - -from khoj.processor.conversation.offline.utils import download_model - -logger = logging.getLogger(__name__) - - -if TYPE_CHECKING: - from sentence_transformers import CrossEncoder - - from khoj.utils.models import BaseEncoder class SearchType(str, Enum): @@ -29,53 +14,3 @@ class SearchType(str, Enum): Notion = "notion" Plaintext = "plaintext" Docx = "docx" - - -class ProcessorType(str, Enum): - Conversation = "conversation" - - -@dataclass -class TextContent: - enabled: bool - - -@dataclass -class ImageContent: - image_names: List[str] - image_embeddings: torch.Tensor - image_metadata_embeddings: torch.Tensor - - -@dataclass -class TextSearchModel: - bi_encoder: BaseEncoder - cross_encoder: Optional[CrossEncoder] = None - top_k: Optional[int] = 15 - - -@dataclass -class ImageSearchModel: - image_encoder: BaseEncoder - - -@dataclass -class SearchModels: - text_search: Optional[TextSearchModel] = None - - -@dataclass -class OfflineChatProcessorConfig: - loaded_model: Union[Any, None] = None - - -class OfflineChatProcessorModel: - def __init__(self, chat_model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", max_tokens: int = None): - self.chat_model = chat_model - self.loaded_model = None - try: - self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens) - except ValueError as e: - self.loaded_model = None - logger.error(f"Error while loading offline chat model: {e}", exc_info=True) - raise e diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 4f128aa7..b2ed49de 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -10,13 +10,6 @@ empty_escape_sequences = "\n|\r|\t| " app_env_filepath = "~/.khoj/env" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" content_directory = "~/.khoj/content/" -default_offline_chat_models = [ - "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - "bartowski/Llama-3.2-3B-Instruct-GGUF", - "bartowski/gemma-2-9b-it-GGUF", - "bartowski/gemma-2-2b-it-GGUF", - "bartowski/Qwen2.5-14B-Instruct-GGUF", -] default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1", "o3", "o4-mini"] default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"] default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"] diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py deleted file mode 100644 index 67e91bc9..00000000 --- a/src/khoj/utils/fs_syncer.py +++ /dev/null @@ -1,252 +0,0 @@ -import glob -import logging -import os -from pathlib import Path -from typing import Optional - -from bs4 import BeautifulSoup -from magika import Magika - -from khoj.database.models import ( - KhojUser, - LocalMarkdownConfig, - LocalOrgConfig, - LocalPdfConfig, - LocalPlaintextConfig, -) -from khoj.utils.config import SearchType -from khoj.utils.helpers import get_absolute_path, is_none_or_empty -from khoj.utils.rawconfig import TextContentConfig - -logger = logging.getLogger(__name__) -magika = Magika() - - -def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict: - files: dict[str, dict] = {"docx": {}, "image": {}} - - if search_type == SearchType.All or search_type == SearchType.Org: - org_config = LocalOrgConfig.objects.filter(user=user).first() - files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {} - if search_type == SearchType.All or search_type == SearchType.Markdown: - markdown_config = LocalMarkdownConfig.objects.filter(user=user).first() - files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {} - if search_type == SearchType.All or search_type == SearchType.Plaintext: - plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first() - files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {} - if search_type == SearchType.All or search_type == SearchType.Pdf: - pdf_config = LocalPdfConfig.objects.filter(user=user).first() - files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {} - files["image"] = {} - files["docx"] = {} - return files - - -def construct_config_from_db(db_config) -> TextContentConfig: - return TextContentConfig( - input_files=db_config.input_files, - input_filter=db_config.input_filter, - index_heading_entries=db_config.index_heading_entries, - ) - - -def get_plaintext_files(config: TextContentConfig) -> dict[str, str]: - def is_plaintextfile(file: str): - "Check if file is plaintext file" - # Check if file path exists - content_group = magika.identify_path(Path(file)).output.group - # Use file extension to decide plaintext if file content is not identifiable - valid_text_file_extensions = ("txt", "md", "markdown", "org" "mbox", "rst", "html", "htm", "xml") - return file.endswith(valid_text_file_extensions) or content_group in ["text", "code"] - - def extract_html_content(html_content: str): - "Extract content from HTML" - soup = BeautifulSoup(html_content, "html.parser") - return soup.get_text(strip=True, separator="\n") - - # Extract required fields from config - input_files, input_filters = ( - config.input_files, - config.input_filter, - ) - - # Input Validation - if is_none_or_empty(input_files) and is_none_or_empty(input_filters): - logger.debug("At least one of input-files or input-file-filter is required to be specified") - return {} - - # Get all plain text files to process - absolute_plaintext_files, filtered_plaintext_files = set(), set() - if input_files: - absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in input_files} - if input_filters: - filtered_plaintext_files = { - filtered_file - for plaintext_file_filter in input_filters - for filtered_file in glob.glob(get_absolute_path(plaintext_file_filter), recursive=True) - if os.path.isfile(filtered_file) - } - - all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files) - - files_with_no_plaintext_extensions = { - target_files for target_files in all_target_files if not is_plaintextfile(target_files) - } - if any(files_with_no_plaintext_extensions): - logger.warning(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}") - all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions) - - logger.debug(f"Processing files: {all_target_files}") - - filename_to_content_map = {} - for file in all_target_files: - with open(file, "r", encoding="utf8") as f: - try: - plaintext_content = f.read() - if file.endswith(("html", "htm", "xml")): - plaintext_content = extract_html_content(plaintext_content) - filename_to_content_map[file] = plaintext_content - except Exception as e: - logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.") - logger.warning(e, exc_info=True) - - return filename_to_content_map - - -def get_org_files(config: TextContentConfig): - # Extract required fields from config - org_files, org_file_filters = ( - config.input_files, - config.input_filter, - ) - - # Input Validation - if is_none_or_empty(org_files) and is_none_or_empty(org_file_filters): - logger.debug("At least one of org-files or org-file-filter is required to be specified") - return {} - - # Get Org files to process - absolute_org_files, filtered_org_files = set(), set() - if org_files: - absolute_org_files = {get_absolute_path(org_file) for org_file in org_files} - if org_file_filters: - filtered_org_files = { - filtered_file - for org_file_filter in org_file_filters - for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True) - if os.path.isfile(filtered_file) - } - - all_org_files = sorted(absolute_org_files | filtered_org_files) - - files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")} - if any(files_with_non_org_extensions): - logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") - - logger.debug(f"Processing files: {all_org_files}") - - filename_to_content_map = {} - for file in all_org_files: - with open(file, "r", encoding="utf8") as f: - try: - filename_to_content_map[file] = f.read() - except Exception as e: - logger.warning(f"Unable to read file: {file} as org. Skipping file.") - logger.warning(e, exc_info=True) - - return filename_to_content_map - - -def get_markdown_files(config: TextContentConfig): - # Extract required fields from config - markdown_files, markdown_file_filters = ( - config.input_files, - config.input_filter, - ) - - # Input Validation - if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filters): - logger.debug("At least one of markdown-files or markdown-file-filter is required to be specified") - return {} - - # Get markdown files to process - absolute_markdown_files, filtered_markdown_files = set(), set() - if markdown_files: - absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files} - - if markdown_file_filters: - filtered_markdown_files = { - filtered_file - for markdown_file_filter in markdown_file_filters - for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True) - if os.path.isfile(filtered_file) - } - - all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files) - - files_with_non_markdown_extensions = { - md_file for md_file in all_markdown_files if not md_file.endswith(".md") and not md_file.endswith(".markdown") - } - - if any(files_with_non_markdown_extensions): - logger.warning( - f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}" - ) - - logger.debug(f"Processing files: {all_markdown_files}") - - filename_to_content_map = {} - for file in all_markdown_files: - with open(file, "r", encoding="utf8") as f: - try: - filename_to_content_map[file] = f.read() - except Exception as e: - logger.warning(f"Unable to read file: {file} as markdown. Skipping file.") - logger.warning(e, exc_info=True) - - return filename_to_content_map - - -def get_pdf_files(config: TextContentConfig): - # Extract required fields from config - pdf_files, pdf_file_filters = ( - config.input_files, - config.input_filter, - ) - - # Input Validation - if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filters): - logger.debug("At least one of pdf-files or pdf-file-filter is required to be specified") - return {} - - # Get PDF files to process - absolute_pdf_files, filtered_pdf_files = set(), set() - if pdf_files: - absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files} - if pdf_file_filters: - filtered_pdf_files = { - filtered_file - for pdf_file_filter in pdf_file_filters - for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True) - if os.path.isfile(filtered_file) - } - - all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files) - - files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")} - - if any(files_with_non_pdf_extensions): - logger.warning(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}") - - logger.debug(f"Processing files: {all_pdf_files}") - - filename_to_content_map = {} - for file in all_pdf_files: - with open(file, "rb") as f: - try: - filename_to_content_map[file] = f.read() - except Exception as e: - logger.warning(f"Unable to read file: {file} as PDF. Skipping file.") - logger.warning(e, exc_info=True) - - return filename_to_content_map diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 508901de..523ec007 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -47,7 +47,6 @@ if TYPE_CHECKING: from sentence_transformers import CrossEncoder, SentenceTransformer from khoj.utils.models import BaseEncoder - from khoj.utils.rawconfig import AppConfig logger = logging.getLogger(__name__) @@ -267,23 +266,16 @@ def get_server_id(): return server_id -def telemetry_disabled(app_config: AppConfig, telemetry_disable_env) -> bool: - if telemetry_disable_env is True: - return True - return not app_config or not app_config.should_log_telemetry - - def log_telemetry( telemetry_type: str, api: str = None, client: Optional[str] = None, - app_config: Optional[AppConfig] = None, disable_telemetry_env: bool = False, properties: dict = None, ): """Log basic app usage telemetry like client, os, api called""" # Do not log usage telemetry, if telemetry is disabled via app config - if telemetry_disabled(app_config, disable_telemetry_env): + if disable_telemetry_env: return [] if properties.get("server_id") is None: diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 9ed1bdff..8023b3ed 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -16,7 +16,6 @@ from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tok from khoj.utils.constants import ( default_anthropic_chat_models, default_gemini_chat_models, - default_offline_chat_models, default_openai_chat_models, ) @@ -72,7 +71,6 @@ def initialization(interactive: bool = True): default_api_key=openai_api_key, api_base_url=openai_base_url, vision_enabled=True, - is_offline=False, interactive=interactive, provider_name=provider, ) @@ -118,7 +116,6 @@ def initialization(interactive: bool = True): default_gemini_chat_models, default_api_key=os.getenv("GEMINI_API_KEY"), vision_enabled=True, - is_offline=False, interactive=interactive, provider_name="Google Gemini", ) @@ -145,40 +142,11 @@ def initialization(interactive: bool = True): default_anthropic_chat_models, default_api_key=os.getenv("ANTHROPIC_API_KEY"), vision_enabled=True, - is_offline=False, - interactive=interactive, - ) - - # Set up offline chat models - _setup_chat_model_provider( - ChatModel.ModelType.OFFLINE, - default_offline_chat_models, - default_api_key=None, - vision_enabled=False, - is_offline=True, interactive=interactive, ) logger.info("🗣️ Chat model configuration complete") - # Set up offline speech to text model - use_offline_speech2text_model = "n" if not interactive else input("Use offline speech to text model? (y/n): ") - if use_offline_speech2text_model == "y": - logger.info("🗣️ Setting up offline speech to text model") - # Delete any existing speech to text model options. There can only be one. - SpeechToTextModelOptions.objects.all().delete() - - default_offline_speech2text_model = "base" - offline_speech2text_model = input( - f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): " - ) - offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model - SpeechToTextModelOptions.objects.create( - model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE - ) - - logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}") - def _setup_chat_model_provider( model_type: ChatModel.ModelType, default_chat_models: list, @@ -186,7 +154,6 @@ def initialization(interactive: bool = True): interactive: bool, api_base_url: str = None, vision_enabled: bool = False, - is_offline: bool = False, provider_name: str = None, ) -> Tuple[bool, AiModelApi]: supported_vision_models = ( @@ -195,11 +162,6 @@ def initialization(interactive: bool = True): provider_name = provider_name or model_type.name.capitalize() default_use_model = default_api_key is not None - # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. - # This usually means there's not enough RAM to load offline models directly within the application. - # In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case. - if is_offline: - default_use_model = False use_model_provider = ( default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y" @@ -211,13 +173,12 @@ def initialization(interactive: bool = True): logger.info(f"️💬 Setting up your {provider_name} chat configuration") ai_model_api = None - if not is_offline: - if interactive: - user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ") - api_key = user_api_key if user_api_key != "" else default_api_key - else: - api_key = default_api_key - ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) + if interactive: + user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ") + api_key = user_api_key if user_api_key != "" else default_api_key + else: + api_key = default_api_key + ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) if interactive: user_chat_models = input( diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index df7f3334..5377577b 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -48,17 +48,6 @@ class FilesFilterRequest(BaseModel): conversation_id: str -class TextConfigBase(ConfigBase): - compressed_jsonl: Path - embeddings_file: Path - - -class TextContentConfig(ConfigBase): - input_files: Optional[List[Path]] = None - input_filter: Optional[List[str]] = None - index_heading_entries: Optional[bool] = False - - class GithubRepoConfig(ConfigBase): name: str owner: str @@ -74,62 +63,6 @@ class NotionContentConfig(ConfigBase): token: str -class ContentConfig(ConfigBase): - org: Optional[TextContentConfig] = None - markdown: Optional[TextContentConfig] = None - pdf: Optional[TextContentConfig] = None - plaintext: Optional[TextContentConfig] = None - github: Optional[GithubContentConfig] = None - notion: Optional[NotionContentConfig] = None - image: Optional[TextContentConfig] = None - docx: Optional[TextContentConfig] = None - - -class ImageSearchConfig(ConfigBase): - encoder: str - encoder_type: Optional[str] = None - model_directory: Optional[Path] = None - - class Config: - protected_namespaces = () - - -class SearchConfig(ConfigBase): - image: Optional[ImageSearchConfig] = None - - -class OpenAIProcessorConfig(ConfigBase): - api_key: str - chat_model: Optional[str] = "gpt-4o-mini" - - -class OfflineChatProcessorConfig(ConfigBase): - chat_model: Optional[str] = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" - - -class ConversationProcessorConfig(ConfigBase): - openai: Optional[OpenAIProcessorConfig] = None - offline_chat: Optional[OfflineChatProcessorConfig] = None - max_prompt_size: Optional[int] = None - tokenizer: Optional[str] = None - - -class ProcessorConfig(ConfigBase): - conversation: Optional[ConversationProcessorConfig] = None - - -class AppConfig(ConfigBase): - should_log_telemetry: bool = True - - -class FullConfig(ConfigBase): - content_type: Optional[ContentConfig] = None - search_type: Optional[SearchConfig] = None - processor: Optional[ProcessorConfig] = None - app: Optional[AppConfig] = AppConfig() - version: Optional[str] = None - - class SearchResponse(ConfigBase): entry: str score: float diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index f96409c2..3b65a85b 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -12,19 +12,14 @@ from whisper import Whisper from khoj.database.models import ProcessLock from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config -from khoj.utils.config import OfflineChatProcessorModel, SearchModels from khoj.utils.helpers import LRU, get_device, is_env_var_true -from khoj.utils.rawconfig import FullConfig # Application Global State -config = FullConfig() -search_models = SearchModels() embeddings_model: Dict[str, EmbeddingsModel] = None cross_encoder_model: Dict[str, CrossEncoderModel] = None openai_client: OpenAI = None -offline_chat_processor_config: OfflineChatProcessorModel = None whisper_model: Whisper = None -config_file: Path = None +log_file: Path = None verbose: int = 0 host: str = None port: int = None @@ -39,7 +34,6 @@ telemetry: List[Dict[str, str]] = [] telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE") khoj_version: str = None device = get_device() -chat_on_gpu: bool = True anonymous_mode: bool = False pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict() billing_enabled: bool = ( diff --git a/src/khoj/utils/yaml.py b/src/khoj/utils/yaml.py index f658e1eb..43b139e5 100644 --- a/src/khoj/utils/yaml.py +++ b/src/khoj/utils/yaml.py @@ -1,47 +1,8 @@ -from pathlib import Path - import yaml -from khoj.utils import state -from khoj.utils.rawconfig import FullConfig - # Do not emit tags when dumping to YAML yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None # type: ignore[assignment] -def save_config_to_file_updated_state(): - with open(state.config_file, "w") as outfile: - yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile) - outfile.close() - return state.config - - -def save_config_to_file(yaml_config: dict, yaml_config_file: Path): - "Write config to YML file" - # Create output directory, if it doesn't exist - yaml_config_file.parent.mkdir(parents=True, exist_ok=True) - - with open(yaml_config_file, "w", encoding="utf-8") as config_file: - yaml.safe_dump(yaml_config, config_file, allow_unicode=True) - - -def load_config_from_file(yaml_config_file: Path) -> dict: - "Read config from YML file" - config_from_file = None - with open(yaml_config_file, "r", encoding="utf-8") as config_file: - config_from_file = yaml.safe_load(config_file) - return config_from_file - - -def parse_config_from_string(yaml_config: dict) -> FullConfig: - "Parse and validate config in YML string" - return FullConfig.model_validate(yaml_config) - - -def parse_config_from_file(yaml_config_file): - "Parse and validate config in YML file" - return parse_config_from_string(load_config_from_file(yaml_config_file)) - - def yaml_dump(data): return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False) diff --git a/tests/conftest.py b/tests/conftest.py index 25876d61..dd448bd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,3 @@ -import os -from pathlib import Path - import pytest from fastapi import FastAPI from fastapi.staticfiles import StaticFiles @@ -11,6 +8,7 @@ from khoj.configure import ( configure_routes, configure_search_types, ) +from khoj.database.adapters import get_default_search_model from khoj.database.models import ( Agent, ChatModel, @@ -19,21 +17,14 @@ from khoj.database.models import ( GithubRepoConfig, KhojApiUser, KhojUser, - LocalMarkdownConfig, - LocalOrgConfig, - LocalPdfConfig, - LocalPlaintextConfig, ) from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.api_content import configure_content from khoj.search_type import text_search -from khoj.utils import fs_syncer, state -from khoj.utils.config import SearchModels +from khoj.utils import state from khoj.utils.constants import web_directory -from khoj.utils.helpers import resolve_absolute_path -from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig from tests.helpers import ( AiModelApiFactory, ChatModelFactory, @@ -43,6 +34,8 @@ from tests.helpers import ( UserFactory, get_chat_api_key, get_chat_provider, + get_index_files, + get_sample_data, ) @@ -59,23 +52,16 @@ def django_db_setup(django_db_setup, django_db_blocker): @pytest.fixture(scope="session") -def search_config() -> SearchConfig: +def search_config(): + search_model = get_default_search_model() state.embeddings_model = dict() - state.embeddings_model["default"] = EmbeddingsModel() - state.cross_encoder_model = dict() - state.cross_encoder_model["default"] = CrossEncoderModel() - - model_dir = resolve_absolute_path("~/.khoj/search") - model_dir.mkdir(parents=True, exist_ok=True) - search_config = SearchConfig() - - search_config.image = ImageSearchConfig( - encoder="sentence-transformers/clip-ViT-B-32", - model_directory=model_dir / "image/", - encoder_type=None, + state.embeddings_model["default"] = EmbeddingsModel( + model_name=search_model.bi_encoder, model_kwargs=search_model.bi_encoder_model_config + ) + state.cross_encoder_model = dict() + state.cross_encoder_model["default"] = CrossEncoderModel( + model_name=search_model.cross_encoder, model_kwargs=search_model.cross_encoder_model_config ) - - return search_config @pytest.mark.django_db @@ -196,17 +182,6 @@ def default_openai_chat_model_option(): return chat_model -@pytest.mark.django_db -@pytest.fixture -def offline_agent(): - chat_model = ChatModelFactory() - return Agent.objects.create( - name="Accountant", - chat_model=chat_model, - personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.", - ) - - @pytest.mark.django_db @pytest.fixture def openai_agent(): @@ -218,13 +193,6 @@ def openai_agent(): ) -@pytest.fixture(scope="session") -def search_models(search_config: SearchConfig): - search_models = SearchModels() - - return search_models - - @pytest.mark.django_db @pytest.fixture def default_process_lock(): @@ -236,72 +204,23 @@ def anyio_backend(): return "asyncio" -@pytest.mark.django_db @pytest.fixture(scope="function") -def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser): - content_dir = tmp_path_factory.mktemp("content") - - # Generate Image Embeddings from Test Images - content_config = ContentConfig() - - LocalOrgConfig.objects.create( - input_files=None, - input_filter=["tests/data/org/*.org"], - index_heading_entries=False, - user=default_user, - ) - - text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=False, user=default_user) - - if os.getenv("GITHUB_PAT_TOKEN"): - GithubConfig.objects.create( - pat_token=os.getenv("GITHUB_PAT_TOKEN"), - user=default_user, - ) - - GithubRepoConfig.objects.create( - owner="khoj-ai", - name="lantern", - branch="master", - github_config=GithubConfig.objects.get(user=default_user), - ) - - LocalPlaintextConfig.objects.create( - input_files=None, - input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"], - user=default_user, - ) - - return content_config - - -@pytest.fixture(scope="session") -def md_content_config(): - markdown_config = LocalMarkdownConfig.objects.create( - input_files=None, - input_filter=["tests/data/markdown/*.markdown"], - ) - - return markdown_config - - -@pytest.fixture(scope="function") -def chat_client(search_config: SearchConfig, default_user2: KhojUser): +def chat_client(search_config, default_user2: KhojUser): return chat_client_builder(search_config, default_user2, require_auth=False) @pytest.fixture(scope="function") -def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser): +def chat_client_with_auth(search_config, default_user2: KhojUser): return chat_client_builder(search_config, default_user2, require_auth=True) @pytest.fixture(scope="function") -def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser): +def chat_client_no_background(search_config, default_user2: KhojUser): return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False) @pytest.fixture(scope="function") -def chat_client_with_large_kb(search_config: SearchConfig, default_user2: KhojUser): +def chat_client_with_large_kb(search_config, default_user2: KhojUser): """ Chat client fixture that creates a large knowledge base with many files for stress testing atomic agent updates. @@ -312,19 +231,14 @@ def chat_client_with_large_kb(search_config: SearchConfig, default_user2: KhojUs @pytest.mark.django_db def chat_client_builder(search_config, user, index_content=True, require_auth=False): # Initialize app state - state.config.search_type = search_config state.SearchType = configure_search_types() if index_content: - LocalMarkdownConfig.objects.create( - input_files=None, - input_filter=["tests/data/markdown/*.markdown"], - user=user, - ) + file_type = "markdown" + files_to_index = {file_type: get_index_files(input_filters=[f"tests/data/{file_type}/*.{file_type}"])} # Index Markdown Content for Search - all_files = fs_syncer.collect_files(user=user) - configure_content(user, all_files) + configure_content(user, files_to_index) # Initialize Processor from Config chat_provider = get_chat_provider() @@ -360,17 +274,17 @@ def large_kb_chat_client_builder(search_config, user): import tempfile # Initialize app state - state.config.search_type = search_config state.SearchType = configure_search_types() # Create temporary directory for large number of test files temp_dir = tempfile.mkdtemp(prefix="khoj_test_large_kb_") + file_type = "markdown" large_file_list = [] try: # Generate 200 test files with substantial content for i in range(300): - file_path = os.path.join(temp_dir, f"test_file_{i:03d}.markdown") + file_path = os.path.join(temp_dir, f"test_file_{i:03d}.{file_type}") content = f""" # Test File {i} @@ -420,16 +334,9 @@ End of file {i}. f.write(content) large_file_list.append(file_path) - # Create LocalMarkdownConfig with all the generated files - LocalMarkdownConfig.objects.create( - input_files=large_file_list, - input_filter=None, - user=user, - ) - - # Index all the files into the user's knowledge base - all_files = fs_syncer.collect_files(user=user) - configure_content(user, all_files) + # Index all generated files into the user's knowledge base + files_to_index = {file_type: get_index_files(input_files=large_file_list, input_filters=None)} + configure_content(user, files_to_index) # Verify we have a substantial knowledge base file_count = FileObject.objects.filter(user=user, agent=None).count() @@ -481,12 +388,8 @@ def fastapi_app(): @pytest.fixture(scope="function") def client( - content_config: ContentConfig, - search_config: SearchConfig, api_user: KhojApiUser, ): - state.config.content_type = content_config - state.config.search_type = search_config state.SearchType = configure_search_types() state.embeddings_model = dict() state.embeddings_model["default"] = EmbeddingsModel() @@ -516,173 +419,18 @@ def client( return TestClient(app) -@pytest.fixture(scope="function") -def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): - # Initialize app state - state.config.search_type = search_config - state.SearchType = configure_search_types() - - LocalMarkdownConfig.objects.create( - input_files=None, - input_filter=["tests/data/markdown/*.markdown"], - user=default_user2, - ) - - all_files = fs_syncer.collect_files(user=default_user2) - configure_content(default_user2, all_files) - - # Initialize Processor from Config - ChatModelFactory( - name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF", - tokenizer=None, - max_prompt_size=None, - model_type="offline", - ) - UserConversationProcessorConfigFactory(user=default_user2) - - state.anonymous_mode = True - - app = FastAPI() - - configure_routes(app) - configure_middleware(app) - app.mount("/static", StaticFiles(directory=web_directory), name="static") - return TestClient(app) - - -@pytest.fixture(scope="function") -def new_org_file(default_user: KhojUser, content_config: ContentConfig): - # Setup - org_config = LocalOrgConfig.objects.filter(user=default_user).first() - input_filters = org_config.input_filter - new_org_file = Path(input_filters[0]).parent / "new_file.org" - new_org_file.touch() - - yield new_org_file - - # Cleanup - if new_org_file.exists(): - new_org_file.unlink() - - -@pytest.fixture(scope="function") -def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser): - LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None) - return LocalOrgConfig.objects.filter(user=default_user).first() - - @pytest.fixture(scope="function") def pdf_configured_user1(default_user: KhojUser): - LocalPdfConfig.objects.create( - input_files=None, - input_filter=["tests/data/pdf/singlepage.pdf"], - user=default_user, - ) - # Index Markdown Content for Search - all_files = fs_syncer.collect_files(user=default_user) - configure_content(default_user, all_files) + # Read data from pdf file at tests/data/pdf/singlepage.pdf + pdf_file_path = "tests/data/pdf/singlepage.pdf" + with open(pdf_file_path, "rb") as pdf_file: + pdf_data = pdf_file.read() + + knowledge_base = {"pdf": {"singlepage.pdf": pdf_data}} + # Index Content for Search + configure_content(default_user, knowledge_base) @pytest.fixture(scope="function") def sample_org_data(): return get_sample_data("org") - - -def get_sample_data(type): - sample_data = { - "org": { - "elisp.org": """ -* Emacs Khoj - /An Emacs interface for [[https://github.com/khoj-ai/khoj][khoj]]/ - -** Requirements - - Install and Run [[https://github.com/khoj-ai/khoj][khoj]] - -** Installation -*** Direct - - Put ~khoj.el~ in your Emacs load path. For e.g. ~/.emacs.d/lisp - - Load via ~use-package~ in your ~/.emacs.d/init.el or .emacs file by adding below snippet - #+begin_src elisp - ;; Khoj Package - (use-package khoj - :load-path "~/.emacs.d/lisp/khoj.el" - :bind ("C-c s" . 'khoj)) - #+end_src - -*** Using [[https://github.com/quelpa/quelpa#installation][Quelpa]] - - Ensure [[https://github.com/quelpa/quelpa#installation][Quelpa]], [[https://github.com/quelpa/quelpa-use-package#installation][quelpa-use-package]] are installed - - Add below snippet to your ~/.emacs.d/init.el or .emacs config file and execute it. - #+begin_src elisp - ;; Khoj Package - (use-package khoj - :quelpa (khoj :fetcher url :url "https://raw.githubusercontent.com/khoj-ai/khoj/master/interface/emacs/khoj.el") - :bind ("C-c s" . 'khoj)) - #+end_src - -** Usage - 1. Call ~khoj~ using keybinding ~C-c s~ or ~M-x khoj~ - 2. Enter Query in Natural Language - e.g. "What is the meaning of life?" "What are my life goals?" - 3. Wait for results - *Note: It takes about 15s on a Mac M1 and a ~100K lines corpus of org-mode files* - 4. (Optional) Narrow down results further - Include/Exclude specific words from results by adding to query - e.g. "What is the meaning of life? -god +none" - -""", - "readme.org": """ -* Khoj - /Allow natural language search on user content like notes, images using transformer based models/ - - All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline - -** Dependencies - - Python3 - - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]] - -** Install - #+begin_src shell - git clone https://github.com/khoj-ai/khoj && cd khoj - conda env create -f environment.yml - conda activate khoj - #+end_src""", - }, - "markdown": { - "readme.markdown": """ -# Khoj -Allow natural language search on user content like notes, images using transformer based models - -All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline - -## Dependencies -- Python3 -- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links) - -## Install -```shell -git clone -conda env create -f environment.yml -conda activate khoj -``` -""" - }, - "plaintext": { - "readme.txt": """ -Khoj -Allow natural language search on user content like notes, images using transformer based models - -All data is processed locally. User can interface with khoj app via Emacs, API or Commandline - -Dependencies -- Python3 -- Miniconda - -Install -git clone -conda env create -f environment.yml -conda activate khoj -""" - }, - } - - return sample_data[type] diff --git a/tests/helpers.py b/tests/helpers.py index d3c94abc..6edb0946 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,3 +1,5 @@ +import glob +import logging import os from datetime import datetime @@ -17,9 +19,12 @@ from khoj.database.models import ( UserConversationConfig, ) from khoj.processor.conversation.utils import message_to_log +from khoj.utils.helpers import get_absolute_path, is_none_or_empty + +logger = logging.getLogger(__name__) -def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE): +def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE): provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER") if provider and provider in ChatModel.ModelType: return ChatModel.ModelType(provider) @@ -61,6 +66,140 @@ def generate_chat_history(message_list): return chat_history +def get_sample_data(type): + sample_data = { + "org": { + "elisp.org": """ +* Emacs Khoj + /An Emacs interface for [[https://github.com/khoj-ai/khoj][khoj]]/ + +** Requirements + - Install and Run [[https://github.com/khoj-ai/khoj][khoj]] + +** Installation +*** Direct + - Put ~khoj.el~ in your Emacs load path. For e.g. ~/.emacs.d/lisp + - Load via ~use-package~ in your ~/.emacs.d/init.el or .emacs file by adding below snippet + #+begin_src elisp + ;; Khoj Package + (use-package khoj + :load-path "~/.emacs.d/lisp/khoj.el" + :bind ("C-c s" . 'khoj)) + #+end_src + +*** Using [[https://github.com/quelpa/quelpa#installation][Quelpa]] + - Ensure [[https://github.com/quelpa/quelpa#installation][Quelpa]], [[https://github.com/quelpa/quelpa-use-package#installation][quelpa-use-package]] are installed + - Add below snippet to your ~/.emacs.d/init.el or .emacs config file and execute it. + #+begin_src elisp + ;; Khoj Package + (use-package khoj + :quelpa (khoj :fetcher url :url "https://raw.githubusercontent.com/khoj-ai/khoj/master/interface/emacs/khoj.el") + :bind ("C-c s" . 'khoj)) + #+end_src + +** Usage + 1. Call ~khoj~ using keybinding ~C-c s~ or ~M-x khoj~ + 2. Enter Query in Natural Language + e.g. "What is the meaning of life?" "What are my life goals?" + 3. Wait for results + *Note: It takes about 15s on a Mac M1 and a ~100K lines corpus of org-mode files* + 4. (Optional) Narrow down results further + Include/Exclude specific words from results by adding to query + e.g. "What is the meaning of life? -god +none" + +""", + "readme.org": """ +* Khoj + /Allow natural language search on user content like notes, images using transformer based models/ + + All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline + +** Dependencies + - Python3 + - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]] + +** Install + #+begin_src shell + git clone https://github.com/khoj-ai/khoj && cd khoj + conda env create -f environment.yml + conda activate khoj + #+end_src""", + }, + "markdown": { + "readme.markdown": """ +# Khoj +Allow natural language search on user content like notes, images using transformer based models + +All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline + +## Dependencies +- Python3 +- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links) + +## Install +```shell +git clone +conda env create -f environment.yml +conda activate khoj +``` +""" + }, + "plaintext": { + "readme.txt": """ +Khoj +Allow natural language search on user content like notes, images using transformer based models + +All data is processed locally. User can interface with khoj app via Emacs, API or Commandline + +Dependencies +- Python3 +- Miniconda + +Install +git clone +conda env create -f environment.yml +conda activate khoj +""" + }, + } + + return sample_data[type] + + +def get_index_files( + input_files: list[str] = None, input_filters: list[str] | None = ["tests/data/org/*.org"] +) -> dict[str, str]: + # Input Validation + if is_none_or_empty(input_files) and is_none_or_empty(input_filters): + logger.debug("At least one of input_files or input_filter is required to be specified") + return {} + + # Get files to process + absolute_files, filtered_files = set(), set() + if input_files: + absolute_files = {get_absolute_path(input_file) for input_file in input_files} + if input_filters: + filtered_files = { + filtered_file + for file_filter in input_filters + for filtered_file in glob.glob(get_absolute_path(file_filter), recursive=True) + if os.path.isfile(filtered_file) + } + + all_files = sorted(absolute_files | filtered_files) + + filename_to_content_map = {} + for file in all_files: + with open(file, "r", encoding="utf8") as f: + try: + filename_to_content_map[file] = f.read() + except Exception as e: + logger.warning(f"Unable to read file: {file}. Skipping file.") + logger.warning(e, exc_info=True) + + return filename_to_content_map + + class UserFactory(factory.django.DjangoModelFactory): class Meta: model = KhojUser @@ -93,7 +232,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory): max_prompt_size = 20000 tokenizer = None - name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF" + name = "gemini-2.0-flash" model_type = get_chat_provider() ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None) diff --git a/tests/test_agents.py b/tests/test_agents.py index 21a242ef..1d3b96ec 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -15,7 +15,7 @@ from tests.helpers import ChatModelFactory def test_create_default_agent(default_user: KhojUser): ChatModelFactory() - agent = AgentAdapters.create_default_agent(default_user) + agent = AgentAdapters.create_default_agent() assert agent is not None assert agent.input_tools == [] assert agent.output_modes == [] diff --git a/tests/test_cli.py b/tests/test_cli.py index 211ff38e..15908653 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,49 +1,15 @@ # Standard Modules from pathlib import Path -from random import random from khoj.utils.cli import cli -from khoj.utils.helpers import resolve_absolute_path # Test # ---------------------------------------------------------------------------------------------------- def test_cli_minimal_default(): # Act - actual_args = cli([]) + actual_args = cli(["-vvv"]) # Assert - assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml")) - assert actual_args.regenerate == False - assert actual_args.verbose == 0 - - -# ---------------------------------------------------------------------------------------------------- -def test_cli_invalid_config_file_path(): - # Arrange - non_existent_config_file = f"non-existent-khoj-{random()}.yml" - - # Act - actual_args = cli([f"--config-file={non_existent_config_file}"]) - - # Assert - assert actual_args.config_file == resolve_absolute_path(non_existent_config_file) - assert actual_args.config == None - - -# ---------------------------------------------------------------------------------------------------- -def test_cli_config_from_file(): - # Act - actual_args = cli(["--config-file=tests/data/config.yml", "--regenerate", "-vvv"]) - - # Assert - assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml")) - assert actual_args.regenerate == True - assert actual_args.config is not None + assert actual_args.log_file == Path("~/.khoj/khoj.log") assert actual_args.verbose == 3 - - # Ensure content config is loaded from file - assert actual_args.config.content_type.org.input_files == [ - Path("~/first_from_config.org"), - Path("~/second_from_config.org"), - ] diff --git a/tests/test_client.py b/tests/test_client.py index 00507851..46732a86 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,7 +13,6 @@ from khoj.database.models import KhojApiUser, KhojUser from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.search_type import text_search from khoj.utils import state -from khoj.utils.rawconfig import ContentConfig, SearchConfig # Test @@ -283,10 +282,6 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser): def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): # Arrange state.anonymous_mode = True - if state.config and state.config.content_type: - state.config.content_type = None - state.search_models = configure_search_types() - configure_routes(fastapi_app) client = TestClient(fastapi_app) @@ -300,7 +295,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser): +def test_notes_search(client, search_config, sample_org_data, default_user: KhojUser): # Arrange headers = {"Authorization": "Bearer kk-secret"} text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) @@ -319,7 +314,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser): +def test_notes_search_no_results(client, search_config, sample_org_data, default_user: KhojUser): # Arrange headers = {"Authorization": "Bearer kk-secret"} text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) @@ -335,9 +330,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_with_only_filters( - client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser -): +def test_notes_search_with_only_filters(client, sample_org_data, default_user: KhojUser): # Arrange headers = {"Authorization": "Bearer kk-secret"} text_search.setup( @@ -401,9 +394,7 @@ def test_notes_search_with_exclude_filter(client, sample_org_data, default_user: # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) -def test_notes_search_requires_parent_context( - client, search_config: SearchConfig, sample_org_data, default_user: KhojUser -): +def test_notes_search_requires_parent_context(client, search_config, sample_org_data, default_user: KhojUser): # Arrange headers = {"Authorization": "Bearer kk-secret"} text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 9a36bd57..21f198ea 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -1,6 +1,13 @@ # Application Packages from khoj.search_filter.file_filter import FileFilter -from khoj.utils.rawconfig import Entry + + +# Mock Entry class for testing +class Entry: + def __init__(self, compiled="", raw="", file=""): + self.compiled = compiled + self.raw = raw + self.file = file def test_can_filter_no_file_filter(): diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py index 30813555..b8ab37a7 100644 --- a/tests/test_markdown_to_entries.py +++ b/tests/test_markdown_to_entries.py @@ -3,8 +3,6 @@ import re from pathlib import Path from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries -from khoj.utils.fs_syncer import get_markdown_files -from khoj.utils.rawconfig import TextContentConfig def test_extract_markdown_with_no_headings(tmp_path): @@ -212,43 +210,6 @@ longer body line 2.1 ), "Third entry is second entries child heading" -def test_get_markdown_files(tmp_path): - "Ensure Markdown files specified via input-filter, input-files extracted" - # Arrange - # Include via input-filter globs - group1_file1 = create_file(tmp_path, filename="group1-file1.md") - group1_file2 = create_file(tmp_path, filename="group1-file2.md") - group2_file1 = create_file(tmp_path, filename="group2-file1.markdown") - group2_file2 = create_file(tmp_path, filename="group2-file2.markdown") - # Include via input-file field - file1 = create_file(tmp_path, filename="notes.md") - # Not included by any filter - create_file(tmp_path, filename="not-included-markdown.md") - create_file(tmp_path, filename="not-included-text.txt") - - expected_files = set( - [os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]] - ) - - # Setup input-files, input-filters - input_files = [tmp_path / "notes.md"] - input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"] - - markdown_config = TextContentConfig( - input_files=input_files, - input_filter=[str(filter) for filter in input_filter], - compressed_jsonl=tmp_path / "test.jsonl", - embeddings_file=tmp_path / "test_embeddings.jsonl", - ) - - # Act - extracted_org_files = get_markdown_files(markdown_config) - - # Assert - assert len(extracted_org_files) == 5 - assert set(extracted_org_files.keys()) == expected_files - - def test_line_number_tracking_in_recursive_split(): "Ensure line numbers in URIs are correct after recursive splitting by checking against the actual file." # Arrange diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py deleted file mode 100644 index 979710b6..00000000 --- a/tests/test_offline_chat_actors.py +++ /dev/null @@ -1,610 +0,0 @@ -from datetime import datetime - -import pytest - -from khoj.database.models import ChatModel -from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions -from khoj.utils.helpers import ConversationCommand -from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider - -SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE -pytestmark = pytest.mark.skipif( - SKIP_TESTS, - reason="Disable in CI to avoid long test runs.", -) - -import freezegun -from freezegun import freeze_time - -from khoj.processor.conversation.offline.chat_model import converse_offline -from khoj.processor.conversation.offline.utils import download_model -from khoj.utils.constants import default_offline_chat_models - - -@pytest.fixture(scope="session") -def loaded_model(): - return download_model(default_offline_chat_models[0], max_tokens=5000) - - -freezegun.configure(extend_ignore_list=["transformers"]) - - -# Test -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2): - # Act - response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model) - - assert len(response) >= 1 - - assert any( - [ - "dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0], - "dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0], - 'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0], - 'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0], - ] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting") -@pytest.mark.chatquality -@freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2): - # Act - response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model) - - # Assert - assert len(response) >= 1 - # The user query should be the last question in the response - assert response[-1] == ["Which countries did I visit last month?"] - assert any( - [ - "dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0], - "dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0], - 'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0], - 'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0], - ] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting") -@pytest.mark.chatquality -@freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2): - # Act - response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model) - - # Assert - expected_responses = [ - ("dt>='1984-01-01'", ""), - ("dt>='1984-01-01'", "dt<'1985-01-01'"), - ("dt>='1984-01-01'", "dt<='1984-12-31'"), - ] - assert len(response) == 1 - assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( - "Expected date filter to limit to 1984 in response but got: " + response[0] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2): - # Act - responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model) - - # Assert - assert len(responses) >= 2 - assert ["the Sun" in response for response in responses] - assert ["the Moon" in response for response in responses] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2): - # Act - response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model) - - # Assert - expected_responses = ["height", "taller", "shorter", "heights", "who"] - assert len(response) <= 3 - - for question in response: - assert any([expected_response in question.lower() for expected_response in expected_responses]), ( - "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2): - # Arrange - message_list = [ - ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), - ] - query = "Does he have any sons?" - - # Act - response = extract_questions( - query, - default_user2, - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - use_history=True, - ) - - any_expected_with_barbara = [ - "sibling", - "brother", - ] - - any_expected_with_anderson = [ - "son", - "sons", - "children", - "family", - ] - - # Assert - assert len(response) >= 1 - # Ensure the remaining generated search queries use proper nouns and chat history context - for question in response: - if "Barbara" in question: - assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), ( - "Expected search queries using proper nouns and chat history for context, but got: " + question - ) - elif "Anderson" in question: - assert any([expected_response in question for expected_response in any_expected_with_anderson]), ( - "Expected search queries using proper nouns and chat history for context, but got: " + question - ) - else: - assert False, ( - "Expected search queries using proper nouns and chat history for context, but got: " + question - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2): - # Arrange - message_list = [ - ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), - ] - - # Act - response = extract_questions( - "Is she a Doctor?", - default_user2, - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - use_history=True, - ) - - expected_responses = [ - "Barbara", - "Anderson", - ] - - # Assert - assert len(response) >= 1 - assert any([expected_response in response[0] for expected_response in expected_responses]), ( - "Expected chat actor to mention person's by name, but got: " + response[0] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context") -@pytest.mark.chatquality -def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2): - # Arrange - message_list = [ - ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []), - ] - - # Act - response = extract_questions( - "What was the Pizza place we ate at over there?", - default_user2, - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - - # Assert - expected_responses = [ - ("dt>='2000-04-01'", "dt<'2000-05-01'"), - ("dt>='2000-04-01'", "dt<='2000-04-30'"), - ('dt>="2000-04-01"', 'dt<"2000-05-01"'), - ('dt>="2000-04-01"', 'dt<="2000-04-30"'), - ] - assert len(response) == 1 - assert "Masai Mara" in response[0] - assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( - "Expected date filter to limit to April 2000 in response but got: " + response[0] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -@pytest.mark.parametrize( - "user_query, expected_conversation_commands", - [ - ( - "Where did I learn to swim?", - {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text}, - ), - ( - "Where is the nearest hospital?", - {"sources": [ConversationCommand.Online], "output": ConversationCommand.Text}, - ), - ( - "Summarize the wikipedia page on the history of the internet", - {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text}, - ), - ( - "How many noble gases are there?", - {"sources": [ConversationCommand.General], "output": ConversationCommand.Text}, - ), - ( - "Make a painting incorporating my past diving experiences", - {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image}, - ), - ( - "Create a chart of the weather over the next 7 days in Timbuktu", - {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text}, - ), - ( - "What's the highest point in this country and have I been there?", - {"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text}, - ), - ], -) -async def test_select_data_sources_actor_chooses_to_search_notes( - client_offline_chat, user_query, expected_conversation_commands, default_user2 -): - # Act - selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2) - - # Assert - assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"]) - assert expected_conversation_commands["output"] == selected_conversation_commands["output"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2): - # Arrange - user_query = "What's the latest in the Israel/Palestine conflict?" - chat_log = [ - ( - "Let's talk about the current events around the world.", - "Sure, let's discuss the current events. What would you like to know?", - [], - ), - ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), - ] - chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log)) - - # Act - tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["online"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_chat_with_no_chat_history_or_retrieved_content(loaded_model): - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Hello, my name is Testatron. Who are you?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["Khoj", "khoj", "KHOJ"] - assert len(response) > 0 - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model): - "Chat actor needs to use context in previous notes and chat history to answer question" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ( - "When was I born?", - "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], - ), - ] - - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Where was I born?", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert len(response) > 0 - # Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes - assert "Testville" in response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): - "Chat actor needs to use context across currently retrieved notes and chat history to answer question" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - - # Act - response_gen = converse_offline( - references=[ - {"compiled": "Testatron was born on 1st April 1984 in Testville."} - ], # Assume context retrieved from notes for the user_query - user_query="Where was I born?", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert len(response) > 0 - assert "Testville" in response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer") -@pytest.mark.chatquality -def test_refuse_answering_unanswerable_question(loaded_model): - "Chat actor should not try make up answers to unanswerable questions." - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Where was I born?", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = [ - "don't know", - "do not know", - "no information", - "do not have", - "don't have", - "cannot answer", - "I'm sorry", - ] - assert len(response) > 0 - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected chat actor to say they don't know in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_requires_current_date_awareness(loaded_model): - "Chat actor should be able to answer questions relative to current date using provided notes" - # Arrange - context = [ - { - "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""2020-04-01 "SuperMercado" "Bananas" -Expenses:Food:Groceries 10.00 USD""" - }, - { - "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What did I have for Dinner today?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["tacos", "Tacos"] - assert len(response) > 0 - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected [T|t]acos in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model): - "Chat actor should be able to answer questions that require date aware aggregation across multiple notes" - # Arrange - context = [ - { - "compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""2020-04-01 "SuperMercado" "Bananas" -Expenses:Food:Groceries 10.00 USD""" - }, - { - "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="How much did I spend on dining this year?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert len(response) > 0 - assert "20" in response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model): - "Chat actor should be able to answer general questions not requiring looking at chat history or notes" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("Where was I born?", "You were born Testville.", []), - ] - - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Write a haiku about unit testing in 3 lines", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["test", "testing"] - assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines. - assert any([expected_response in response.lower() for expected_response in expected_responses]), ( - "Expected [T|t]est in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model): - "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" - # Arrange - context = [ - { - "compiled": f"""# Ramya -My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""" - }, - { - "compiled": f"""# Fang -My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""" - }, - { - "compiled": f"""# Aiyla -My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""" - }, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="How many kids does my older sister have?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister", "Which one"] - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected chat actor to ask for clarification in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_agent_prompt_should_be_used(loaded_model, offline_agent): - "Chat actor should ask be tuned to think like an accountant based on the agent definition" - # Arrange - context = [ - {"compiled": f"""I went to the store and bought some bananas for 2.20"""}, - {"compiled": f"""I went to the store and bought some apples for 1.30"""}, - {"compiled": f"""I went to the store and bought some oranges for 6.00"""}, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What did I buy?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert that the model without the agent prompt does not include the summary of purchases - expected_responses = ["9.50", "9.5"] - assert all([expected_response not in response for expected_response in expected_responses]), ( - "Expected chat actor to summarize values of purchases" + response - ) - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What did I buy?", - loaded_model=loaded_model, - agent=offline_agent, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert that the model with the agent prompt does include the summary of purchases - expected_responses = ["9.50", "9.5"] - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected chat actor to summarize values of purchases" + response - ) - - -# ---------------------------------------------------------------------------------------------------- -def test_chat_does_not_exceed_prompt_size(loaded_model): - "Ensure chat context and response together do not exceed max prompt size for the model" - # Arrange - prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed" - context = [{"compiled": " ".join([f"{number}" for number in range(2043)])}] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What numbers come after these?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert prompt_size_exceeded_error not in response, ( - "Expected chat response to be within prompt limits, but got exceeded error: " + response - ) diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py deleted file mode 100644 index 0eb4a0dc..00000000 --- a/tests/test_offline_chat_director.py +++ /dev/null @@ -1,726 +0,0 @@ -import urllib.parse - -import pytest -from faker import Faker -from freezegun import freeze_time - -from khoj.database.models import Agent, ChatModel, Entry, KhojUser -from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import message_to_log -from tests.helpers import ConversationFactory, get_chat_provider - -SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE -pytestmark = pytest.mark.skipif( - SKIP_TESTS, - reason="Disable in CI to avoid long test runs.", -) - -fake = Faker() - - -# Helpers -# ---------------------------------------------------------------------------------------------------- -def generate_history(message_list): - # Generate conversation logs - conversation_log = {"chat": []} - for user_message, gpt_message, context in message_list: - message_to_log( - user_message, - gpt_message, - {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - chat_history=conversation_log.get("chat", []), - ) - return conversation_log - - -def create_conversation(message_list, user, agent=None): - # Generate conversation logs - conversation_log = generate_history(message_list) - # Update Conversation Metadata Logs in Database - return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent) - - -# Tests -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat): - # Act - query = "Hello, my name is Testatron. Who are you?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Khoj", "khoj"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_chat_with_online_content(client_offline_chat): - # Act - q = "/online give me the link to paul graham's essay how to do great work" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.json()["response"] - - # Assert - expected_responses = [ - "https://paulgraham.com/greatwork.html", - "https://www.paulgraham.com/greatwork.html", - "http://www.paulgraham.com/greatwork.html", - ] - assert response.status_code == 200 - assert any( - [expected_response in response_message for expected_response in expected_responses] - ), f"Expected links: {expected_responses}. Actual response: {response_message}" - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_chat_with_online_webpage_content(client_offline_chat): - # Act - q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.json()["response"] - - # Assert - expected_responses = ["185", "1871", "horse"] - assert response.status_code == 200 - assert any( - [expected_response in response_message for expected_response in expected_responses] - ), f"Expected response with {expected_responses}. But actual response had: {response_message}" - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_chat_history(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - create_conversation(message_list, default_user2) - - # Act - q = "What is my name?" - response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Testatron", "testatron"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected [T|t]estatron in response but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ( - "When was I born?", - "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], - ), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was Xi Li born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "Fujiang" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ( - "When was I born?", - "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], - ), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was I born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - # 1. Infer who I am from chat history - # 2. Infer I was born in Testville from previously retrieved notes - assert "Testville" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was I born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - # Inference in a multi-turn conversation - # 1. Infer who I am from chat history - # 2. Search for notes about when was born - # 3. Extract where I was born from currently retrieved notes - assert "Fujiang" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2): - "Chat director should say don't know as not enough contexts in chat history or retrieved to answer question" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was I born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected chat director to say they don't know in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_using_general_command(client_offline_chat, default_user2): - # Arrange - message_list = [] - create_conversation(message_list, default_user2) - - # Act - query = "/general Where was Xi Li born?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "Fujiang" not in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2): - # Arrange - message_list = [] - create_conversation(message_list, default_user2) - - # Act - query = "/notes Where was Xi Li born?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "Fujiang" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_using_file_filter(client_offline_chat, default_user2): - # Arrange - no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"') - answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"') - message_list = [] - create_conversation(message_list, default_user2) - - # Act - no_answer_response = client_offline_chat.post( - f"/api/chat", json={"q": no_answer_query, "stream": True} - ).content.decode("utf-8") - answer_response = client_offline_chat.post(f"/api/chat", json={"q": answer_query, "stream": True}).content.decode( - "utf-8" - ) - - # Assert - assert "Fujiang" not in no_answer_response - assert "Fujiang" in answer_response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_not_known_using_notes_command(client_offline_chat, default_user2): - # Arrange - message_list = [] - create_conversation(message_list, default_user2) - - # Act - query = urllib.parse.quote("/notes Where was Testatron born?") - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert response_message == prompts.no_notes_found.format() - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_one_file(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "Xi Li.markdown" file to the file filters - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - # pick the file that has "Xi Li.markdown" in the name - summarization_file = "" - for file in file_list: - if "Birthday Gift for Xiu turning 4.markdown" in file: - summarization_file = file - break - assert summarization_file != "" - - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation.id)}, - ) - - # Act - query = "/summarize" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert - assert response_message != "" - assert response_message != "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "Xi Li.markdown" file to the file filters - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - # pick the file that has "Xi Li.markdown" in the name - summarization_file = "" - for file in file_list: - if "Birthday Gift for Xiu turning 4.markdown" in file: - summarization_file = file - break - assert summarization_file != "" - - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation.id)}, - ) - - # Act - query = "/summarize tell me about Xiu" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert - assert response_message != "" - assert response_message != "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "Xi Li.markdown" file to the file filters - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - - response = client_offline_chat.post( - "api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)} - ) - response = client_offline_chat.post( - "api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)} - ) - - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - - # Assert - assert response_message is not None - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_no_files(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser): - message_list = [] - conversation1 = create_conversation(message_list, default_user2) - conversation2 = create_conversation(message_list, default_user2) - - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - summarization_file = "" - for file in file_list: - if "Birthday Gift for Xiu turning 4.markdown" in file: - summarization_file = file - break - assert summarization_file != "" - - # add file filter to conversation 1. - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation1.id)}, - ) - - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id}) - response_message = response.json()["response"] - - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - # now make sure that the file filter is still in conversation 1 - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id}) - response_message = response.json()["response"] - - # Assert - assert response_message != "" - assert response_message != "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "imaginary.markdown" file to the file filters - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)}, - ) - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_diff_user_file( - client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser -): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # Get the pdf file called singlepage.pdf - file_list = ( - Entry.objects.filter(user=default_user, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - summarization_file = "" - for file in file_list: - if "singlepage.pdf" in file: - summarization_file = file - break - assert summarization_file != "" - # add singlepage.pdf to the file filters - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation.id)}, - ) - - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -@freeze_time("2023-04-01", ignore=["transformers"]) -def test_answer_requires_current_date_awareness(client_offline_chat): - "Chat actor should be able to answer questions relative to current date using provided notes" - # Act - query = "Where did I have lunch today?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Arak", "Medellin"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected chat director to say Arak, Medellin, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -@freeze_time("2023-04-01", ignore=["transformers"]) -def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat): - "Chat director should be able to answer questions that require date aware aggregation across multiple notes" - # Act - query = "How much did I spend on dining this year?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "26" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("Where was I born?", "You were born Testville.", []), - ] - create_conversation(message_list, default_user2) - - # Act - query = "Write a haiku about unit testing. Do not say anything else." - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["test", "Test"] - assert response.status_code == 200 - assert len(response_message.splitlines()) == 3 # haikus are 3 lines long - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected [T|t]est in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2): - # Act - query = "What is the name of Namitas older son" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = [ - "which of them is the older", - "which one is older", - "which of them is older", - "which one is the older", - ] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected chat director to ask for clarification in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not capable of answering this question yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("Where was I born?", "You were born Testville.", []), - ] - create_conversation(message_list, default_user2) - - # Act - query = "What is my name?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Testatron", "testatron"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected [T|t]estatron in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("What's my favorite color", "Your favorite color is green.", []), - ("Where was I born?", "You were born Testville.", []), - ] - message_list2 = [ - ("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 14th August 1947.", []), - ("What's my favorite color", "Your favorite color is maroon.", []), - ("Where was I born?", "You were born in a potato farm.", []), - ] - conversation = create_conversation(message_list, default_user2) - create_conversation(message_list2, default_user2) - - # Act - query = "What is my favorite color?" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["green"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected green in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_by_conversation_id_with_agent( - client_offline_chat, default_user2: KhojUser, offline_agent: Agent -): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("What's my favorite color", "Your favorite color is green.", []), - ("Where was I born?", "You were born Testville.", []), - ("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []), - ] - conversation = create_conversation(message_list, default_user2, offline_agent) - - # Act - query = "/general What did I eat for breakfast?" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert that agent only responds with the summary of spending - expected_responses = ["13.00", "13", "13.0", "thirteen"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected green in response, but got: " + response_message - ) - - -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_chat_history_very_long(client_offline_chat, default_user2): - # Arrange - message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)] - create_conversation(message_list, default_user2) - - # Act - query = "What is my name?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert len(response_message) > 0 - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_requires_multiple_independent_searches(client_offline_chat): - "Chat director should be able to answer by doing multiple independent searches for required information" - # Act - query = "Is Xi older than Namita?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected Xi is older than Namita, but got: " + response_message - ) diff --git a/tests/test_org_to_entries.py b/tests/test_org_to_entries.py index d5dcdbd2..0196ef6c 100644 --- a/tests/test_org_to_entries.py +++ b/tests/test_org_to_entries.py @@ -4,9 +4,8 @@ import time from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.text_to_entries import TextToEntries -from khoj.utils.fs_syncer import get_org_files from khoj.utils.helpers import is_none_or_empty -from khoj.utils.rawconfig import Entry, TextContentConfig +from khoj.utils.rawconfig import Entry def test_configure_indexing_heading_only_entries(tmp_path): @@ -330,46 +329,6 @@ def test_file_with_no_headings_to_entry(tmp_path): assert len(entries[1]) == 1 -def test_get_org_files(tmp_path): - "Ensure Org files specified via input-filter, input-files extracted" - # Arrange - # Include via input-filter globs - group1_file1 = create_file(tmp_path, filename="group1-file1.org") - group1_file2 = create_file(tmp_path, filename="group1-file2.org") - group2_file1 = create_file(tmp_path, filename="group2-file1.org") - group2_file2 = create_file(tmp_path, filename="group2-file2.org") - # Include via input-file field - orgfile1 = create_file(tmp_path, filename="orgfile1.org") - # Not included by any filter - create_file(tmp_path, filename="orgfile2.org") - create_file(tmp_path, filename="text1.txt") - - expected_files = set( - [ - os.path.join(tmp_path, file.name) - for file in [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1] - ] - ) - - # Setup input-files, input-filters - input_files = [tmp_path / "orgfile1.org"] - input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"] - - org_config = TextContentConfig( - input_files=input_files, - input_filter=[str(filter) for filter in input_filter], - compressed_jsonl=tmp_path / "test.jsonl", - embeddings_file=tmp_path / "test_embeddings.jsonl", - ) - - # Act - extracted_org_files = get_org_files(org_config) - - # Assert - assert len(extracted_org_files) == 5 - assert set(extracted_org_files.keys()) == expected_files - - def test_extract_entries_with_different_level_headings(tmp_path): "Extract org entries with different level headings." # Arrange diff --git a/tests/test_pdf_to_entries.py b/tests/test_pdf_to_entries.py index a62eca8b..d7336fdc 100644 --- a/tests/test_pdf_to_entries.py +++ b/tests/test_pdf_to_entries.py @@ -4,8 +4,6 @@ import re import pytest from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries -from khoj.utils.fs_syncer import get_pdf_files -from khoj.utils.rawconfig import TextContentConfig def test_single_page_pdf_to_jsonl(): @@ -61,43 +59,6 @@ def test_ocr_page_pdf_to_jsonl(): assert re.search(expected_str_with_variable_spaces, raw_entry) is not None -def test_get_pdf_files(tmp_path): - "Ensure Pdf files specified via input-filter, input-files extracted" - # Arrange - # Include via input-filter globs - group1_file1 = create_file(tmp_path, filename="group1-file1.pdf") - group1_file2 = create_file(tmp_path, filename="group1-file2.pdf") - group2_file1 = create_file(tmp_path, filename="group2-file1.pdf") - group2_file2 = create_file(tmp_path, filename="group2-file2.pdf") - # Include via input-file field - file1 = create_file(tmp_path, filename="document.pdf") - # Not included by any filter - create_file(tmp_path, filename="not-included-document.pdf") - create_file(tmp_path, filename="not-included-text.txt") - - expected_files = set( - [os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]] - ) - - # Setup input-files, input-filters - input_files = [tmp_path / "document.pdf"] - input_filter = [tmp_path / "group1*.pdf", tmp_path / "group2*.pdf"] - - pdf_config = TextContentConfig( - input_files=input_files, - input_filter=[str(path) for path in input_filter], - compressed_jsonl=tmp_path / "test.jsonl", - embeddings_file=tmp_path / "test_embeddings.jsonl", - ) - - # Act - extracted_pdf_files = get_pdf_files(pdf_config) - - # Assert - assert len(extracted_pdf_files) == 5 - assert set(extracted_pdf_files.keys()) == expected_files - - # Helper Functions def create_file(tmp_path, entry=None, filename="document.pdf"): pdf_file = tmp_path / filename diff --git a/tests/test_plaintext_to_entries.py b/tests/test_plaintext_to_entries.py index a085b2b5..558832d3 100644 --- a/tests/test_plaintext_to_entries.py +++ b/tests/test_plaintext_to_entries.py @@ -1,27 +1,20 @@ -import os from pathlib import Path +from textwrap import dedent -from khoj.database.models import KhojUser, LocalPlaintextConfig from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries -from khoj.utils.fs_syncer import get_plaintext_files -from khoj.utils.rawconfig import TextContentConfig -def test_plaintext_file(tmp_path): +def test_plaintext_file(): "Convert files with no heading to jsonl." # Arrange raw_entry = f""" Hi, I am a plaintext file and I have some plaintext words. """ - plaintextfile = create_file(tmp_path, raw_entry) + plaintextfile = "test.txt" + data = {plaintextfile: raw_entry} # Act # Extract Entries from specified plaintext files - - data = { - f"{plaintextfile}": raw_entry, - } - entries = PlaintextToEntries.extract_plaintext_entries(data) # Convert each entry.file to absolute path to make them JSON serializable @@ -37,59 +30,20 @@ def test_plaintext_file(tmp_path): assert entries[1][0].compiled == f"{plaintextfile}\n{raw_entry}" -def test_get_plaintext_files(tmp_path): - "Ensure Plaintext files specified via input-filter, input-files extracted" - # Arrange - # Include via input-filter globs - group1_file1 = create_file(tmp_path, filename="group1-file1.md") - group1_file2 = create_file(tmp_path, filename="group1-file2.md") - - group2_file1 = create_file(tmp_path, filename="group2-file1.markdown") - group2_file2 = create_file(tmp_path, filename="group2-file2.markdown") - group2_file4 = create_file(tmp_path, filename="group2-file4.html") - # Include via input-file field - file1 = create_file(tmp_path, filename="notes.txt") - # Include unsupported file types - create_file(tmp_path, filename="group2-unincluded.py") - create_file(tmp_path, filename="group2-unincluded.csv") - create_file(tmp_path, filename="group2-unincluded.csv") - create_file(tmp_path, filename="group2-file3.mbox") - # Not included by any filter - create_file(tmp_path, filename="not-included-markdown.md") - create_file(tmp_path, filename="not-included-text.txt") - - expected_files = set( - [ - os.path.join(tmp_path, file.name) - for file in [group1_file1, group1_file2, group2_file1, group2_file2, group2_file4, file1] - ] - ) - - # Setup input-files, input-filters - input_files = [tmp_path / "notes.txt"] - input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.*"] - - plaintext_config = TextContentConfig( - input_files=input_files, - input_filter=[str(filter) for filter in input_filter], - compressed_jsonl=tmp_path / "test.jsonl", - embeddings_file=tmp_path / "test_embeddings.jsonl", - ) - - # Act - extracted_plaintext_files = get_plaintext_files(plaintext_config) - - # Assert - assert len(extracted_plaintext_files) == len(expected_files) - assert set(extracted_plaintext_files.keys()) == set(expected_files) - - -def test_parse_html_plaintext_file(content_config, default_user: KhojUser): +def test_parse_html_plaintext_file(tmp_path): "Ensure HTML files are parsed correctly" # Arrange - # Setup input-files, input-filters - config = LocalPlaintextConfig.objects.filter(user=default_user).first() - extracted_plaintext_files = get_plaintext_files(config=config) + raw_entry = dedent( + f""" + + Test HTML + +
Test content
+ + + """ + ) + extracted_plaintext_files = {"test.html": raw_entry} # Act entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 712f4aba..9e532429 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -2,23 +2,16 @@ import asyncio import logging import os -from pathlib import Path import pytest from khoj.database.adapters import EntryAdapters -from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig -from khoj.processor.content.docx.docx_to_entries import DocxToEntries +from khoj.database.models import Entry, GithubConfig, KhojUser from khoj.processor.content.github.github_to_entries import GithubToEntries -from khoj.processor.content.images.image_to_entries import ImageToEntries -from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries -from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries -from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.content.text_to_entries import TextToEntries from khoj.search_type import text_search -from khoj.utils.fs_syncer import collect_files, get_org_files -from khoj.utils.rawconfig import ContentConfig, SearchConfig +from tests.helpers import get_index_files, get_sample_data logger = logging.getLogger(__name__) @@ -26,53 +19,20 @@ logger = logging.getLogger(__name__) # Test # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig): - # Arrange - # Ensure file mentioned in org.input-files is missing - single_new_file = Path(org_config_with_only_new_file.input_files[0]) - single_new_file.unlink() - - # Act - # Generate notes embeddings during asymmetric setup - with pytest.raises(FileNotFoundError): - get_org_files(org_config_with_only_new_file) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.django_db -def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser): - # Arrange - orgfile = tmp_path / "directory.org" / "file.org" - orgfile.parent.mkdir() - with open(orgfile, "w") as f: - f.write("* Heading\n- List item\n") - - LocalOrgConfig.objects.create( - input_filter=[f"{tmp_path}/**/*"], - input_files=None, - user=default_user, - ) - - # Act - org_files = collect_files(user=default_user)["org"] - - # Assert - # should return orgfile and not raise IsADirectoryError - assert org_files == {f"{orgfile}": "* Heading\n- List item\n"} - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.django_db -def test_text_search_setup_with_empty_file_creates_no_entries( - org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser -): +def test_text_search_setup_with_empty_file_creates_no_entries(search_config, default_user: KhojUser): # Arrange + initial_data = { + "test.org": "* First heading\nFirst content", + "test2.org": "* Second heading\nSecond content", + } + text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) existing_entries = Entry.objects.filter(user=default_user).count() - data = get_org_files(org_config_with_only_new_file) + + final_data = {"new_file.org": ""} # Act # Generate notes embeddings during asymmetric setup - text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) + text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user) # Assert updated_entries = Entry.objects.filter(user=default_user).count() @@ -84,13 +44,14 @@ def test_text_search_setup_with_empty_file_creates_no_entries( # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_indexer_deletes_embedding_before_regenerate( - content_config: ContentConfig, default_user: KhojUser, caplog -): +def test_text_indexer_deletes_embedding_before_regenerate(search_config, default_user: KhojUser, caplog): # Arrange + data = { + "test1.org": "* Test heading\nTest content", + "test2.org": "* Another heading\nAnother content", + } + text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) existing_entries = Entry.objects.filter(user=default_user).count() - org_config = LocalOrgConfig.objects.filter(user=default_user).first() - data = get_org_files(org_config) # Act # Generate notes embeddings during asymmetric setup @@ -107,11 +68,10 @@ def test_text_indexer_deletes_embedding_before_regenerate( # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog): +def test_text_index_same_if_content_unchanged(search_config, default_user: KhojUser, caplog): # Arrange existing_entries = Entry.objects.filter(user=default_user) - org_config = LocalOrgConfig.objects.filter(user=default_user).first() - data = get_org_files(org_config) + data = {"test.org": "* Test heading\nTest content"} # Act # Generate initial notes embeddings during asymmetric setup @@ -136,20 +96,14 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -@pytest.mark.anyio -# @pytest.mark.asyncio -async def test_text_search(search_config: SearchConfig): +@pytest.mark.asyncio +async def test_text_search(search_config): # Arrange - default_user = await KhojUser.objects.acreate( + default_user, _ = await KhojUser.objects.aget_or_create( username="test_user", password="test_password", email="test@example.com" ) - org_config = await LocalOrgConfig.objects.acreate( - input_files=None, - input_filter=["tests/data/org/*.org"], - index_heading_entries=False, - user=default_user, - ) - data = get_org_files(org_config) + # Get some sample org data to index + data = get_sample_data("org") loop = asyncio.get_event_loop() await loop.run_in_executor( @@ -175,17 +129,15 @@ async def test_text_search(search_config: SearchConfig): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog): +def test_entry_chunking_by_max_tokens(tmp_path, search_config, default_user: KhojUser, caplog): # Arrange # Insert org-mode entry with size exceeding max token limit to new org file max_tokens = 256 - new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) - with open(new_file_to_index, "w") as f: - f.write(f"* Entry more than {max_tokens} words\n") - for index in range(max_tokens + 1): - f.write(f"{index} ") - - data = get_org_files(org_config_with_only_new_file) + new_file_to_index = tmp_path / "test.org" + content = f"* Entry more than {max_tokens} words\n" + for index in range(max_tokens + 1): + content += f"{index} " + data = {str(new_file_to_index): content} # Act # reload embeddings, entries, notes model after adding new org-mode file @@ -200,9 +152,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_entry_chunking_by_max_tokens_not_full_corpus( - org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog -): +def test_entry_chunking_by_max_tokens_not_full_corpus(tmp_path, search_config, default_user: KhojUser, caplog): # Arrange # Insert org-mode entry with size exceeding max token limit to new org file data = { @@ -231,13 +181,11 @@ conda activate khoj ) max_tokens = 256 - new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) - with open(new_file_to_index, "w") as f: - f.write(f"* Entry more than {max_tokens} words\n") - for index in range(max_tokens + 1): - f.write(f"{index} ") - - data = get_org_files(org_config_with_only_new_file) + new_file_to_index = tmp_path / "test.org" + content = f"* Entry more than {max_tokens} words\n" + for index in range(max_tokens + 1): + content += f"{index} " + data = {str(new_file_to_index): content} # Act # reload embeddings, entries, notes model after adding new org-mode file @@ -257,34 +205,34 @@ conda activate khoj # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser): +def test_regenerate_index_with_new_entry(search_config, default_user: KhojUser): # Arrange + # Initial indexed files + text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=True, user=default_user) existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - org_config = LocalOrgConfig.objects.filter(user=default_user).first() - initial_data = get_org_files(org_config) - # append org-mode entry to first org input file in config - org_config.input_files = [f"{new_org_file}"] - with open(new_org_file, "w") as f: - f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") - - final_data = get_org_files(org_config) - - # Act - text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) + # Regenerate index with only files from test data set + files_to_index = get_index_files() + text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user) updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) + # Act + # Update index with the new file + new_file = "test.org" + new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + files_to_index[new_file] = new_entry + # regenerate notes jsonl, model embeddings and model to include entry from new file - text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user) + text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user) updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # Assert for entry in updated_entries1: assert entry in updated_entries2 - assert not any([new_org_file.name in entry for entry in updated_entries1]) - assert not any([new_org_file.name in entry for entry in existing_entries]) - assert any([new_org_file.name in entry for entry in updated_entries2]) + assert not any([new_file in entry for entry in updated_entries1]) + assert not any([new_file in entry for entry in existing_entries]) + assert any([new_file in entry for entry in updated_entries2]) assert any( ["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2] @@ -294,28 +242,24 @@ def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_ # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_update_index_with_duplicate_entries_in_stable_order( - org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser -): +def test_update_index_with_duplicate_entries_in_stable_order(tmp_path, search_config, default_user: KhojUser): # Arrange + initial_data = get_sample_data("org") + text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user) existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) # Insert org-mode entries with same compiled form into new org file + new_file_to_index = tmp_path / "test.org" new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" - with open(new_file_to_index, "w") as f: - f.write(f"{new_entry}{new_entry}") - - data = get_org_files(org_config_with_only_new_file) + # Initial data with duplicate entries + data = {str(new_file_to_index): f"{new_entry}{new_entry}"} # Act # generate embeddings, entries, notes model from scratch after adding new org-mode file text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - data = get_org_files(org_config_with_only_new_file) - - # update embeddings, entries, notes model with no new changes + # idempotent indexing when data unchanged text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) @@ -324,6 +268,7 @@ def test_update_index_with_duplicate_entries_in_stable_order( for entry in existing_entries: assert entry not in updated_entries1 + # verify the second indexing update has same entries and ordering as first for entry in updated_entries1: assert entry in updated_entries2 @@ -334,22 +279,17 @@ def test_update_index_with_duplicate_entries_in_stable_order( # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser): +def test_update_index_with_deleted_entry(tmp_path, search_config, default_user: KhojUser): # Arrange existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) - # Insert org-mode entries with same compiled form into new org file + new_file_to_index = tmp_path / "test.org" new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" - with open(new_file_to_index, "w") as f: - f.write(f"{new_entry}{new_entry} -- Tatooine") - initial_data = get_org_files(org_config_with_only_new_file) - # update embeddings, entries, notes model after removing an entry from the org file - with open(new_file_to_index, "w") as f: - f.write(f"{new_entry}") - - final_data = get_org_files(org_config_with_only_new_file) + # Initial data with two entries + initial_data = {str(new_file_to_index): f"{new_entry}{new_entry} -- Tatooine"} + # Final data with only first entry, with second entry removed + final_data = {str(new_file_to_index): f"{new_entry}"} # Act # load embeddings, entries, notes model after adding new org file with 2 entries @@ -375,29 +315,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser): +def test_update_index_with_new_entry(search_config, default_user: KhojUser): # Arrange - existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - org_config = LocalOrgConfig.objects.filter(user=default_user).first() - data = get_org_files(org_config) - text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) + # Initial indexed files + text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=True, user=default_user) + old_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) - # append org-mode entry to first org input file in config - with open(new_org_file, "w") as f: - new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" - f.write(new_entry) - - data = get_org_files(org_config) + # Regenerate index with only files from test data set + files_to_index = get_index_files() + new_entries = text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user) # Act - # update embeddings, entries with the newly added note - text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) - updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) + # Update index with the new file + new_file = "test.org" + new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + final_data = {new_file: new_entry} + + text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user) + updated_new_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True)) # Assert - for entry in existing_entries: - assert entry not in updated_entries1 - assert len(updated_entries1) == len(existing_entries) + 1 + for old_entry in old_entries: + assert old_entry not in updated_new_entries + assert len(updated_new_entries) == len(new_entries) + 1 verify_embeddings(3, default_user) @@ -409,9 +349,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file (OrgToEntries), ], ) -def test_update_index_with_deleted_file( - org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser -): +def test_update_index_with_deleted_file(text_to_entries: TextToEntries, search_config, default_user: KhojUser): "Delete entries associated with new file when file path with empty content passed." # Arrange file_to_index = "test" @@ -446,7 +384,7 @@ def test_update_index_with_deleted_file( # ---------------------------------------------------------------------------------------------------- @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") -def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser): +def test_text_search_setup_github(search_config, default_user: KhojUser): # Arrange github_config = GithubConfig.objects.filter(user=default_user).first() diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index ebd6cccf..5333e17f 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -1,6 +1,12 @@ # Application Packages from khoj.search_filter.word_filter import WordFilter -from khoj.utils.rawconfig import Entry + + +# Mock Entry class for testing +class Entry: + def __init__(self, compiled="", raw=""): + self.compiled = compiled + self.raw = raw # Test