diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
index 45da7dda..641bf68b 100644
--- a/.devcontainer/Dockerfile
+++ b/.devcontainer/Dockerfile
@@ -18,8 +18,8 @@ ENV PATH="/opt/venv/bin:${PATH}"
COPY pyproject.toml README.md ./
# Setup python environment
- # Use the pre-built llama-cpp-python, torch cpu wheel
-ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://abetlen.github.io/llama-cpp-python/whl/cpu" \
+ # Use the pre-built torch cpu wheel
+ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" \
# Avoid downloading unused cuda specific python packages
CUDA_VISIBLE_DEVICES="" \
# Use static version to build app without git dependency
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index e8db2c89..8e04a8fc 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -64,8 +64,6 @@ jobs:
DEBIAN_FRONTEND: noninteractive
run: |
apt update && apt install -y git libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
- # required by llama-cpp-python prebuilt wheels
- apt install -y musl-dev && ln -s /usr/lib/x86_64-linux-musl/libc.so /lib/libc.musl-x86_64.so.1
- name: ⬇️ Install Postgres
env:
diff --git a/documentation/docs/advanced/admin.md b/documentation/docs/advanced/admin.md
index e04f81e2..961056b8 100644
--- a/documentation/docs/advanced/admin.md
+++ b/documentation/docs/advanced/admin.md
@@ -20,7 +20,7 @@ Add all the agents you want to use for your different use-cases like Writer, Res
### Chat Model Options
Add all the chat models you want to try, use and switch between for your different use-cases. For each chat model you add:
- `Chat model`: The name of an [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/en/docs/about-claude/models#model-names), [Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) or [Offline](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf) chat model.
-- `Model type`: The chat model provider like `OpenAI`, `Offline`.
+- `Model type`: The chat model provider like `OpenAI`, `Google`.
- `Vision enabled`: Set to `true` if your model supports vision. This is currently only supported for vision capable OpenAI models like `gpt-4o`
- `Max prompt size`, `Subscribed max prompt size`: These are optional fields. They are used to truncate the context to the maximum context size that can be passed to the model. This can help with accuracy and cost-saving.
- `Tokenizer`: This is an optional field. It is used to accurately count tokens and truncate context passed to the chat model to stay within the models max prompt size.
diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx
index e3c1e12f..01d60496 100644
--- a/documentation/docs/get-started/setup.mdx
+++ b/documentation/docs/get-started/setup.mdx
@@ -18,10 +18,6 @@ import TabItem from '@theme/TabItem';
These are the general setup instructions for self-hosted Khoj.
You can install the Khoj server using either [Docker](?server=docker) or [Pip](?server=pip).
-:::info[Offline Model + GPU]
-To use the offline chat model with your GPU, we recommend using the Docker setup with Ollama . You can also use the local Khoj setup via the Python package directly.
-:::
-
:::info[First Run]
Restart your Khoj server after the first run to ensure all settings are applied correctly.
:::
@@ -225,10 +221,6 @@ To start Khoj automatically in the background use [Task scheduler](https://www.w
You can now open the web app at http://localhost:42110 and start interacting!
Nothing else is necessary, but you can customize your setup further by following the steps below.
-:::info[First Message to Offline Chat Model]
-The offline chat model gets downloaded when you first send a message to it. The download can take a few minutes! Subsequent messages should be faster.
-:::
-
### Add Chat Models
Login to the Khoj Admin Panel
Go to http://localhost:42110/server/admin and login with the admin credentials you setup during installation.
@@ -301,13 +293,14 @@ Offline chat stays completely private and can work without internet using any op
- A Nvidia, AMD GPU or a Mac M1+ machine would significantly speed up chat responses
:::
-1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
-2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
-3. Set the `chat-model` field to the name of your preferred chat model
- - Make sure the `model-type` is set to `Offline`
-4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).
-5. Restart the Khoj server and [start chatting](http://localhost:42110) with your new offline model!
-
+1. Install any Openai API compatible local ai model server like [llama-cpp-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server), Ollama, vLLM etc.
+2. Add an [ai model api](http://localhost:42110/server/admin/database/aimodelapi/add/) on the admin panel
+ - Set the `api url` field to the url of your local ai model provider like `http://localhost:11434/v1/` for Ollama
+3. Restart the Khoj server to load models available on your local ai model provider
+ - If that doesn't work, you'll need to manually add available [chat model](http://localhost:42110/server/admin/database/chatmodel/add) in the admin panel.
+4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings)
+5. [Start chatting](http://localhost:42110) with your local AI!
+
:::tip[Multiple Chat Models]
diff --git a/pyproject.toml b/pyproject.toml
index 1a580dae..6491dc06 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,7 +65,6 @@ dependencies = [
"django == 5.1.10",
"django-unfold == 0.42.0",
"authlib == 1.2.1",
- "llama-cpp-python == 0.2.88",
"itsdangerous == 2.1.2",
"httpx == 0.28.1",
"pgvector == 0.2.4",
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index a72f15d5..40d1eeb5 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -50,13 +50,11 @@ from khoj.database.adapters import (
)
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
-from khoj.routers.api_content import configure_content, configure_search
+from khoj.routers.api_content import configure_content
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.config import SearchType
-from khoj.utils.fs_syncer import collect_files
-from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
-from khoj.utils.rawconfig import FullConfig
+from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__)
@@ -232,14 +230,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
return AuthCredentials(), UnauthenticatedUser()
-def initialize_server(config: Optional[FullConfig]):
- try:
- configure_server(config, init=True)
- except Exception as e:
- logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
- raise e
-
-
def clean_connections(func):
"""
A decorator that ensures that Django database connections that have become unusable, or are obsolete, are closed
@@ -260,19 +250,7 @@ def clean_connections(func):
return func_wrapper
-def configure_server(
- config: FullConfig,
- regenerate: bool = False,
- search_type: Optional[SearchType] = None,
- init=False,
- user: KhojUser = None,
-):
- # Update Config
- if config == None:
- logger.info(f"Initializing with default config.")
- config = FullConfig()
- state.config = config
-
+def initialize_server():
if ConversationAdapters.has_valid_ai_model_api():
ai_model_api = ConversationAdapters.get_ai_model_api()
state.openai_client = openai.OpenAI(api_key=ai_model_api.api_key, base_url=ai_model_api.api_base_url)
@@ -309,43 +287,33 @@ def configure_server(
)
state.SearchType = configure_search_types()
- state.search_models = configure_search(state.search_models, state.config.search_type)
- setup_default_agent(user)
+ setup_default_agent()
- message = (
- "📡 Telemetry disabled"
- if telemetry_disabled(state.config.app, state.telemetry_disabled)
- else "📡 Telemetry enabled"
- )
+ message = "📡 Telemetry disabled" if state.telemetry_disabled else "📡 Telemetry enabled"
logger.info(message)
- if not init:
- initialize_content(user, regenerate, search_type)
-
except Exception as e:
logger.error(f"Failed to load some search models: {e}", exc_info=True)
-def setup_default_agent(user: KhojUser):
- AgentAdapters.create_default_agent(user)
+def setup_default_agent():
+ AgentAdapters.create_default_agent()
def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
# Initialize Content from Config
- if state.search_models:
- try:
- logger.info("📬 Updating content index...")
- all_files = collect_files(user=user)
- status = configure_content(
- user,
- all_files,
- regenerate,
- search_type,
- )
- if not status:
- raise RuntimeError("Failed to update content index")
- except Exception as e:
- raise e
+ try:
+ logger.info("📬 Updating content index...")
+ status = configure_content(
+ user,
+ {},
+ regenerate,
+ search_type,
+ )
+ if not status:
+ raise RuntimeError("Failed to update content index")
+ except Exception as e:
+ raise e
def configure_routes(app):
@@ -438,8 +406,7 @@ def configure_middleware(app, ssl_enabled: bool = False):
def update_content_index():
for user in get_all_users():
- all_files = collect_files(user=user)
- success = configure_content(user, all_files)
+ success = configure_content(user, {})
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler")
@@ -464,7 +431,7 @@ def configure_search_types():
@schedule.repeat(schedule.every(2).minutes)
@clean_connections
def upload_telemetry():
- if telemetry_disabled(state.config.app, state.telemetry_disabled) or not state.telemetry:
+ if state.telemetry_disabled or not state.telemetry:
return
try:
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 14eaadfe..6d92b8e9 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -72,7 +72,6 @@ from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
-from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
clean_object_for_db,
clean_text_for_db,
@@ -789,8 +788,8 @@ class AgentAdapters:
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
@staticmethod
- def create_default_agent(user: KhojUser):
- default_chat_model = ConversationAdapters.get_default_chat_model(user)
+ def create_default_agent():
+ default_chat_model = ConversationAdapters.get_default_chat_model(user=None)
if default_chat_model is None:
logger.info("No default conversation config found, skipping default agent creation")
return None
@@ -1553,14 +1552,6 @@ class ConversationAdapters:
if chat_model is None:
chat_model = await ConversationAdapters.aget_default_chat_model()
- if chat_model.model_type == ChatModel.ModelType.OFFLINE:
- if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
- chat_model_name = chat_model.name
- max_tokens = chat_model.max_prompt_size
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
-
- return chat_model
-
if (
chat_model.model_type
in [
diff --git a/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py b/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py
new file mode 100644
index 00000000..256aa93d
--- /dev/null
+++ b/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py
@@ -0,0 +1,36 @@
+# Generated by Django 5.1.10 on 2025-07-19 21:33
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0091_chatmodel_friendly_name_and_more"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="chatmodel",
+ name="model_type",
+ field=models.CharField(
+ choices=[("openai", "Openai"), ("anthropic", "Anthropic"), ("google", "Google")],
+ default="google",
+ max_length=200,
+ ),
+ ),
+ migrations.AlterField(
+ model_name="chatmodel",
+ name="name",
+ field=models.CharField(default="gemini-2.5-flash", max_length=200),
+ ),
+ migrations.AlterField(
+ model_name="speechtotextmodeloptions",
+ name="model_name",
+ field=models.CharField(default="whisper-1", max_length=200),
+ ),
+ migrations.AlterField(
+ model_name="speechtotextmodeloptions",
+ name="model_type",
+ field=models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200),
+ ),
+ ]
diff --git a/src/khoj/database/migrations/0093_remove_localorgconfig_user_and_more.py b/src/khoj/database/migrations/0093_remove_localorgconfig_user_and_more.py
new file mode 100644
index 00000000..ad3409cf
--- /dev/null
+++ b/src/khoj/database/migrations/0093_remove_localorgconfig_user_and_more.py
@@ -0,0 +1,36 @@
+# Generated by Django 5.1.10 on 2025-07-25 23:30
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="localorgconfig",
+ name="user",
+ ),
+ migrations.RemoveField(
+ model_name="localpdfconfig",
+ name="user",
+ ),
+ migrations.RemoveField(
+ model_name="localplaintextconfig",
+ name="user",
+ ),
+ migrations.DeleteModel(
+ name="LocalMarkdownConfig",
+ ),
+ migrations.DeleteModel(
+ name="LocalOrgConfig",
+ ),
+ migrations.DeleteModel(
+ name="LocalPdfConfig",
+ ),
+ migrations.DeleteModel(
+ name="LocalPlaintextConfig",
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 7f80459a..1ed58572 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -220,16 +220,15 @@ class PriceTier(models.TextChoices):
class ChatModel(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
- OFFLINE = "offline"
ANTHROPIC = "anthropic"
GOOGLE = "google"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
- name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
+ name = models.CharField(max_length=200, default="gemini-2.5-flash")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
- model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
+ model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.GOOGLE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -489,34 +488,6 @@ class ServerChatSettings(DbBaseModel):
super().save(*args, **kwargs)
-class LocalOrgConfig(DbBaseModel):
- input_files = models.JSONField(default=list, null=True)
- input_filter = models.JSONField(default=list, null=True)
- index_heading_entries = models.BooleanField(default=False)
- user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-
-
-class LocalMarkdownConfig(DbBaseModel):
- input_files = models.JSONField(default=list, null=True)
- input_filter = models.JSONField(default=list, null=True)
- index_heading_entries = models.BooleanField(default=False)
- user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-
-
-class LocalPdfConfig(DbBaseModel):
- input_files = models.JSONField(default=list, null=True)
- input_filter = models.JSONField(default=list, null=True)
- index_heading_entries = models.BooleanField(default=False)
- user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-
-
-class LocalPlaintextConfig(DbBaseModel):
- input_files = models.JSONField(default=list, null=True)
- input_filter = models.JSONField(default=list, null=True)
- index_heading_entries = models.BooleanField(default=False)
- user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
-
-
class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
@@ -605,11 +576,10 @@ class TextToImageModelConfig(DbBaseModel):
class SpeechToTextModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
- OFFLINE = "offline"
- model_name = models.CharField(max_length=200, default="base")
+ model_name = models.CharField(max_length=200, default="whisper-1")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
- model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
+ model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
diff --git a/src/khoj/main.py b/src/khoj/main.py
index 50da2624..f42ae135 100644
--- a/src/khoj/main.py
+++ b/src/khoj/main.py
@@ -138,10 +138,10 @@ def run(should_start_server=True):
initialization(not args.non_interactive)
# Create app directory, if it doesn't exist
- state.config_file.parent.mkdir(parents=True, exist_ok=True)
+ state.log_file.parent.mkdir(parents=True, exist_ok=True)
# Set Log File
- fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
+ fh = logging.FileHandler(state.log_file, encoding="utf-8")
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
@@ -194,7 +194,7 @@ def run(should_start_server=True):
# Configure Middleware
configure_middleware(app, state.ssl_config)
- initialize_server(args.config)
+ initialize_server()
# If the server is started through gunicorn (external to the script), don't start the server
if should_start_server:
@@ -204,8 +204,7 @@ def run(should_start_server=True):
def set_state(args):
- state.config_file = args.config_file
- state.config = args.config
+ state.log_file = args.log_file
state.verbose = args.verbose
state.host = args.host
state.port = args.port
@@ -214,7 +213,6 @@ def set_state(args):
)
state.anonymous_mode = args.anonymous_mode
state.khoj_version = version("khoj")
- state.chat_on_gpu = args.chat_on_gpu
def start_server(app, host=None, port=None, socket=None):
diff --git a/src/khoj/migrations/__init__.py b/src/khoj/migrations/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/src/khoj/migrations/migrate_offline_chat_default_model.py b/src/khoj/migrations/migrate_offline_chat_default_model.py
deleted file mode 100644
index 831f2d9d..00000000
--- a/src/khoj/migrations/migrate_offline_chat_default_model.py
+++ /dev/null
@@ -1,69 +0,0 @@
-"""
-Current format of khoj.yml
----
-app:
- ...
-content-type:
- ...
-processor:
- conversation:
- offline-chat:
- enable-offline-chat: false
- chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
- ...
-search-type:
- ...
-
-New format of khoj.yml
----
-app:
- ...
-content-type:
- ...
-processor:
- conversation:
- offline-chat:
- enable-offline-chat: false
- chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
- ...
-search-type:
- ...
-"""
-import logging
-
-from packaging import version
-
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-logger = logging.getLogger(__name__)
-
-
-def migrate_offline_chat_default_model(args):
- schema_version = "0.12.4"
- raw_config = load_config_from_file(args.config_file)
- previous_version = raw_config.get("version")
-
- if "processor" not in raw_config:
- return args
- if raw_config["processor"] is None:
- return args
- if "conversation" not in raw_config["processor"]:
- return args
- if "offline-chat" not in raw_config["processor"]["conversation"]:
- return args
- if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]:
- return args
-
- if previous_version is None or version.parse(previous_version) < version.parse("0.12.4"):
- logger.info(
- f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF"
- )
- raw_config["version"] = schema_version
-
- # Update offline chat model to mistral in GGUF format to use latest GPT4All
- offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"]
- if offline_chat_model.endswith(".bin"):
- raw_config["processor"]["conversation"]["offline-chat"]["chat-model"] = "mistral-7b-instruct-v0.1.Q4_0.gguf"
-
- save_config_to_file(raw_config, args.config_file)
- return args
diff --git a/src/khoj/migrations/migrate_offline_chat_default_model_2.py b/src/khoj/migrations/migrate_offline_chat_default_model_2.py
deleted file mode 100644
index 107b7130..00000000
--- a/src/khoj/migrations/migrate_offline_chat_default_model_2.py
+++ /dev/null
@@ -1,71 +0,0 @@
-"""
-Current format of khoj.yml
----
-app:
- ...
-content-type:
- ...
-processor:
- conversation:
- offline-chat:
- enable-offline-chat: false
- chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
- ...
-search-type:
- ...
-
-New format of khoj.yml
----
-app:
- ...
-content-type:
- ...
-processor:
- conversation:
- offline-chat:
- enable-offline-chat: false
- chat-model: NousResearch/Hermes-2-Pro-Mistral-7B-GGUF
- ...
-search-type:
- ...
-"""
-import logging
-
-from packaging import version
-
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-logger = logging.getLogger(__name__)
-
-
-def migrate_offline_chat_default_model(args):
- schema_version = "1.7.0"
- raw_config = load_config_from_file(args.config_file)
- previous_version = raw_config.get("version")
-
- if "processor" not in raw_config:
- return args
- if raw_config["processor"] is None:
- return args
- if "conversation" not in raw_config["processor"]:
- return args
- if "offline-chat" not in raw_config["processor"]["conversation"]:
- return args
- if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]:
- return args
-
- if previous_version is None or version.parse(previous_version) < version.parse(schema_version):
- logger.info(
- f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF"
- )
- raw_config["version"] = schema_version
-
- # Update offline chat model to use Nous Research's Hermes-2-Pro GGUF in path format suitable for llama-cpp
- offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"]
- if offline_chat_model == "mistral-7b-instruct-v0.1.Q4_0.gguf":
- raw_config["processor"]["conversation"]["offline-chat"][
- "chat-model"
- ] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
-
- save_config_to_file(raw_config, args.config_file)
- return args
diff --git a/src/khoj/migrations/migrate_offline_chat_schema.py b/src/khoj/migrations/migrate_offline_chat_schema.py
deleted file mode 100644
index 0c221652..00000000
--- a/src/khoj/migrations/migrate_offline_chat_schema.py
+++ /dev/null
@@ -1,83 +0,0 @@
-"""
-Current format of khoj.yml
----
-app:
- ...
-content-type:
- ...
-processor:
- conversation:
- enable-offline-chat: false
- conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
- openai:
- ...
-search-type:
- ...
-
-New format of khoj.yml
----
-app:
- ...
-content-type:
- ...
-processor:
- conversation:
- offline-chat:
- enable-offline-chat: false
- chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
- tokenizer: null
- max_prompt_size: null
- conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
- openai:
- ...
-search-type:
- ...
-"""
-import logging
-
-from packaging import version
-
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-logger = logging.getLogger(__name__)
-
-
-def migrate_offline_chat_schema(args):
- schema_version = "0.12.3"
- raw_config = load_config_from_file(args.config_file)
- previous_version = raw_config.get("version")
-
- if "processor" not in raw_config:
- return args
- if raw_config["processor"] is None:
- return args
- if "conversation" not in raw_config["processor"]:
- return args
-
- if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
- logger.info(
- f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
- )
- raw_config["version"] = schema_version
-
- # Create max-prompt-size field in conversation processor schema
- raw_config["processor"]["conversation"]["max-prompt-size"] = None
- raw_config["processor"]["conversation"]["tokenizer"] = None
-
- # Create offline chat schema based on existing enable_offline_chat field in khoj config schema
- offline_chat_model = (
- raw_config["processor"]["conversation"]
- .get("offline-chat", {})
- .get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
- )
- raw_config["processor"]["conversation"]["offline-chat"] = {
- "enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
- "chat-model": offline_chat_model,
- }
-
- # Delete old enable-offline-chat field from conversation processor schema
- if "enable-offline-chat" in raw_config["processor"]["conversation"]:
- del raw_config["processor"]["conversation"]["enable-offline-chat"]
-
- save_config_to_file(raw_config, args.config_file)
- return args
diff --git a/src/khoj/migrations/migrate_offline_model.py b/src/khoj/migrations/migrate_offline_model.py
deleted file mode 100644
index 6294a4e8..00000000
--- a/src/khoj/migrations/migrate_offline_model.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import logging
-import os
-
-from packaging import version
-
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-logger = logging.getLogger(__name__)
-
-
-def migrate_offline_model(args):
- schema_version = "0.10.1"
- raw_config = load_config_from_file(args.config_file)
- previous_version = raw_config.get("version")
-
- if previous_version is None or version.parse(previous_version) < version.parse("0.10.1"):
- logger.info(
- f"Migrating offline model used for version {previous_version} to latest version for {args.version_no}"
- )
- raw_config["version"] = schema_version
-
- # If the user has downloaded the offline model, remove it from the cache.
- offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin")
- if os.path.exists(offline_model_path):
- os.remove(offline_model_path)
-
- save_config_to_file(raw_config, args.config_file)
-
- return args
diff --git a/src/khoj/migrations/migrate_processor_config_openai.py b/src/khoj/migrations/migrate_processor_config_openai.py
deleted file mode 100644
index c25e5306..00000000
--- a/src/khoj/migrations/migrate_processor_config_openai.py
+++ /dev/null
@@ -1,67 +0,0 @@
-"""
-Current format of khoj.yml
----
-app:
- should-log-telemetry: true
-content-type:
- ...
-processor:
- conversation:
- chat-model: gpt-3.5-turbo
- conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
- model: text-davinci-003
- openai-api-key: sk-secret-key
-search-type:
- ...
-
-New format of khoj.yml
----
-app:
- should-log-telemetry: true
-content-type:
- ...
-processor:
- conversation:
- openai:
- chat-model: gpt-3.5-turbo
- openai-api-key: sk-secret-key
- conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
- enable-offline-chat: false
-search-type:
- ...
-"""
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-
-def migrate_processor_conversation_schema(args):
- schema_version = "0.10.0"
- raw_config = load_config_from_file(args.config_file)
-
- if "processor" not in raw_config:
- return args
- if raw_config["processor"] is None:
- return args
- if "conversation" not in raw_config["processor"]:
- return args
-
- current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
- current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
- if current_openai_api_key is None and current_chat_model is None:
- return args
-
- raw_config["version"] = schema_version
-
- # Add enable_offline_chat to khoj config schema
- if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
- raw_config["processor"]["conversation"]["enable-offline-chat"] = False
-
- # Update conversation processor schema
- conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
- raw_config["processor"]["conversation"] = {
- "openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
- "conversation-logfile": conversation_logfile,
- "enable-offline-chat": False,
- }
-
- save_config_to_file(raw_config, args.config_file)
- return args
diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py
deleted file mode 100644
index 316704b9..00000000
--- a/src/khoj/migrations/migrate_server_pg.py
+++ /dev/null
@@ -1,132 +0,0 @@
-"""
-The application config currently looks like this:
-app:
- should-log-telemetry: true
-content-type:
- ...
-processor:
- conversation:
- conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
- max-prompt-size: null
- offline-chat:
- chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
- enable-offline-chat: false
- openai:
- api-key: sk-blah
- chat-model: gpt-3.5-turbo
- tokenizer: null
-search-type:
- asymmetric:
- cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
- encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
- encoder-type: null
- model-directory: /Users/si/.khoj/search/asymmetric
- image:
- encoder: sentence-transformers/clip-ViT-B-32
- encoder-type: null
- model-directory: /Users/si/.khoj/search/image
- symmetric:
- cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
- encoder: sentence-transformers/all-MiniLM-L6-v2
- encoder-type: null
- model-directory: ~/.khoj/search/symmetric
-version: 0.14.0
-
-
-The new version will looks like this:
-app:
- should-log-telemetry: true
-processor:
- conversation:
- offline-chat:
- enabled: false
- openai:
- api-key: sk-blah
- chat-model-options:
- - chat-model: gpt-3.5-turbo
- tokenizer: null
- type: openai
- - chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
- tokenizer: null
- type: offline
-search-type:
- asymmetric:
- cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
- encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
-version: 0.15.0
-"""
-
-import logging
-
-from packaging import version
-
-from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-logger = logging.getLogger(__name__)
-
-
-def migrate_server_pg(args):
- schema_version = "0.15.0"
- raw_config = load_config_from_file(args.config_file)
- previous_version = raw_config.get("version")
-
- if previous_version is None or version.parse(previous_version) < version.parse(schema_version):
- logger.info(
- f"Migrating configuration used for version {previous_version} to latest version for server with postgres in {args.version_no}"
- )
- raw_config["version"] = schema_version
-
- if raw_config is None:
- return args
-
- if "search-type" in raw_config and raw_config["search-type"]:
- if "asymmetric" in raw_config["search-type"]:
- # Delete all existing search models
- SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
- # Create new search model from existing Khoj YAML config
- asymmetric_search = raw_config["search-type"]["asymmetric"]
- SearchModelConfig.objects.create(
- name="default",
- model_type=SearchModelConfig.ModelType.TEXT,
- bi_encoder=asymmetric_search.get("encoder"),
- cross_encoder=asymmetric_search.get("cross-encoder"),
- )
-
- if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
- processor_conversation = raw_config["processor"]["conversation"]
-
- if "offline-chat" in raw_config["processor"]["conversation"]:
- offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
- ChatModel.objects.create(
- name=offline_chat.get("chat-model"),
- tokenizer=processor_conversation.get("tokenizer"),
- max_prompt_size=processor_conversation.get("max-prompt-size"),
- model_type=ChatModel.ModelType.OFFLINE,
- )
-
- if (
- "openai" in raw_config["processor"]["conversation"]
- and raw_config["processor"]["conversation"]["openai"]
- ):
- openai = raw_config["processor"]["conversation"]["openai"]
-
- if openai.get("api-key") is None:
- logger.error("OpenAI API Key is not set. Will not be migrating OpenAI config.")
- else:
- if openai.get("chat-model") is None:
- openai["chat-model"] = "gpt-3.5-turbo"
-
- openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
-
- ChatModel.objects.create(
- name=openai.get("chat-model"),
- tokenizer=processor_conversation.get("tokenizer"),
- max_prompt_size=processor_conversation.get("max-prompt-size"),
- model_type=ChatModel.ModelType.OPENAI,
- ai_model_api=openai_model_api,
- )
-
- save_config_to_file(raw_config, args.config_file)
-
- return args
diff --git a/src/khoj/migrations/migrate_version.py b/src/khoj/migrations/migrate_version.py
deleted file mode 100644
index de8b9571..00000000
--- a/src/khoj/migrations/migrate_version.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from khoj.utils.yaml import load_config_from_file, save_config_to_file
-
-
-def migrate_config_to_version(args):
- schema_version = "0.9.0"
- raw_config = load_config_from_file(args.config_file)
-
- # Add version to khoj config schema
- if "version" not in raw_config:
- raw_config["version"] = schema_version
- save_config_to_file(raw_config, args.config_file)
-
- # regenerate khoj index on first start of this version
- # this should refresh index and apply index corruption fixes from #325
- args.regenerate = True
-
- return args
diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py
index 31f99f84..63ed50c6 100644
--- a/src/khoj/processor/content/github/github_to_entries.py
+++ b/src/khoj/processor/content/github/github_to_entries.py
@@ -20,7 +20,6 @@ magika = Magika()
class GithubToEntries(TextToEntries):
def __init__(self, config: GithubConfig):
- super().__init__(config)
raw_repos = config.githubrepoconfig.all()
repos = []
for repo in raw_repos:
diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py
index 1e1ab4d3..23b96f63 100644
--- a/src/khoj/processor/content/notion/notion_to_entries.py
+++ b/src/khoj/processor/content/notion/notion_to_entries.py
@@ -47,7 +47,6 @@ class NotionBlockType(Enum):
class NotionToEntries(TextToEntries):
def __init__(self, config: NotionConfig):
- super().__init__(config)
self.config = NotionContentConfig(
token=config.token,
)
diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py
index 0ceda11d..0369d273 100644
--- a/src/khoj/processor/content/text_to_entries.py
+++ b/src/khoj/processor/content/text_to_entries.py
@@ -27,7 +27,6 @@ logger = logging.getLogger(__name__)
class TextToEntries(ABC):
def __init__(self, config: Any = None):
self.embeddings_model = state.embeddings_model
- self.config = config
self.date_filter = DateFilter()
@abstractmethod
diff --git a/src/khoj/processor/conversation/offline/__init__.py b/src/khoj/processor/conversation/offline/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
deleted file mode 100644
index b117a48d..00000000
--- a/src/khoj/processor/conversation/offline/chat_model.py
+++ /dev/null
@@ -1,224 +0,0 @@
-import asyncio
-import logging
-import os
-from datetime import datetime
-from threading import Thread
-from time import perf_counter
-from typing import Any, AsyncGenerator, Dict, List, Union
-
-from langchain_core.messages.chat import ChatMessage
-from llama_cpp import Llama
-
-from khoj.database.models import Agent, ChatMessageModel, ChatModel
-from khoj.processor.conversation import prompts
-from khoj.processor.conversation.offline.utils import download_model
-from khoj.processor.conversation.utils import (
- ResponseWithThought,
- commit_conversation_trace,
- generate_chatml_messages_with_context,
- messages_to_print,
-)
-from khoj.utils import state
-from khoj.utils.helpers import (
- is_none_or_empty,
- is_promptrace_enabled,
- truncate_code_context,
-)
-from khoj.utils.rawconfig import FileAttachment, LocationData
-from khoj.utils.yaml import yaml_dump
-
-logger = logging.getLogger(__name__)
-
-
-async def converse_offline(
- # Query
- user_query: str,
- # Context
- references: list[dict] = [],
- online_results={},
- code_results={},
- query_files: str = None,
- generated_files: List[FileAttachment] = None,
- additional_context: List[str] = None,
- generated_asset_results: Dict[str, Dict] = {},
- location_data: LocationData = None,
- user_name: str = None,
- chat_history: list[ChatMessageModel] = [],
- # Model
- model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
- loaded_model: Union[Any, None] = None,
- max_prompt_size=None,
- tokenizer_name=None,
- agent: Agent = None,
- tracer: dict = {},
-) -> AsyncGenerator[ResponseWithThought, None]:
- """
- Converse with user using Llama (Async Version)
- """
- # Initialize Variables
- assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
- offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
- tracer["chat_model"] = model_name
- current_date = datetime.now()
-
- if agent and agent.personality:
- system_prompt = prompts.custom_system_prompt_offline_chat.format(
- name=agent.name,
- bio=agent.personality,
- current_date=current_date.strftime("%Y-%m-%d"),
- day_of_week=current_date.strftime("%A"),
- )
- else:
- system_prompt = prompts.system_prompt_offline_chat.format(
- current_date=current_date.strftime("%Y-%m-%d"),
- day_of_week=current_date.strftime("%A"),
- )
-
- if location_data:
- location_prompt = prompts.user_location.format(location=f"{location_data}")
- system_prompt = f"{system_prompt}\n{location_prompt}"
-
- if user_name:
- user_name_prompt = prompts.user_name.format(name=user_name)
- system_prompt = f"{system_prompt}\n{user_name_prompt}"
-
- # Get Conversation Primer appropriate to Conversation Type
- context_message = ""
- if not is_none_or_empty(references):
- context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
- if not is_none_or_empty(online_results):
- simplified_online_results = online_results.copy()
- for result in online_results:
- if online_results[result].get("webpages"):
- simplified_online_results[result] = online_results[result]["webpages"]
-
- context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
- if not is_none_or_empty(code_results):
- context_message += (
- f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
- )
- context_message = context_message.strip()
-
- # Setup Prompt with Primer or Conversation History
- messages = generate_chatml_messages_with_context(
- user_query,
- system_prompt,
- chat_history,
- context_message=context_message,
- model_name=model_name,
- loaded_model=offline_chat_model,
- max_prompt_size=max_prompt_size,
- tokenizer_name=tokenizer_name,
- model_type=ChatModel.ModelType.OFFLINE,
- query_files=query_files,
- generated_files=generated_files,
- generated_asset_results=generated_asset_results,
- program_execution_context=additional_context,
- )
-
- logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
-
- # Use asyncio.Queue and a thread to bridge sync iterator
- queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue()
- stop_phrases = ["", "INST]", "Notes:"]
-
- def _sync_llm_thread():
- """Synchronous function to run in a separate thread."""
- aggregated_response = ""
- start_time = perf_counter()
- state.chat_lock.acquire()
- try:
- response_iterator = send_message_to_model_offline(
- messages,
- loaded_model=offline_chat_model,
- stop=stop_phrases,
- max_prompt_size=max_prompt_size,
- streaming=True,
- tracer=tracer,
- )
- for response in response_iterator:
- response_delta: str = response["choices"][0]["delta"].get("content", "")
- # Log the time taken to start response
- if aggregated_response == "" and response_delta != "":
- logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
- # Handle response chunk
- aggregated_response += response_delta
- # Put chunk into the asyncio queue (non-blocking)
- try:
- queue.put_nowait(ResponseWithThought(text=response_delta))
- except asyncio.QueueFull:
- # Should not happen with default queue size unless consumer is very slow
- logger.warning("Asyncio queue full during offline LLM streaming.")
- # Potentially block here or handle differently if needed
- asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
-
- # Log the time taken to stream the entire response
- logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
-
- # Save conversation trace
- tracer["chat_model"] = model_name
- if is_promptrace_enabled():
- commit_conversation_trace(messages, aggregated_response, tracer)
-
- except Exception as e:
- logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
- finally:
- state.chat_lock.release()
- # Signal end of stream
- queue.put_nowait(None)
-
- # Start the synchronous thread
- thread = Thread(target=_sync_llm_thread)
- thread.start()
-
- # Asynchronously consume from the queue
- while True:
- chunk = await queue.get()
- if chunk is None: # End of stream signal
- queue.task_done()
- break
- yield chunk
- queue.task_done()
-
- # Wait for the thread to finish (optional, ensures cleanup)
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(None, thread.join)
-
-
-def send_message_to_model_offline(
- messages: List[ChatMessage],
- loaded_model=None,
- model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
- temperature: float = 0.2,
- streaming=False,
- stop=[],
- max_prompt_size: int = None,
- response_type: str = "text",
- tracer: dict = {},
-):
- assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
- offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
- messages_dict = [{"role": message.role, "content": message.content} for message in messages]
- seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
- response = offline_chat_model.create_chat_completion(
- messages_dict,
- stop=stop,
- stream=streaming,
- temperature=temperature,
- response_format={"type": response_type},
- seed=seed,
- )
-
- if streaming:
- return response
-
- response_text: str = response["choices"][0]["message"].get("content", "")
-
- # Save conversation trace for non-streaming responses
- # Streamed responses need to be saved by the calling function
- tracer["chat_model"] = model_name
- tracer["temperature"] = temperature
- if is_promptrace_enabled():
- commit_conversation_trace(messages, response_text, tracer)
-
- return ResponseWithThought(text=response_text)
diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py
deleted file mode 100644
index 88082ad1..00000000
--- a/src/khoj/processor/conversation/offline/utils.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import glob
-import logging
-import math
-import os
-from typing import Any, Dict
-
-from huggingface_hub.constants import HF_HUB_CACHE
-
-from khoj.utils import state
-from khoj.utils.helpers import get_device_memory
-
-logger = logging.getLogger(__name__)
-
-
-def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
- # Initialize Model Parameters
- # Use n_ctx=0 to get context size from the model
- kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False}
-
- # Decide whether to load model to GPU or CPU
- device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
- kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
-
- # Add chat format if known
- if "llama-3" in repo_id.lower():
- kwargs["chat_format"] = "llama-3"
- elif "gemma-2" in repo_id.lower():
- kwargs["chat_format"] = "gemma"
-
- # Check if the model is already downloaded
- model_path = load_model_from_cache(repo_id, filename)
- chat_model = None
- try:
- chat_model = load_model(model_path, repo_id, filename, kwargs)
- except:
- # Load model on CPU if GPU is not available
- kwargs["n_gpu_layers"], device = 0, "cpu"
- chat_model = load_model(model_path, repo_id, filename, kwargs)
-
- # Now load the model with context size set based on:
- # 1. context size supported by model and
- # 2. configured size or machine (V)RAM
- kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
- chat_model = load_model(model_path, repo_id, filename, kwargs)
-
- logger.debug(
- f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
- )
- return chat_model
-
-
-def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
- from llama_cpp.llama import Llama
-
- if model_path:
- return Llama(model_path, **kwargs)
- else:
- return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
-
-
-def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
- # Construct the path to the model file in the cache directory
- repo_org, repo_name = repo_id.split("/")
- object_id = "--".join([repo_type, repo_org, repo_name])
- model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
-
- # Check if the model file exists
- paths = glob.glob(model_path)
- if paths:
- return paths[0]
- else:
- return None
-
-
-def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int:
- """Infer max prompt size based on device memory and max context window supported by the model"""
- configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
- vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
- configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None
- return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
diff --git a/src/khoj/processor/conversation/offline/whisper.py b/src/khoj/processor/conversation/offline/whisper.py
deleted file mode 100644
index d8dd4457..00000000
--- a/src/khoj/processor/conversation/offline/whisper.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import whisper
-from asgiref.sync import sync_to_async
-
-from khoj.utils import state
-
-
-async def transcribe_audio_offline(audio_filename: str, model: str) -> str:
- """
- Transcribe audio file offline using Whisper
- """
- # Send the audio data to the Whisper API
- if not state.whisper_model:
- state.whisper_model = whisper.load_model(model)
- response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
- return response["text"]
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 47589e93..fcf30d07 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -78,38 +78,6 @@ no_entries_found = PromptTemplate.from_template(
""".strip()
)
-## Conversation Prompts for Offline Chat Models
-## --
-system_prompt_offline_chat = PromptTemplate.from_template(
- """
-You are Khoj, a smart, inquisitive and helpful personal assistant.
-- Use your general knowledge and past conversation with the user as context to inform your responses.
-- If you do not know the answer, say 'I don't know.'
-- Think step-by-step and ask questions to get the necessary information to answer the user's question.
-- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
-- Do not print verbatim Notes unless necessary.
-
-Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
-Today is {day_of_week}, {current_date} in UTC.
-""".strip()
-)
-
-custom_system_prompt_offline_chat = PromptTemplate.from_template(
- """
-You are {name}, a personal agent on Khoj.
-- Use your general knowledge and past conversation with the user as context to inform your responses.
-- If you do not know the answer, say 'I don't know.'
-- Think step-by-step and ask questions to get the necessary information to answer the user's question.
-- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
-- Do not print verbatim Notes unless necessary.
-
-Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
-Today is {day_of_week}, {current_date} in UTC.
-
-Instructions:\n{bio}
-""".strip()
-)
-
## Notes Conversation
## --
notes_conversation = PromptTemplate.from_template(
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 3bd17856..6e60c786 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -18,8 +18,6 @@ import requests
import tiktoken
import yaml
from langchain_core.messages.chat import ChatMessage
-from llama_cpp import LlamaTokenizer
-from llama_cpp.llama import Llama
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -32,7 +30,6 @@ from khoj.database.models import (
KhojUser,
)
from khoj.processor.conversation import prompts
-from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -85,12 +82,6 @@ model_to_prompt_size = {
"claude-sonnet-4-20250514": 60000,
"claude-opus-4-0": 60000,
"claude-opus-4-20250514": 60000,
- # Offline Models
- "bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
- "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
- "bartowski/Llama-3.2-3B-Instruct-GGUF": 20000,
- "bartowski/gemma-2-9b-it-GGUF": 6000,
- "bartowski/gemma-2-2b-it-GGUF": 6000,
}
model_to_tokenizer: Dict[str, str] = {}
@@ -573,7 +564,6 @@ def generate_chatml_messages_with_context(
system_message: str = None,
chat_history: list[ChatMessageModel] = [],
model_name="gpt-4o-mini",
- loaded_model: Optional[Llama] = None,
max_prompt_size=None,
tokenizer_name=None,
query_images=None,
@@ -588,10 +578,7 @@ def generate_chatml_messages_with_context(
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
- if loaded_model:
- max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
- else:
- max_prompt_size = model_to_prompt_size.get(model_name, 10000)
+ max_prompt_size = model_to_prompt_size.get(model_name, 10000)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
@@ -735,7 +722,7 @@ def generate_chatml_messages_with_context(
message.content = [{"type": "text", "text": message.content}]
# Truncate oldest messages from conversation history until under max supported prompt size by model
- messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
+ messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order
return messages[::-1]
@@ -743,25 +730,20 @@ def generate_chatml_messages_with_context(
def get_encoder(
model_name: str,
- loaded_model: Optional[Llama] = None,
tokenizer_name=None,
-) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
+) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
default_tokenizer = "gpt-4o"
try:
- if loaded_model:
- encoder = loaded_model.tokenizer()
- elif model_name.startswith("gpt-") or model_name.startswith("o1"):
- # as tiktoken doesn't recognize o1 model series yet
- encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
- elif tokenizer_name:
+ if tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name]
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
state.pretrained_tokenizers[tokenizer_name] = encoder
else:
- encoder = download_model(model_name).tokenizer()
+ # as tiktoken doesn't recognize o1 model series yet
+ encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except:
encoder = tiktoken.encoding_for_model(default_tokenizer)
if state.verbose > 2:
@@ -773,7 +755,7 @@ def get_encoder(
def count_tokens(
message_content: str | list[str | dict],
- encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
+ encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
@@ -825,11 +807,10 @@ def truncate_messages(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str,
- loaded_model: Optional[Llama] = None,
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
- encoder = get_encoder(model_name, loaded_model, tokenizer_name)
+ encoder = get_encoder(model_name, tokenizer_name)
# Extract system message from messages
system_message = None
diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py
index 979205a0..2b63c40d 100644
--- a/src/khoj/processor/operator/__init__.py
+++ b/src/khoj/processor/operator/__init__.py
@@ -235,7 +235,6 @@ def is_operator_model(model: str) -> ChatModel.ModelType | None:
"claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC,
"claude-sonnet-4": ChatModel.ModelType.ANTHROPIC,
"claude-opus-4": ChatModel.ModelType.ANTHROPIC,
- "ui-tars-1.5": ChatModel.ModelType.OFFLINE,
}
for operator_model in operator_models:
if model.startswith(operator_model):
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 23805200..44c2f2b7 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -15,7 +15,6 @@ from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo
from khoj.database.models import KhojUser, SpeechToTextModelOptions
-from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import (
ApiUserRateLimiter,
@@ -88,22 +87,14 @@ def update(
force: Optional[bool] = False,
):
user = request.user.object
- if not state.config:
- error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/settings, plugins or by editing {state.config_file}."
- logger.warning(error_msg)
- raise HTTPException(status_code=500, detail=error_msg)
try:
initialize_content(user=user, regenerate=force, search_type=t)
except Exception as e:
- error_msg = f"🚨 Failed to update server via API: {e}"
+ error_msg = f"🚨 Failed to update server indexed content via API: {e}"
logger.error(error_msg, exc_info=True)
raise HTTPException(status_code=500, detail=error_msg)
else:
- components = []
- if state.search_models:
- components.append("Search models")
- components_msg = ", ".join(components)
- logger.info(f"📪 {components_msg} updated via API")
+ logger.info(f"📪 Server indexed content updated via API")
update_telemetry_state(
request=request,
@@ -150,9 +141,6 @@ async def transcribe(
if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unsupported on server error
status_code = 501
- elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
- speech2text_model = speech_to_text_config.model_name
- user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
speech2text_model = speech_to_text_config.model_name
if speech_to_text_config.ai_model_api:
diff --git a/src/khoj/routers/api_content.py b/src/khoj/routers/api_content.py
index 4f9cc407..c2732ec8 100644
--- a/src/khoj/routers/api_content.py
+++ b/src/khoj/routers/api_content.py
@@ -27,16 +27,7 @@ from khoj.database.adapters import (
get_user_notion_config,
)
from khoj.database.models import Entry as DbEntry
-from khoj.database.models import (
- GithubConfig,
- GithubRepoConfig,
- KhojUser,
- LocalMarkdownConfig,
- LocalOrgConfig,
- LocalPdfConfig,
- LocalPlaintextConfig,
- NotionConfig,
-)
+from khoj.database.models import GithubConfig, GithubRepoConfig, NotionConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.routers.helpers import (
@@ -47,17 +38,9 @@ from khoj.routers.helpers import (
get_user_config,
update_telemetry_state,
)
-from khoj.utils import constants, state
-from khoj.utils.config import SearchModels
-from khoj.utils.rawconfig import (
- ContentConfig,
- FullConfig,
- GithubContentConfig,
- NotionContentConfig,
- SearchConfig,
-)
+from khoj.utils import state
+from khoj.utils.rawconfig import GithubContentConfig, NotionContentConfig
from khoj.utils.state import SearchType
-from khoj.utils.yaml import save_config_to_file_updated_state
logger = logging.getLogger(__name__)
@@ -192,8 +175,6 @@ async def set_content_github(
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
- _initialize_config()
-
user = request.user.object
try:
@@ -225,8 +206,6 @@ async def set_content_notion(
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
- _initialize_config()
-
user = request.user.object
try:
@@ -323,10 +302,6 @@ def get_content_types(request: Request, client: Optional[str] = None):
configured_content_types = set(EntryAdapters.get_unique_file_types(user))
configured_content_types |= {"all"}
- if state.config and state.config.content_type:
- for ctype in state.config.content_type.model_dump(exclude_none=True):
- configured_content_types.add(ctype)
-
return list(configured_content_types & all_content_types)
@@ -606,28 +581,6 @@ async def indexer(
docx=index_files["docx"],
)
- if state.config == None:
- logger.info("📬 Initializing content index on first run.")
- default_full_config = FullConfig(
- content_type=None,
- search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
- processor=None,
- )
- state.config = default_full_config
- default_content_config = ContentConfig(
- org=None,
- markdown=None,
- pdf=None,
- docx=None,
- image=None,
- github=None,
- notion=None,
- plaintext=None,
- )
- state.config.content_type = default_content_config
- save_config_to_file_updated_state()
- configure_search(state.search_models, state.config.search_type)
-
loop = asyncio.get_event_loop()
success = await loop.run_in_executor(
None,
@@ -674,14 +627,6 @@ async def indexer(
return Response(content=indexed_filenames, status_code=200)
-def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
- # Run Validation Checks
- if search_models is None:
- search_models = SearchModels()
-
- return search_models
-
-
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
@@ -689,56 +634,3 @@ def map_config_to_object(content_source: str):
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
-
-
-async def map_config_to_db(config: FullConfig, user: KhojUser):
- if config.content_type:
- if config.content_type.org:
- await LocalOrgConfig.objects.filter(user=user).adelete()
- await LocalOrgConfig.objects.acreate(
- input_files=config.content_type.org.input_files,
- input_filter=config.content_type.org.input_filter,
- index_heading_entries=config.content_type.org.index_heading_entries,
- user=user,
- )
- if config.content_type.markdown:
- await LocalMarkdownConfig.objects.filter(user=user).adelete()
- await LocalMarkdownConfig.objects.acreate(
- input_files=config.content_type.markdown.input_files,
- input_filter=config.content_type.markdown.input_filter,
- index_heading_entries=config.content_type.markdown.index_heading_entries,
- user=user,
- )
- if config.content_type.pdf:
- await LocalPdfConfig.objects.filter(user=user).adelete()
- await LocalPdfConfig.objects.acreate(
- input_files=config.content_type.pdf.input_files,
- input_filter=config.content_type.pdf.input_filter,
- index_heading_entries=config.content_type.pdf.index_heading_entries,
- user=user,
- )
- if config.content_type.plaintext:
- await LocalPlaintextConfig.objects.filter(user=user).adelete()
- await LocalPlaintextConfig.objects.acreate(
- input_files=config.content_type.plaintext.input_files,
- input_filter=config.content_type.plaintext.input_filter,
- index_heading_entries=config.content_type.plaintext.index_heading_entries,
- user=user,
- )
- if config.content_type.github:
- await adapters.set_user_github_config(
- user=user,
- pat_token=config.content_type.github.pat_token,
- repos=config.content_type.github.repos,
- )
- if config.content_type.notion:
- await adapters.set_notion_config(
- user=user,
- token=config.content_type.notion.token,
- )
-
-
-def _initialize_config():
- if state.config is None:
- state.config = FullConfig()
- state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index d3554f7f..8dcda86b 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -89,10 +89,6 @@ from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
gemini_send_message_to_model,
)
-from khoj.processor.conversation.offline.chat_model import (
- converse_offline,
- send_message_to_model_offline,
-)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
send_message_to_model,
@@ -117,7 +113,6 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import text_search
from khoj.utils import state
-from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
LRU,
ConversationCommand,
@@ -168,14 +163,6 @@ async def is_ready_to_chat(user: KhojUser):
if user_chat_model == None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
- if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
- chat_model_name = user_chat_model.name
- max_tokens = user_chat_model.max_prompt_size
- if state.offline_chat_processor_config is None:
- logger.info("Loading Offline Chat Model...")
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
- return True
-
if (
user_chat_model
and (
@@ -231,7 +218,6 @@ def update_telemetry_state(
telemetry_type=telemetry_type,
api=api,
client=client,
- app_config=state.config.app,
disable_telemetry_env=state.telemetry_disabled,
properties=user_state,
)
@@ -1470,12 +1456,6 @@ async def send_message_to_model_wrapper(
vision_available = chat_model.vision_enabled
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
- loaded_model = None
-
- if model_type == ChatModel.ModelType.OFFLINE:
- if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
- loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
@@ -1483,7 +1463,6 @@ async def send_message_to_model_wrapper(
system_message=system_message,
chat_history=chat_history,
model_name=chat_model_name,
- loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
@@ -1492,18 +1471,7 @@ async def send_message_to_model_wrapper(
query_files=query_files,
)
- if model_type == ChatModel.ModelType.OFFLINE:
- return send_message_to_model_offline(
- messages=truncated_messages,
- loaded_model=loaded_model,
- model_name=chat_model_name,
- max_prompt_size=max_tokens,
- streaming=False,
- response_type=response_type,
- tracer=tracer,
- )
-
- elif model_type == ChatModel.ModelType.OPENAI:
+ if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
@@ -1565,19 +1533,12 @@ def send_message_to_model_wrapper_sync(
vision_available = chat_model.vision_enabled
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
- loaded_model = None
-
- if model_type == ChatModel.ModelType.OFFLINE:
- if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
- state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
- loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
chat_history=chat_history,
model_name=chat_model_name,
- loaded_model=loaded_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=model_type,
@@ -1585,18 +1546,7 @@ def send_message_to_model_wrapper_sync(
query_files=query_files,
)
- if model_type == ChatModel.ModelType.OFFLINE:
- return send_message_to_model_offline(
- messages=truncated_messages,
- loaded_model=loaded_model,
- model_name=chat_model_name,
- max_prompt_size=max_tokens,
- streaming=False,
- response_type=response_type,
- tracer=tracer,
- )
-
- elif model_type == ChatModel.ModelType.OPENAI:
+ if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
@@ -1678,30 +1628,7 @@ async def agenerate_chat_response(
chat_model = vision_enabled_config
vision_available = True
- if chat_model.model_type == "offline":
- loaded_model = state.offline_chat_processor_config.loaded_model
- chat_response_generator = converse_offline(
- # Query
- user_query=query_to_run,
- # Context
- references=compiled_references,
- online_results=online_results,
- generated_files=raw_generated_files,
- generated_asset_results=generated_asset_results,
- location_data=location_data,
- user_name=user_name,
- query_files=query_files,
- chat_history=chat_history,
- # Model
- loaded_model=loaded_model,
- model_name=chat_model.name,
- max_prompt_size=chat_model.max_prompt_size,
- tokenizer_name=chat_model.tokenizer,
- agent=agent,
- tracer=tracer,
- )
-
- elif chat_model.model_type == ChatModel.ModelType.OPENAI:
+ if chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name
@@ -2798,7 +2725,8 @@ def configure_content(
search_type = t.value if t else None
- no_documents = all([not files.get(file_type) for file_type in files])
+ # Check if client sent any documents of the supported types
+ no_client_sent_documents = all([not files.get(file_type) for file_type in files])
if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.")
@@ -2872,7 +2800,8 @@ def configure_content(
success = False
try:
- if no_documents:
+ # Run server side indexing of user Github docs if no client sent documents
+ if no_client_sent_documents:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
@@ -2892,7 +2821,8 @@ def configure_content(
success = False
try:
- if no_documents:
+ # Run server side indexing of user Notion docs if no client sent documents
+ if no_client_sent_documents:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
if (
diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py
index 14581f41..a016f9a4 100644
--- a/src/khoj/utils/cli.py
+++ b/src/khoj/utils/cli.py
@@ -1,36 +1,19 @@
import argparse
import logging
-import os
import pathlib
from importlib.metadata import version
logger = logging.getLogger(__name__)
-from khoj.migrations.migrate_offline_chat_default_model import (
- migrate_offline_chat_default_model,
-)
-from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
-from khoj.migrations.migrate_offline_model import migrate_offline_model
-from khoj.migrations.migrate_processor_config_openai import (
- migrate_processor_conversation_schema,
-)
-from khoj.migrations.migrate_server_pg import migrate_server_pg
-from khoj.migrations.migrate_version import migrate_config_to_version
-from khoj.utils.helpers import is_env_var_true, resolve_absolute_path
-from khoj.utils.yaml import parse_config_from_file
-
def cli(args=None):
# Setup Argument Parser for the Commandline Interface
parser = argparse.ArgumentParser(description="Start Khoj; An AI personal assistant for your Digital Brain")
parser.add_argument(
- "--config-file", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj"
- )
- parser.add_argument(
- "--regenerate",
- action="store_true",
- default=False,
- help="Regenerate model embeddings from source files. Default: false",
+ "--log-file",
+ default="~/.khoj/khoj.log",
+ type=pathlib.Path,
+ help="File path for server logs. Default: ~/.khoj/khoj.log",
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1")
@@ -43,14 +26,11 @@ def cli(args=None):
parser.add_argument("--sslcert", type=str, help="Path to SSL certificate file")
parser.add_argument("--sslkey", type=str, help="Path to SSL key file")
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
- parser.add_argument(
- "--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
- )
parser.add_argument(
"--anonymous-mode",
action="store_true",
default=False,
- help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
+ help="Run Khoj in single user mode with no login required. Useful for personal use or testing.",
)
parser.add_argument(
"--non-interactive",
@@ -64,38 +44,10 @@ def cli(args=None):
if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
- # Set default values for arguments
- args.chat_on_gpu = not args.disable_chat_on_gpu
-
args.version_no = version("khoj")
if args.version:
# Show version of khoj installed and exit
print(args.version_no)
exit(0)
- # Normalize config_file path to absolute path
- args.config_file = resolve_absolute_path(args.config_file)
-
- if not args.config_file.exists():
- args.config = None
- else:
- args = run_migrations(args)
- args.config = parse_config_from_file(args.config_file)
- if is_env_var_true("KHOJ_TELEMETRY_DISABLE"):
- args.config.app.should_log_telemetry = False
-
- return args
-
-
-def run_migrations(args):
- migrations = [
- migrate_config_to_version,
- migrate_processor_conversation_schema,
- migrate_offline_model,
- migrate_offline_chat_schema,
- migrate_offline_chat_default_model,
- migrate_server_pg,
- ]
- for migration in migrations:
- args = migration(args)
return args
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 03dad75c..c9cd1c43 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -1,22 +1,7 @@
# System Packages
from __future__ import annotations # to avoid quoting type hints
-import logging
-from dataclasses import dataclass
from enum import Enum
-from typing import TYPE_CHECKING, Any, List, Optional, Union
-
-import torch
-
-from khoj.processor.conversation.offline.utils import download_model
-
-logger = logging.getLogger(__name__)
-
-
-if TYPE_CHECKING:
- from sentence_transformers import CrossEncoder
-
- from khoj.utils.models import BaseEncoder
class SearchType(str, Enum):
@@ -29,53 +14,3 @@ class SearchType(str, Enum):
Notion = "notion"
Plaintext = "plaintext"
Docx = "docx"
-
-
-class ProcessorType(str, Enum):
- Conversation = "conversation"
-
-
-@dataclass
-class TextContent:
- enabled: bool
-
-
-@dataclass
-class ImageContent:
- image_names: List[str]
- image_embeddings: torch.Tensor
- image_metadata_embeddings: torch.Tensor
-
-
-@dataclass
-class TextSearchModel:
- bi_encoder: BaseEncoder
- cross_encoder: Optional[CrossEncoder] = None
- top_k: Optional[int] = 15
-
-
-@dataclass
-class ImageSearchModel:
- image_encoder: BaseEncoder
-
-
-@dataclass
-class SearchModels:
- text_search: Optional[TextSearchModel] = None
-
-
-@dataclass
-class OfflineChatProcessorConfig:
- loaded_model: Union[Any, None] = None
-
-
-class OfflineChatProcessorModel:
- def __init__(self, chat_model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", max_tokens: int = None):
- self.chat_model = chat_model
- self.loaded_model = None
- try:
- self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
- except ValueError as e:
- self.loaded_model = None
- logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
- raise e
diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py
index 4f128aa7..b2ed49de 100644
--- a/src/khoj/utils/constants.py
+++ b/src/khoj/utils/constants.py
@@ -10,13 +10,6 @@ empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
-default_offline_chat_models = [
- "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
- "bartowski/Llama-3.2-3B-Instruct-GGUF",
- "bartowski/gemma-2-9b-it-GGUF",
- "bartowski/gemma-2-2b-it-GGUF",
- "bartowski/Qwen2.5-14B-Instruct-GGUF",
-]
default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1", "o3", "o4-mini"]
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"]
default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"]
diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py
deleted file mode 100644
index 67e91bc9..00000000
--- a/src/khoj/utils/fs_syncer.py
+++ /dev/null
@@ -1,252 +0,0 @@
-import glob
-import logging
-import os
-from pathlib import Path
-from typing import Optional
-
-from bs4 import BeautifulSoup
-from magika import Magika
-
-from khoj.database.models import (
- KhojUser,
- LocalMarkdownConfig,
- LocalOrgConfig,
- LocalPdfConfig,
- LocalPlaintextConfig,
-)
-from khoj.utils.config import SearchType
-from khoj.utils.helpers import get_absolute_path, is_none_or_empty
-from khoj.utils.rawconfig import TextContentConfig
-
-logger = logging.getLogger(__name__)
-magika = Magika()
-
-
-def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict:
- files: dict[str, dict] = {"docx": {}, "image": {}}
-
- if search_type == SearchType.All or search_type == SearchType.Org:
- org_config = LocalOrgConfig.objects.filter(user=user).first()
- files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
- if search_type == SearchType.All or search_type == SearchType.Markdown:
- markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
- files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
- if search_type == SearchType.All or search_type == SearchType.Plaintext:
- plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
- files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
- if search_type == SearchType.All or search_type == SearchType.Pdf:
- pdf_config = LocalPdfConfig.objects.filter(user=user).first()
- files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
- files["image"] = {}
- files["docx"] = {}
- return files
-
-
-def construct_config_from_db(db_config) -> TextContentConfig:
- return TextContentConfig(
- input_files=db_config.input_files,
- input_filter=db_config.input_filter,
- index_heading_entries=db_config.index_heading_entries,
- )
-
-
-def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
- def is_plaintextfile(file: str):
- "Check if file is plaintext file"
- # Check if file path exists
- content_group = magika.identify_path(Path(file)).output.group
- # Use file extension to decide plaintext if file content is not identifiable
- valid_text_file_extensions = ("txt", "md", "markdown", "org" "mbox", "rst", "html", "htm", "xml")
- return file.endswith(valid_text_file_extensions) or content_group in ["text", "code"]
-
- def extract_html_content(html_content: str):
- "Extract content from HTML"
- soup = BeautifulSoup(html_content, "html.parser")
- return soup.get_text(strip=True, separator="\n")
-
- # Extract required fields from config
- input_files, input_filters = (
- config.input_files,
- config.input_filter,
- )
-
- # Input Validation
- if is_none_or_empty(input_files) and is_none_or_empty(input_filters):
- logger.debug("At least one of input-files or input-file-filter is required to be specified")
- return {}
-
- # Get all plain text files to process
- absolute_plaintext_files, filtered_plaintext_files = set(), set()
- if input_files:
- absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in input_files}
- if input_filters:
- filtered_plaintext_files = {
- filtered_file
- for plaintext_file_filter in input_filters
- for filtered_file in glob.glob(get_absolute_path(plaintext_file_filter), recursive=True)
- if os.path.isfile(filtered_file)
- }
-
- all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files)
-
- files_with_no_plaintext_extensions = {
- target_files for target_files in all_target_files if not is_plaintextfile(target_files)
- }
- if any(files_with_no_plaintext_extensions):
- logger.warning(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}")
- all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions)
-
- logger.debug(f"Processing files: {all_target_files}")
-
- filename_to_content_map = {}
- for file in all_target_files:
- with open(file, "r", encoding="utf8") as f:
- try:
- plaintext_content = f.read()
- if file.endswith(("html", "htm", "xml")):
- plaintext_content = extract_html_content(plaintext_content)
- filename_to_content_map[file] = plaintext_content
- except Exception as e:
- logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
- logger.warning(e, exc_info=True)
-
- return filename_to_content_map
-
-
-def get_org_files(config: TextContentConfig):
- # Extract required fields from config
- org_files, org_file_filters = (
- config.input_files,
- config.input_filter,
- )
-
- # Input Validation
- if is_none_or_empty(org_files) and is_none_or_empty(org_file_filters):
- logger.debug("At least one of org-files or org-file-filter is required to be specified")
- return {}
-
- # Get Org files to process
- absolute_org_files, filtered_org_files = set(), set()
- if org_files:
- absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
- if org_file_filters:
- filtered_org_files = {
- filtered_file
- for org_file_filter in org_file_filters
- for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True)
- if os.path.isfile(filtered_file)
- }
-
- all_org_files = sorted(absolute_org_files | filtered_org_files)
-
- files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
- if any(files_with_non_org_extensions):
- logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
-
- logger.debug(f"Processing files: {all_org_files}")
-
- filename_to_content_map = {}
- for file in all_org_files:
- with open(file, "r", encoding="utf8") as f:
- try:
- filename_to_content_map[file] = f.read()
- except Exception as e:
- logger.warning(f"Unable to read file: {file} as org. Skipping file.")
- logger.warning(e, exc_info=True)
-
- return filename_to_content_map
-
-
-def get_markdown_files(config: TextContentConfig):
- # Extract required fields from config
- markdown_files, markdown_file_filters = (
- config.input_files,
- config.input_filter,
- )
-
- # Input Validation
- if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filters):
- logger.debug("At least one of markdown-files or markdown-file-filter is required to be specified")
- return {}
-
- # Get markdown files to process
- absolute_markdown_files, filtered_markdown_files = set(), set()
- if markdown_files:
- absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
-
- if markdown_file_filters:
- filtered_markdown_files = {
- filtered_file
- for markdown_file_filter in markdown_file_filters
- for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True)
- if os.path.isfile(filtered_file)
- }
-
- all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
-
- files_with_non_markdown_extensions = {
- md_file for md_file in all_markdown_files if not md_file.endswith(".md") and not md_file.endswith(".markdown")
- }
-
- if any(files_with_non_markdown_extensions):
- logger.warning(
- f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
- )
-
- logger.debug(f"Processing files: {all_markdown_files}")
-
- filename_to_content_map = {}
- for file in all_markdown_files:
- with open(file, "r", encoding="utf8") as f:
- try:
- filename_to_content_map[file] = f.read()
- except Exception as e:
- logger.warning(f"Unable to read file: {file} as markdown. Skipping file.")
- logger.warning(e, exc_info=True)
-
- return filename_to_content_map
-
-
-def get_pdf_files(config: TextContentConfig):
- # Extract required fields from config
- pdf_files, pdf_file_filters = (
- config.input_files,
- config.input_filter,
- )
-
- # Input Validation
- if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filters):
- logger.debug("At least one of pdf-files or pdf-file-filter is required to be specified")
- return {}
-
- # Get PDF files to process
- absolute_pdf_files, filtered_pdf_files = set(), set()
- if pdf_files:
- absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files}
- if pdf_file_filters:
- filtered_pdf_files = {
- filtered_file
- for pdf_file_filter in pdf_file_filters
- for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True)
- if os.path.isfile(filtered_file)
- }
-
- all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files)
-
- files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
-
- if any(files_with_non_pdf_extensions):
- logger.warning(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}")
-
- logger.debug(f"Processing files: {all_pdf_files}")
-
- filename_to_content_map = {}
- for file in all_pdf_files:
- with open(file, "rb") as f:
- try:
- filename_to_content_map[file] = f.read()
- except Exception as e:
- logger.warning(f"Unable to read file: {file} as PDF. Skipping file.")
- logger.warning(e, exc_info=True)
-
- return filename_to_content_map
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 508901de..523ec007 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -47,7 +47,6 @@ if TYPE_CHECKING:
from sentence_transformers import CrossEncoder, SentenceTransformer
from khoj.utils.models import BaseEncoder
- from khoj.utils.rawconfig import AppConfig
logger = logging.getLogger(__name__)
@@ -267,23 +266,16 @@ def get_server_id():
return server_id
-def telemetry_disabled(app_config: AppConfig, telemetry_disable_env) -> bool:
- if telemetry_disable_env is True:
- return True
- return not app_config or not app_config.should_log_telemetry
-
-
def log_telemetry(
telemetry_type: str,
api: str = None,
client: Optional[str] = None,
- app_config: Optional[AppConfig] = None,
disable_telemetry_env: bool = False,
properties: dict = None,
):
"""Log basic app usage telemetry like client, os, api called"""
# Do not log usage telemetry, if telemetry is disabled via app config
- if telemetry_disabled(app_config, disable_telemetry_env):
+ if disable_telemetry_env:
return []
if properties.get("server_id") is None:
diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py
index 9ed1bdff..8023b3ed 100644
--- a/src/khoj/utils/initialization.py
+++ b/src/khoj/utils/initialization.py
@@ -16,7 +16,6 @@ from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tok
from khoj.utils.constants import (
default_anthropic_chat_models,
default_gemini_chat_models,
- default_offline_chat_models,
default_openai_chat_models,
)
@@ -72,7 +71,6 @@ def initialization(interactive: bool = True):
default_api_key=openai_api_key,
api_base_url=openai_base_url,
vision_enabled=True,
- is_offline=False,
interactive=interactive,
provider_name=provider,
)
@@ -118,7 +116,6 @@ def initialization(interactive: bool = True):
default_gemini_chat_models,
default_api_key=os.getenv("GEMINI_API_KEY"),
vision_enabled=True,
- is_offline=False,
interactive=interactive,
provider_name="Google Gemini",
)
@@ -145,40 +142,11 @@ def initialization(interactive: bool = True):
default_anthropic_chat_models,
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
vision_enabled=True,
- is_offline=False,
- interactive=interactive,
- )
-
- # Set up offline chat models
- _setup_chat_model_provider(
- ChatModel.ModelType.OFFLINE,
- default_offline_chat_models,
- default_api_key=None,
- vision_enabled=False,
- is_offline=True,
interactive=interactive,
)
logger.info("🗣️ Chat model configuration complete")
- # Set up offline speech to text model
- use_offline_speech2text_model = "n" if not interactive else input("Use offline speech to text model? (y/n): ")
- if use_offline_speech2text_model == "y":
- logger.info("🗣️ Setting up offline speech to text model")
- # Delete any existing speech to text model options. There can only be one.
- SpeechToTextModelOptions.objects.all().delete()
-
- default_offline_speech2text_model = "base"
- offline_speech2text_model = input(
- f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): "
- )
- offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model
- SpeechToTextModelOptions.objects.create(
- model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE
- )
-
- logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
-
def _setup_chat_model_provider(
model_type: ChatModel.ModelType,
default_chat_models: list,
@@ -186,7 +154,6 @@ def initialization(interactive: bool = True):
interactive: bool,
api_base_url: str = None,
vision_enabled: bool = False,
- is_offline: bool = False,
provider_name: str = None,
) -> Tuple[bool, AiModelApi]:
supported_vision_models = (
@@ -195,11 +162,6 @@ def initialization(interactive: bool = True):
provider_name = provider_name or model_type.name.capitalize()
default_use_model = default_api_key is not None
- # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment.
- # This usually means there's not enough RAM to load offline models directly within the application.
- # In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
- if is_offline:
- default_use_model = False
use_model_provider = (
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y"
@@ -211,13 +173,12 @@ def initialization(interactive: bool = True):
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
ai_model_api = None
- if not is_offline:
- if interactive:
- user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
- api_key = user_api_key if user_api_key != "" else default_api_key
- else:
- api_key = default_api_key
- ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
+ if interactive:
+ user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
+ api_key = user_api_key if user_api_key != "" else default_api_key
+ else:
+ api_key = default_api_key
+ ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
if interactive:
user_chat_models = input(
diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py
index df7f3334..5377577b 100644
--- a/src/khoj/utils/rawconfig.py
+++ b/src/khoj/utils/rawconfig.py
@@ -48,17 +48,6 @@ class FilesFilterRequest(BaseModel):
conversation_id: str
-class TextConfigBase(ConfigBase):
- compressed_jsonl: Path
- embeddings_file: Path
-
-
-class TextContentConfig(ConfigBase):
- input_files: Optional[List[Path]] = None
- input_filter: Optional[List[str]] = None
- index_heading_entries: Optional[bool] = False
-
-
class GithubRepoConfig(ConfigBase):
name: str
owner: str
@@ -74,62 +63,6 @@ class NotionContentConfig(ConfigBase):
token: str
-class ContentConfig(ConfigBase):
- org: Optional[TextContentConfig] = None
- markdown: Optional[TextContentConfig] = None
- pdf: Optional[TextContentConfig] = None
- plaintext: Optional[TextContentConfig] = None
- github: Optional[GithubContentConfig] = None
- notion: Optional[NotionContentConfig] = None
- image: Optional[TextContentConfig] = None
- docx: Optional[TextContentConfig] = None
-
-
-class ImageSearchConfig(ConfigBase):
- encoder: str
- encoder_type: Optional[str] = None
- model_directory: Optional[Path] = None
-
- class Config:
- protected_namespaces = ()
-
-
-class SearchConfig(ConfigBase):
- image: Optional[ImageSearchConfig] = None
-
-
-class OpenAIProcessorConfig(ConfigBase):
- api_key: str
- chat_model: Optional[str] = "gpt-4o-mini"
-
-
-class OfflineChatProcessorConfig(ConfigBase):
- chat_model: Optional[str] = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
-
-
-class ConversationProcessorConfig(ConfigBase):
- openai: Optional[OpenAIProcessorConfig] = None
- offline_chat: Optional[OfflineChatProcessorConfig] = None
- max_prompt_size: Optional[int] = None
- tokenizer: Optional[str] = None
-
-
-class ProcessorConfig(ConfigBase):
- conversation: Optional[ConversationProcessorConfig] = None
-
-
-class AppConfig(ConfigBase):
- should_log_telemetry: bool = True
-
-
-class FullConfig(ConfigBase):
- content_type: Optional[ContentConfig] = None
- search_type: Optional[SearchConfig] = None
- processor: Optional[ProcessorConfig] = None
- app: Optional[AppConfig] = AppConfig()
- version: Optional[str] = None
-
-
class SearchResponse(ConfigBase):
entry: str
score: float
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index f96409c2..3b65a85b 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -12,19 +12,14 @@ from whisper import Whisper
from khoj.database.models import ProcessLock
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
-from khoj.utils.config import OfflineChatProcessorModel, SearchModels
from khoj.utils.helpers import LRU, get_device, is_env_var_true
-from khoj.utils.rawconfig import FullConfig
# Application Global State
-config = FullConfig()
-search_models = SearchModels()
embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: Dict[str, CrossEncoderModel] = None
openai_client: OpenAI = None
-offline_chat_processor_config: OfflineChatProcessorModel = None
whisper_model: Whisper = None
-config_file: Path = None
+log_file: Path = None
verbose: int = 0
host: str = None
port: int = None
@@ -39,7 +34,6 @@ telemetry: List[Dict[str, str]] = []
telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
khoj_version: str = None
device = get_device()
-chat_on_gpu: bool = True
anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
billing_enabled: bool = (
diff --git a/src/khoj/utils/yaml.py b/src/khoj/utils/yaml.py
index f658e1eb..43b139e5 100644
--- a/src/khoj/utils/yaml.py
+++ b/src/khoj/utils/yaml.py
@@ -1,47 +1,8 @@
-from pathlib import Path
-
import yaml
-from khoj.utils import state
-from khoj.utils.rawconfig import FullConfig
-
# Do not emit tags when dumping to YAML
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None # type: ignore[assignment]
-def save_config_to_file_updated_state():
- with open(state.config_file, "w") as outfile:
- yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
- outfile.close()
- return state.config
-
-
-def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
- "Write config to YML file"
- # Create output directory, if it doesn't exist
- yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
-
- with open(yaml_config_file, "w", encoding="utf-8") as config_file:
- yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
-
-
-def load_config_from_file(yaml_config_file: Path) -> dict:
- "Read config from YML file"
- config_from_file = None
- with open(yaml_config_file, "r", encoding="utf-8") as config_file:
- config_from_file = yaml.safe_load(config_file)
- return config_from_file
-
-
-def parse_config_from_string(yaml_config: dict) -> FullConfig:
- "Parse and validate config in YML string"
- return FullConfig.model_validate(yaml_config)
-
-
-def parse_config_from_file(yaml_config_file):
- "Parse and validate config in YML file"
- return parse_config_from_string(load_config_from_file(yaml_config_file))
-
-
def yaml_dump(data):
return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False)
diff --git a/tests/conftest.py b/tests/conftest.py
index 25876d61..dd448bd1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,3 @@
-import os
-from pathlib import Path
-
import pytest
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
@@ -11,6 +8,7 @@ from khoj.configure import (
configure_routes,
configure_search_types,
)
+from khoj.database.adapters import get_default_search_model
from khoj.database.models import (
Agent,
ChatModel,
@@ -19,21 +17,14 @@ from khoj.database.models import (
GithubRepoConfig,
KhojApiUser,
KhojUser,
- LocalMarkdownConfig,
- LocalOrgConfig,
- LocalPdfConfig,
- LocalPlaintextConfig,
)
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.api_content import configure_content
from khoj.search_type import text_search
-from khoj.utils import fs_syncer, state
-from khoj.utils.config import SearchModels
+from khoj.utils import state
from khoj.utils.constants import web_directory
-from khoj.utils.helpers import resolve_absolute_path
-from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
from tests.helpers import (
AiModelApiFactory,
ChatModelFactory,
@@ -43,6 +34,8 @@ from tests.helpers import (
UserFactory,
get_chat_api_key,
get_chat_provider,
+ get_index_files,
+ get_sample_data,
)
@@ -59,23 +52,16 @@ def django_db_setup(django_db_setup, django_db_blocker):
@pytest.fixture(scope="session")
-def search_config() -> SearchConfig:
+def search_config():
+ search_model = get_default_search_model()
state.embeddings_model = dict()
- state.embeddings_model["default"] = EmbeddingsModel()
- state.cross_encoder_model = dict()
- state.cross_encoder_model["default"] = CrossEncoderModel()
-
- model_dir = resolve_absolute_path("~/.khoj/search")
- model_dir.mkdir(parents=True, exist_ok=True)
- search_config = SearchConfig()
-
- search_config.image = ImageSearchConfig(
- encoder="sentence-transformers/clip-ViT-B-32",
- model_directory=model_dir / "image/",
- encoder_type=None,
+ state.embeddings_model["default"] = EmbeddingsModel(
+ model_name=search_model.bi_encoder, model_kwargs=search_model.bi_encoder_model_config
+ )
+ state.cross_encoder_model = dict()
+ state.cross_encoder_model["default"] = CrossEncoderModel(
+ model_name=search_model.cross_encoder, model_kwargs=search_model.cross_encoder_model_config
)
-
- return search_config
@pytest.mark.django_db
@@ -196,17 +182,6 @@ def default_openai_chat_model_option():
return chat_model
-@pytest.mark.django_db
-@pytest.fixture
-def offline_agent():
- chat_model = ChatModelFactory()
- return Agent.objects.create(
- name="Accountant",
- chat_model=chat_model,
- personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
- )
-
-
@pytest.mark.django_db
@pytest.fixture
def openai_agent():
@@ -218,13 +193,6 @@ def openai_agent():
)
-@pytest.fixture(scope="session")
-def search_models(search_config: SearchConfig):
- search_models = SearchModels()
-
- return search_models
-
-
@pytest.mark.django_db
@pytest.fixture
def default_process_lock():
@@ -236,72 +204,23 @@ def anyio_backend():
return "asyncio"
-@pytest.mark.django_db
@pytest.fixture(scope="function")
-def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
- content_dir = tmp_path_factory.mktemp("content")
-
- # Generate Image Embeddings from Test Images
- content_config = ContentConfig()
-
- LocalOrgConfig.objects.create(
- input_files=None,
- input_filter=["tests/data/org/*.org"],
- index_heading_entries=False,
- user=default_user,
- )
-
- text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=False, user=default_user)
-
- if os.getenv("GITHUB_PAT_TOKEN"):
- GithubConfig.objects.create(
- pat_token=os.getenv("GITHUB_PAT_TOKEN"),
- user=default_user,
- )
-
- GithubRepoConfig.objects.create(
- owner="khoj-ai",
- name="lantern",
- branch="master",
- github_config=GithubConfig.objects.get(user=default_user),
- )
-
- LocalPlaintextConfig.objects.create(
- input_files=None,
- input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
- user=default_user,
- )
-
- return content_config
-
-
-@pytest.fixture(scope="session")
-def md_content_config():
- markdown_config = LocalMarkdownConfig.objects.create(
- input_files=None,
- input_filter=["tests/data/markdown/*.markdown"],
- )
-
- return markdown_config
-
-
-@pytest.fixture(scope="function")
-def chat_client(search_config: SearchConfig, default_user2: KhojUser):
+def chat_client(search_config, default_user2: KhojUser):
return chat_client_builder(search_config, default_user2, require_auth=False)
@pytest.fixture(scope="function")
-def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser):
+def chat_client_with_auth(search_config, default_user2: KhojUser):
return chat_client_builder(search_config, default_user2, require_auth=True)
@pytest.fixture(scope="function")
-def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
+def chat_client_no_background(search_config, default_user2: KhojUser):
return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False)
@pytest.fixture(scope="function")
-def chat_client_with_large_kb(search_config: SearchConfig, default_user2: KhojUser):
+def chat_client_with_large_kb(search_config, default_user2: KhojUser):
"""
Chat client fixture that creates a large knowledge base with many files
for stress testing atomic agent updates.
@@ -312,19 +231,14 @@ def chat_client_with_large_kb(search_config: SearchConfig, default_user2: KhojUs
@pytest.mark.django_db
def chat_client_builder(search_config, user, index_content=True, require_auth=False):
# Initialize app state
- state.config.search_type = search_config
state.SearchType = configure_search_types()
if index_content:
- LocalMarkdownConfig.objects.create(
- input_files=None,
- input_filter=["tests/data/markdown/*.markdown"],
- user=user,
- )
+ file_type = "markdown"
+ files_to_index = {file_type: get_index_files(input_filters=[f"tests/data/{file_type}/*.{file_type}"])}
# Index Markdown Content for Search
- all_files = fs_syncer.collect_files(user=user)
- configure_content(user, all_files)
+ configure_content(user, files_to_index)
# Initialize Processor from Config
chat_provider = get_chat_provider()
@@ -360,17 +274,17 @@ def large_kb_chat_client_builder(search_config, user):
import tempfile
# Initialize app state
- state.config.search_type = search_config
state.SearchType = configure_search_types()
# Create temporary directory for large number of test files
temp_dir = tempfile.mkdtemp(prefix="khoj_test_large_kb_")
+ file_type = "markdown"
large_file_list = []
try:
# Generate 200 test files with substantial content
for i in range(300):
- file_path = os.path.join(temp_dir, f"test_file_{i:03d}.markdown")
+ file_path = os.path.join(temp_dir, f"test_file_{i:03d}.{file_type}")
content = f"""
# Test File {i}
@@ -420,16 +334,9 @@ End of file {i}.
f.write(content)
large_file_list.append(file_path)
- # Create LocalMarkdownConfig with all the generated files
- LocalMarkdownConfig.objects.create(
- input_files=large_file_list,
- input_filter=None,
- user=user,
- )
-
- # Index all the files into the user's knowledge base
- all_files = fs_syncer.collect_files(user=user)
- configure_content(user, all_files)
+ # Index all generated files into the user's knowledge base
+ files_to_index = {file_type: get_index_files(input_files=large_file_list, input_filters=None)}
+ configure_content(user, files_to_index)
# Verify we have a substantial knowledge base
file_count = FileObject.objects.filter(user=user, agent=None).count()
@@ -481,12 +388,8 @@ def fastapi_app():
@pytest.fixture(scope="function")
def client(
- content_config: ContentConfig,
- search_config: SearchConfig,
api_user: KhojApiUser,
):
- state.config.content_type = content_config
- state.config.search_type = search_config
state.SearchType = configure_search_types()
state.embeddings_model = dict()
state.embeddings_model["default"] = EmbeddingsModel()
@@ -516,173 +419,18 @@ def client(
return TestClient(app)
-@pytest.fixture(scope="function")
-def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
- # Initialize app state
- state.config.search_type = search_config
- state.SearchType = configure_search_types()
-
- LocalMarkdownConfig.objects.create(
- input_files=None,
- input_filter=["tests/data/markdown/*.markdown"],
- user=default_user2,
- )
-
- all_files = fs_syncer.collect_files(user=default_user2)
- configure_content(default_user2, all_files)
-
- # Initialize Processor from Config
- ChatModelFactory(
- name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
- tokenizer=None,
- max_prompt_size=None,
- model_type="offline",
- )
- UserConversationProcessorConfigFactory(user=default_user2)
-
- state.anonymous_mode = True
-
- app = FastAPI()
-
- configure_routes(app)
- configure_middleware(app)
- app.mount("/static", StaticFiles(directory=web_directory), name="static")
- return TestClient(app)
-
-
-@pytest.fixture(scope="function")
-def new_org_file(default_user: KhojUser, content_config: ContentConfig):
- # Setup
- org_config = LocalOrgConfig.objects.filter(user=default_user).first()
- input_filters = org_config.input_filter
- new_org_file = Path(input_filters[0]).parent / "new_file.org"
- new_org_file.touch()
-
- yield new_org_file
-
- # Cleanup
- if new_org_file.exists():
- new_org_file.unlink()
-
-
-@pytest.fixture(scope="function")
-def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
- LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
- return LocalOrgConfig.objects.filter(user=default_user).first()
-
-
@pytest.fixture(scope="function")
def pdf_configured_user1(default_user: KhojUser):
- LocalPdfConfig.objects.create(
- input_files=None,
- input_filter=["tests/data/pdf/singlepage.pdf"],
- user=default_user,
- )
- # Index Markdown Content for Search
- all_files = fs_syncer.collect_files(user=default_user)
- configure_content(default_user, all_files)
+ # Read data from pdf file at tests/data/pdf/singlepage.pdf
+ pdf_file_path = "tests/data/pdf/singlepage.pdf"
+ with open(pdf_file_path, "rb") as pdf_file:
+ pdf_data = pdf_file.read()
+
+ knowledge_base = {"pdf": {"singlepage.pdf": pdf_data}}
+ # Index Content for Search
+ configure_content(default_user, knowledge_base)
@pytest.fixture(scope="function")
def sample_org_data():
return get_sample_data("org")
-
-
-def get_sample_data(type):
- sample_data = {
- "org": {
- "elisp.org": """
-* Emacs Khoj
- /An Emacs interface for [[https://github.com/khoj-ai/khoj][khoj]]/
-
-** Requirements
- - Install and Run [[https://github.com/khoj-ai/khoj][khoj]]
-
-** Installation
-*** Direct
- - Put ~khoj.el~ in your Emacs load path. For e.g. ~/.emacs.d/lisp
- - Load via ~use-package~ in your ~/.emacs.d/init.el or .emacs file by adding below snippet
- #+begin_src elisp
- ;; Khoj Package
- (use-package khoj
- :load-path "~/.emacs.d/lisp/khoj.el"
- :bind ("C-c s" . 'khoj))
- #+end_src
-
-*** Using [[https://github.com/quelpa/quelpa#installation][Quelpa]]
- - Ensure [[https://github.com/quelpa/quelpa#installation][Quelpa]], [[https://github.com/quelpa/quelpa-use-package#installation][quelpa-use-package]] are installed
- - Add below snippet to your ~/.emacs.d/init.el or .emacs config file and execute it.
- #+begin_src elisp
- ;; Khoj Package
- (use-package khoj
- :quelpa (khoj :fetcher url :url "https://raw.githubusercontent.com/khoj-ai/khoj/master/interface/emacs/khoj.el")
- :bind ("C-c s" . 'khoj))
- #+end_src
-
-** Usage
- 1. Call ~khoj~ using keybinding ~C-c s~ or ~M-x khoj~
- 2. Enter Query in Natural Language
- e.g. "What is the meaning of life?" "What are my life goals?"
- 3. Wait for results
- *Note: It takes about 15s on a Mac M1 and a ~100K lines corpus of org-mode files*
- 4. (Optional) Narrow down results further
- Include/Exclude specific words from results by adding to query
- e.g. "What is the meaning of life? -god +none"
-
-""",
- "readme.org": """
-* Khoj
- /Allow natural language search on user content like notes, images using transformer based models/
-
- All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
-
-** Dependencies
- - Python3
- - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
-
-** Install
- #+begin_src shell
- git clone https://github.com/khoj-ai/khoj && cd khoj
- conda env create -f environment.yml
- conda activate khoj
- #+end_src""",
- },
- "markdown": {
- "readme.markdown": """
-# Khoj
-Allow natural language search on user content like notes, images using transformer based models
-
-All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline
-
-## Dependencies
-- Python3
-- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links)
-
-## Install
-```shell
-git clone
-conda env create -f environment.yml
-conda activate khoj
-```
-"""
- },
- "plaintext": {
- "readme.txt": """
-Khoj
-Allow natural language search on user content like notes, images using transformer based models
-
-All data is processed locally. User can interface with khoj app via Emacs, API or Commandline
-
-Dependencies
-- Python3
-- Miniconda
-
-Install
-git clone
-conda env create -f environment.yml
-conda activate khoj
-"""
- },
- }
-
- return sample_data[type]
diff --git a/tests/helpers.py b/tests/helpers.py
index d3c94abc..6edb0946 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,3 +1,5 @@
+import glob
+import logging
import os
from datetime import datetime
@@ -17,9 +19,12 @@ from khoj.database.models import (
UserConversationConfig,
)
from khoj.processor.conversation.utils import message_to_log
+from khoj.utils.helpers import get_absolute_path, is_none_or_empty
+
+logger = logging.getLogger(__name__)
-def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
+def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
@@ -61,6 +66,140 @@ def generate_chat_history(message_list):
return chat_history
+def get_sample_data(type):
+ sample_data = {
+ "org": {
+ "elisp.org": """
+* Emacs Khoj
+ /An Emacs interface for [[https://github.com/khoj-ai/khoj][khoj]]/
+
+** Requirements
+ - Install and Run [[https://github.com/khoj-ai/khoj][khoj]]
+
+** Installation
+*** Direct
+ - Put ~khoj.el~ in your Emacs load path. For e.g. ~/.emacs.d/lisp
+ - Load via ~use-package~ in your ~/.emacs.d/init.el or .emacs file by adding below snippet
+ #+begin_src elisp
+ ;; Khoj Package
+ (use-package khoj
+ :load-path "~/.emacs.d/lisp/khoj.el"
+ :bind ("C-c s" . 'khoj))
+ #+end_src
+
+*** Using [[https://github.com/quelpa/quelpa#installation][Quelpa]]
+ - Ensure [[https://github.com/quelpa/quelpa#installation][Quelpa]], [[https://github.com/quelpa/quelpa-use-package#installation][quelpa-use-package]] are installed
+ - Add below snippet to your ~/.emacs.d/init.el or .emacs config file and execute it.
+ #+begin_src elisp
+ ;; Khoj Package
+ (use-package khoj
+ :quelpa (khoj :fetcher url :url "https://raw.githubusercontent.com/khoj-ai/khoj/master/interface/emacs/khoj.el")
+ :bind ("C-c s" . 'khoj))
+ #+end_src
+
+** Usage
+ 1. Call ~khoj~ using keybinding ~C-c s~ or ~M-x khoj~
+ 2. Enter Query in Natural Language
+ e.g. "What is the meaning of life?" "What are my life goals?"
+ 3. Wait for results
+ *Note: It takes about 15s on a Mac M1 and a ~100K lines corpus of org-mode files*
+ 4. (Optional) Narrow down results further
+ Include/Exclude specific words from results by adding to query
+ e.g. "What is the meaning of life? -god +none"
+
+""",
+ "readme.org": """
+* Khoj
+ /Allow natural language search on user content like notes, images using transformer based models/
+
+ All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
+
+** Dependencies
+ - Python3
+ - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
+
+** Install
+ #+begin_src shell
+ git clone https://github.com/khoj-ai/khoj && cd khoj
+ conda env create -f environment.yml
+ conda activate khoj
+ #+end_src""",
+ },
+ "markdown": {
+ "readme.markdown": """
+# Khoj
+Allow natural language search on user content like notes, images using transformer based models
+
+All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline
+
+## Dependencies
+- Python3
+- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links)
+
+## Install
+```shell
+git clone
+conda env create -f environment.yml
+conda activate khoj
+```
+"""
+ },
+ "plaintext": {
+ "readme.txt": """
+Khoj
+Allow natural language search on user content like notes, images using transformer based models
+
+All data is processed locally. User can interface with khoj app via Emacs, API or Commandline
+
+Dependencies
+- Python3
+- Miniconda
+
+Install
+git clone
+conda env create -f environment.yml
+conda activate khoj
+"""
+ },
+ }
+
+ return sample_data[type]
+
+
+def get_index_files(
+ input_files: list[str] = None, input_filters: list[str] | None = ["tests/data/org/*.org"]
+) -> dict[str, str]:
+ # Input Validation
+ if is_none_or_empty(input_files) and is_none_or_empty(input_filters):
+ logger.debug("At least one of input_files or input_filter is required to be specified")
+ return {}
+
+ # Get files to process
+ absolute_files, filtered_files = set(), set()
+ if input_files:
+ absolute_files = {get_absolute_path(input_file) for input_file in input_files}
+ if input_filters:
+ filtered_files = {
+ filtered_file
+ for file_filter in input_filters
+ for filtered_file in glob.glob(get_absolute_path(file_filter), recursive=True)
+ if os.path.isfile(filtered_file)
+ }
+
+ all_files = sorted(absolute_files | filtered_files)
+
+ filename_to_content_map = {}
+ for file in all_files:
+ with open(file, "r", encoding="utf8") as f:
+ try:
+ filename_to_content_map[file] = f.read()
+ except Exception as e:
+ logger.warning(f"Unable to read file: {file}. Skipping file.")
+ logger.warning(e, exc_info=True)
+
+ return filename_to_content_map
+
+
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
@@ -93,7 +232,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory):
max_prompt_size = 20000
tokenizer = None
- name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
+ name = "gemini-2.0-flash"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 21a242ef..1d3b96ec 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -15,7 +15,7 @@ from tests.helpers import ChatModelFactory
def test_create_default_agent(default_user: KhojUser):
ChatModelFactory()
- agent = AgentAdapters.create_default_agent(default_user)
+ agent = AgentAdapters.create_default_agent()
assert agent is not None
assert agent.input_tools == []
assert agent.output_modes == []
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 211ff38e..15908653 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,49 +1,15 @@
# Standard Modules
from pathlib import Path
-from random import random
from khoj.utils.cli import cli
-from khoj.utils.helpers import resolve_absolute_path
# Test
# ----------------------------------------------------------------------------------------------------
def test_cli_minimal_default():
# Act
- actual_args = cli([])
+ actual_args = cli(["-vvv"])
# Assert
- assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml"))
- assert actual_args.regenerate == False
- assert actual_args.verbose == 0
-
-
-# ----------------------------------------------------------------------------------------------------
-def test_cli_invalid_config_file_path():
- # Arrange
- non_existent_config_file = f"non-existent-khoj-{random()}.yml"
-
- # Act
- actual_args = cli([f"--config-file={non_existent_config_file}"])
-
- # Assert
- assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
- assert actual_args.config == None
-
-
-# ----------------------------------------------------------------------------------------------------
-def test_cli_config_from_file():
- # Act
- actual_args = cli(["--config-file=tests/data/config.yml", "--regenerate", "-vvv"])
-
- # Assert
- assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
- assert actual_args.regenerate == True
- assert actual_args.config is not None
+ assert actual_args.log_file == Path("~/.khoj/khoj.log")
assert actual_args.verbose == 3
-
- # Ensure content config is loaded from file
- assert actual_args.config.content_type.org.input_files == [
- Path("~/first_from_config.org"),
- Path("~/second_from_config.org"),
- ]
diff --git a/tests/test_client.py b/tests/test_client.py
index 00507851..46732a86 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -13,7 +13,6 @@ from khoj.database.models import KhojApiUser, KhojUser
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.search_type import text_search
from khoj.utils import state
-from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test
@@ -283,10 +282,6 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
state.anonymous_mode = True
- if state.config and state.config.content_type:
- state.config.content_type = None
- state.search_models = configure_search_types()
-
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
@@ -300,7 +295,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
+def test_notes_search(client, search_config, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
@@ -319,7 +314,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
+def test_notes_search_no_results(client, search_config, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
@@ -335,9 +330,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search_with_only_filters(
- client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser
-):
+def test_notes_search_with_only_filters(client, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
@@ -401,9 +394,7 @@ def test_notes_search_with_exclude_filter(client, sample_org_data, default_user:
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
-def test_notes_search_requires_parent_context(
- client, search_config: SearchConfig, sample_org_data, default_user: KhojUser
-):
+def test_notes_search_requires_parent_context(client, search_config, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py
index 9a36bd57..21f198ea 100644
--- a/tests/test_file_filter.py
+++ b/tests/test_file_filter.py
@@ -1,6 +1,13 @@
# Application Packages
from khoj.search_filter.file_filter import FileFilter
-from khoj.utils.rawconfig import Entry
+
+
+# Mock Entry class for testing
+class Entry:
+ def __init__(self, compiled="", raw="", file=""):
+ self.compiled = compiled
+ self.raw = raw
+ self.file = file
def test_can_filter_no_file_filter():
diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py
index 30813555..b8ab37a7 100644
--- a/tests/test_markdown_to_entries.py
+++ b/tests/test_markdown_to_entries.py
@@ -3,8 +3,6 @@ import re
from pathlib import Path
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
-from khoj.utils.fs_syncer import get_markdown_files
-from khoj.utils.rawconfig import TextContentConfig
def test_extract_markdown_with_no_headings(tmp_path):
@@ -212,43 +210,6 @@ longer body line 2.1
), "Third entry is second entries child heading"
-def test_get_markdown_files(tmp_path):
- "Ensure Markdown files specified via input-filter, input-files extracted"
- # Arrange
- # Include via input-filter globs
- group1_file1 = create_file(tmp_path, filename="group1-file1.md")
- group1_file2 = create_file(tmp_path, filename="group1-file2.md")
- group2_file1 = create_file(tmp_path, filename="group2-file1.markdown")
- group2_file2 = create_file(tmp_path, filename="group2-file2.markdown")
- # Include via input-file field
- file1 = create_file(tmp_path, filename="notes.md")
- # Not included by any filter
- create_file(tmp_path, filename="not-included-markdown.md")
- create_file(tmp_path, filename="not-included-text.txt")
-
- expected_files = set(
- [os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
- )
-
- # Setup input-files, input-filters
- input_files = [tmp_path / "notes.md"]
- input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
-
- markdown_config = TextContentConfig(
- input_files=input_files,
- input_filter=[str(filter) for filter in input_filter],
- compressed_jsonl=tmp_path / "test.jsonl",
- embeddings_file=tmp_path / "test_embeddings.jsonl",
- )
-
- # Act
- extracted_org_files = get_markdown_files(markdown_config)
-
- # Assert
- assert len(extracted_org_files) == 5
- assert set(extracted_org_files.keys()) == expected_files
-
-
def test_line_number_tracking_in_recursive_split():
"Ensure line numbers in URIs are correct after recursive splitting by checking against the actual file."
# Arrange
diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py
deleted file mode 100644
index 979710b6..00000000
--- a/tests/test_offline_chat_actors.py
+++ /dev/null
@@ -1,610 +0,0 @@
-from datetime import datetime
-
-import pytest
-
-from khoj.database.models import ChatModel
-from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
-from khoj.utils.helpers import ConversationCommand
-from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
-
-SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
-pytestmark = pytest.mark.skipif(
- SKIP_TESTS,
- reason="Disable in CI to avoid long test runs.",
-)
-
-import freezegun
-from freezegun import freeze_time
-
-from khoj.processor.conversation.offline.chat_model import converse_offline
-from khoj.processor.conversation.offline.utils import download_model
-from khoj.utils.constants import default_offline_chat_models
-
-
-@pytest.fixture(scope="session")
-def loaded_model():
- return download_model(default_offline_chat_models[0], max_tokens=5000)
-
-
-freezegun.configure(extend_ignore_list=["transformers"])
-
-
-# Test
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@freeze_time("1984-04-02", ignore=["transformers"])
-def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2):
- # Act
- response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model)
-
- assert len(response) >= 1
-
- assert any(
- [
- "dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
- "dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
- 'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
- 'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
- ]
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
-@pytest.mark.chatquality
-@freeze_time("1984-04-02", ignore=["transformers"])
-def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2):
- # Act
- response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model)
-
- # Assert
- assert len(response) >= 1
- # The user query should be the last question in the response
- assert response[-1] == ["Which countries did I visit last month?"]
- assert any(
- [
- "dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
- "dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
- 'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
- 'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
- ]
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
-@pytest.mark.chatquality
-@freeze_time("1984-04-02", ignore=["transformers"])
-def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2):
- # Act
- response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model)
-
- # Assert
- expected_responses = [
- ("dt>='1984-01-01'", ""),
- ("dt>='1984-01-01'", "dt<'1985-01-01'"),
- ("dt>='1984-01-01'", "dt<='1984-12-31'"),
- ]
- assert len(response) == 1
- assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
- "Expected date filter to limit to 1984 in response but got: " + response[0]
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2):
- # Act
- responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model)
-
- # Assert
- assert len(responses) >= 2
- assert ["the Sun" in response for response in responses]
- assert ["the Moon" in response for response in responses]
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2):
- # Act
- response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model)
-
- # Assert
- expected_responses = ["height", "taller", "shorter", "heights", "who"]
- assert len(response) <= 3
-
- for question in response:
- assert any([expected_response in question.lower() for expected_response in expected_responses]), (
- "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2):
- # Arrange
- message_list = [
- ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
- ]
- query = "Does he have any sons?"
-
- # Act
- response = extract_questions(
- query,
- default_user2,
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- use_history=True,
- )
-
- any_expected_with_barbara = [
- "sibling",
- "brother",
- ]
-
- any_expected_with_anderson = [
- "son",
- "sons",
- "children",
- "family",
- ]
-
- # Assert
- assert len(response) >= 1
- # Ensure the remaining generated search queries use proper nouns and chat history context
- for question in response:
- if "Barbara" in question:
- assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
- "Expected search queries using proper nouns and chat history for context, but got: " + question
- )
- elif "Anderson" in question:
- assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
- "Expected search queries using proper nouns and chat history for context, but got: " + question
- )
- else:
- assert False, (
- "Expected search queries using proper nouns and chat history for context, but got: " + question
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2):
- # Arrange
- message_list = [
- ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
- ]
-
- # Act
- response = extract_questions(
- "Is she a Doctor?",
- default_user2,
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- use_history=True,
- )
-
- expected_responses = [
- "Barbara",
- "Anderson",
- ]
-
- # Assert
- assert len(response) >= 1
- assert any([expected_response in response[0] for expected_response in expected_responses]), (
- "Expected chat actor to mention person's by name, but got: " + response[0]
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
-@pytest.mark.chatquality
-def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2):
- # Arrange
- message_list = [
- ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
- ]
-
- # Act
- response = extract_questions(
- "What was the Pizza place we ate at over there?",
- default_user2,
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- )
-
- # Assert
- expected_responses = [
- ("dt>='2000-04-01'", "dt<'2000-05-01'"),
- ("dt>='2000-04-01'", "dt<='2000-04-30'"),
- ('dt>="2000-04-01"', 'dt<"2000-05-01"'),
- ('dt>="2000-04-01"', 'dt<="2000-04-30"'),
- ]
- assert len(response) == 1
- assert "Masai Mara" in response[0]
- assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
- "Expected date filter to limit to April 2000 in response but got: " + response[0]
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.anyio
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.parametrize(
- "user_query, expected_conversation_commands",
- [
- (
- "Where did I learn to swim?",
- {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text},
- ),
- (
- "Where is the nearest hospital?",
- {"sources": [ConversationCommand.Online], "output": ConversationCommand.Text},
- ),
- (
- "Summarize the wikipedia page on the history of the internet",
- {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text},
- ),
- (
- "How many noble gases are there?",
- {"sources": [ConversationCommand.General], "output": ConversationCommand.Text},
- ),
- (
- "Make a painting incorporating my past diving experiences",
- {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image},
- ),
- (
- "Create a chart of the weather over the next 7 days in Timbuktu",
- {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text},
- ),
- (
- "What's the highest point in this country and have I been there?",
- {"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text},
- ),
- ],
-)
-async def test_select_data_sources_actor_chooses_to_search_notes(
- client_offline_chat, user_query, expected_conversation_commands, default_user2
-):
- # Act
- selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)
-
- # Assert
- assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"])
- assert expected_conversation_commands["output"] == selected_conversation_commands["output"]
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.anyio
-@pytest.mark.django_db(transaction=True)
-async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
- # Arrange
- user_query = "What's the latest in the Israel/Palestine conflict?"
- chat_log = [
- (
- "Let's talk about the current events around the world.",
- "Sure, let's discuss the current events. What would you like to know?",
- [],
- ),
- ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
- ]
- chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log))
-
- # Act
- tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False)
-
- # Assert
- tools = [tool.value for tool in tools]
- assert tools == ["online"]
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
- # Act
- response_gen = converse_offline(
- references=[], # Assume no context retrieved from notes for the user_query
- user_query="Hello, my name is Testatron. Who are you?",
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- expected_responses = ["Khoj", "khoj", "KHOJ"]
- assert len(response) > 0
- assert any([expected_response in response for expected_response in expected_responses]), (
- "Expected assistants name, [K|k]hoj, in response but got: " + response
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
- "Chat actor needs to use context in previous notes and chat history to answer question"
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- (
- "When was I born?",
- "You were born on 1st April 1984.",
- ["Testatron was born on 1st April 1984 in Testville."],
- ),
- ]
-
- # Act
- response_gen = converse_offline(
- references=[], # Assume no context retrieved from notes for the user_query
- user_query="Where was I born?",
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- assert len(response) > 0
- # Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
- assert "Testville" in response
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
- "Chat actor needs to use context across currently retrieved notes and chat history to answer question"
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ]
-
- # Act
- response_gen = converse_offline(
- references=[
- {"compiled": "Testatron was born on 1st April 1984 in Testville."}
- ], # Assume context retrieved from notes for the user_query
- user_query="Where was I born?",
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- assert len(response) > 0
- assert "Testville" in response
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
-@pytest.mark.chatquality
-def test_refuse_answering_unanswerable_question(loaded_model):
- "Chat actor should not try make up answers to unanswerable questions."
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ]
-
- # Act
- response_gen = converse_offline(
- references=[], # Assume no context retrieved from notes for the user_query
- user_query="Where was I born?",
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- expected_responses = [
- "don't know",
- "do not know",
- "no information",
- "do not have",
- "don't have",
- "cannot answer",
- "I'm sorry",
- ]
- assert len(response) > 0
- assert any([expected_response in response for expected_response in expected_responses]), (
- "Expected chat actor to say they don't know in response, but got: " + response
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_answer_requires_current_date_awareness(loaded_model):
- "Chat actor should be able to answer questions relative to current date using provided notes"
- # Arrange
- context = [
- {
- "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
-Expenses:Food:Dining 10.00 USD"""
- },
- {
- "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
-Expenses:Food:Dining 10.00 USD"""
- },
- {
- "compiled": f"""2020-04-01 "SuperMercado" "Bananas"
-Expenses:Food:Groceries 10.00 USD"""
- },
- {
- "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
-Expenses:Food:Dining 10.00 USD"""
- },
- ]
-
- # Act
- response_gen = converse_offline(
- references=context, # Assume context retrieved from notes for the user_query
- user_query="What did I have for Dinner today?",
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- expected_responses = ["tacos", "Tacos"]
- assert len(response) > 0
- assert any([expected_response in response for expected_response in expected_responses]), (
- "Expected [T|t]acos in response, but got: " + response
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
- "Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
- # Arrange
- context = [
- {
- "compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
-Expenses:Food:Dining 10.00 USD"""
- },
- {
- "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
-Expenses:Food:Dining 10.00 USD"""
- },
- {
- "compiled": f"""2020-04-01 "SuperMercado" "Bananas"
-Expenses:Food:Groceries 10.00 USD"""
- },
- {
- "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
-Expenses:Food:Dining 10.00 USD"""
- },
- ]
-
- # Act
- response_gen = converse_offline(
- references=context, # Assume context retrieved from notes for the user_query
- user_query="How much did I spend on dining this year?",
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- assert len(response) > 0
- assert "20" in response
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
- "Chat actor should be able to answer general questions not requiring looking at chat history or notes"
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ("Where was I born?", "You were born Testville.", []),
- ]
-
- # Act
- response_gen = converse_offline(
- references=[], # Assume no context retrieved from notes for the user_query
- user_query="Write a haiku about unit testing in 3 lines",
- chat_history=generate_chat_history(message_list),
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- expected_responses = ["test", "testing"]
- assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
- assert any([expected_response in response.lower() for expected_response in expected_responses]), (
- "Expected [T|t]est in response, but got: " + response
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
- "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
- # Arrange
- context = [
- {
- "compiled": f"""# Ramya
-My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
- },
- {
- "compiled": f"""# Fang
-My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
- },
- {
- "compiled": f"""# Aiyla
-My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
- },
- ]
-
- # Act
- response_gen = converse_offline(
- references=context, # Assume context retrieved from notes for the user_query
- user_query="How many kids does my older sister have?",
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister", "Which one"]
- assert any([expected_response in response for expected_response in expected_responses]), (
- "Expected chat actor to ask for clarification in response, but got: " + response
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-def test_agent_prompt_should_be_used(loaded_model, offline_agent):
- "Chat actor should ask be tuned to think like an accountant based on the agent definition"
- # Arrange
- context = [
- {"compiled": f"""I went to the store and bought some bananas for 2.20"""},
- {"compiled": f"""I went to the store and bought some apples for 1.30"""},
- {"compiled": f"""I went to the store and bought some oranges for 6.00"""},
- ]
-
- # Act
- response_gen = converse_offline(
- references=context, # Assume context retrieved from notes for the user_query
- user_query="What did I buy?",
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert that the model without the agent prompt does not include the summary of purchases
- expected_responses = ["9.50", "9.5"]
- assert all([expected_response not in response for expected_response in expected_responses]), (
- "Expected chat actor to summarize values of purchases" + response
- )
-
- # Act
- response_gen = converse_offline(
- references=context, # Assume context retrieved from notes for the user_query
- user_query="What did I buy?",
- loaded_model=loaded_model,
- agent=offline_agent,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert that the model with the agent prompt does include the summary of purchases
- expected_responses = ["9.50", "9.5"]
- assert any([expected_response in response for expected_response in expected_responses]), (
- "Expected chat actor to summarize values of purchases" + response
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-def test_chat_does_not_exceed_prompt_size(loaded_model):
- "Ensure chat context and response together do not exceed max prompt size for the model"
- # Arrange
- prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
- context = [{"compiled": " ".join([f"{number}" for number in range(2043)])}]
-
- # Act
- response_gen = converse_offline(
- references=context, # Assume context retrieved from notes for the user_query
- user_query="What numbers come after these?",
- loaded_model=loaded_model,
- )
- response = "".join([response_chunk for response_chunk in response_gen])
-
- # Assert
- assert prompt_size_exceeded_error not in response, (
- "Expected chat response to be within prompt limits, but got exceeded error: " + response
- )
diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py
deleted file mode 100644
index 0eb4a0dc..00000000
--- a/tests/test_offline_chat_director.py
+++ /dev/null
@@ -1,726 +0,0 @@
-import urllib.parse
-
-import pytest
-from faker import Faker
-from freezegun import freeze_time
-
-from khoj.database.models import Agent, ChatModel, Entry, KhojUser
-from khoj.processor.conversation import prompts
-from khoj.processor.conversation.utils import message_to_log
-from tests.helpers import ConversationFactory, get_chat_provider
-
-SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
-pytestmark = pytest.mark.skipif(
- SKIP_TESTS,
- reason="Disable in CI to avoid long test runs.",
-)
-
-fake = Faker()
-
-
-# Helpers
-# ----------------------------------------------------------------------------------------------------
-def generate_history(message_list):
- # Generate conversation logs
- conversation_log = {"chat": []}
- for user_message, gpt_message, context in message_list:
- message_to_log(
- user_message,
- gpt_message,
- {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
- chat_history=conversation_log.get("chat", []),
- )
- return conversation_log
-
-
-def create_conversation(message_list, user, agent=None):
- # Generate conversation logs
- conversation_log = generate_history(message_list)
- # Update Conversation Metadata Logs in Database
- return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
-
-
-# Tests
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
- # Act
- query = "Hello, my name is Testatron. Who are you?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["Khoj", "khoj"]
- assert response.status_code == 200
- assert any([expected_response in response_message for expected_response in expected_responses]), (
- "Expected assistants name, [K|k]hoj, in response but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_chat_with_online_content(client_offline_chat):
- # Act
- q = "/online give me the link to paul graham's essay how to do great work"
- response = client_offline_chat.post(f"/api/chat", json={"q": q})
- response_message = response.json()["response"]
-
- # Assert
- expected_responses = [
- "https://paulgraham.com/greatwork.html",
- "https://www.paulgraham.com/greatwork.html",
- "http://www.paulgraham.com/greatwork.html",
- ]
- assert response.status_code == 200
- assert any(
- [expected_response in response_message for expected_response in expected_responses]
- ), f"Expected links: {expected_responses}. Actual response: {response_message}"
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_chat_with_online_webpage_content(client_offline_chat):
- # Act
- q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
- response = client_offline_chat.post(f"/api/chat", json={"q": q})
- response_message = response.json()["response"]
-
- # Assert
- expected_responses = ["185", "1871", "horse"]
- assert response.status_code == 200
- assert any(
- [expected_response in response_message for expected_response in expected_responses]
- ), f"Expected response with {expected_responses}. But actual response had: {response_message}"
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_from_chat_history(client_offline_chat, default_user2):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- q = "What is my name?"
- response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["Testatron", "testatron"]
- assert response.status_code == 200
- assert any([expected_response in response_message for expected_response in expected_responses]), (
- "Expected [T|t]estatron in response but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- (
- "When was I born?",
- "You were born on 1st April 1984.",
- ["Testatron was born on 1st April 1984 in Testville."],
- ),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- q = "Where was Xi Li born?"
- response = client_offline_chat.post(f"/api/chat", json={"q": q})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- assert "Fujiang" in response_message
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- (
- "When was I born?",
- "You were born on 1st April 1984.",
- ["Testatron was born on 1st April 1984 in Testville."],
- ),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- q = "Where was I born?"
- response = client_offline_chat.post(f"/api/chat", json={"q": q})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- # 1. Infer who I am from chat history
- # 2. Infer I was born in Testville from previously retrieved notes
- assert "Testville" in response_message
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
- # Arrange
- message_list = [
- ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- q = "Where was I born?"
- response = client_offline_chat.post(f"/api/chat", json={"q": q})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- # Inference in a multi-turn conversation
- # 1. Infer who I am from chat history
- # 2. Search for notes about when was born
- # 3. Extract where I was born from currently retrieved notes
- assert "Fujiang" in response_message
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
- "Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- q = "Where was I born?"
- response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
- assert response.status_code == 200
- assert any([expected_response in response_message for expected_response in expected_responses]), (
- "Expected chat director to say they don't know in response, but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_using_general_command(client_offline_chat, default_user2):
- # Arrange
- message_list = []
- create_conversation(message_list, default_user2)
-
- # Act
- query = "/general Where was Xi Li born?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- assert "Fujiang" not in response_message
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
- # Arrange
- message_list = []
- create_conversation(message_list, default_user2)
-
- # Act
- query = "/notes Where was Xi Li born?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- assert "Fujiang" in response_message
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_using_file_filter(client_offline_chat, default_user2):
- # Arrange
- no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
- answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
- message_list = []
- create_conversation(message_list, default_user2)
-
- # Act
- no_answer_response = client_offline_chat.post(
- f"/api/chat", json={"q": no_answer_query, "stream": True}
- ).content.decode("utf-8")
- answer_response = client_offline_chat.post(f"/api/chat", json={"q": answer_query, "stream": True}).content.decode(
- "utf-8"
- )
-
- # Assert
- assert "Fujiang" not in no_answer_response
- assert "Fujiang" in answer_response
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
- # Arrange
- message_list = []
- create_conversation(message_list, default_user2)
-
- # Act
- query = urllib.parse.quote("/notes Where was Testatron born?")
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- assert response_message == prompts.no_notes_found.format()
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = []
- conversation = create_conversation(message_list, default_user2)
- # post "Xi Li.markdown" file to the file filters
- file_list = (
- Entry.objects.filter(user=default_user2, file_source="computer")
- .distinct("file_path")
- .values_list("file_path", flat=True)
- )
- # pick the file that has "Xi Li.markdown" in the name
- summarization_file = ""
- for file in file_list:
- if "Birthday Gift for Xiu turning 4.markdown" in file:
- summarization_file = file
- break
- assert summarization_file != ""
-
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters",
- json={"filename": summarization_file, "conversation_id": str(conversation.id)},
- )
-
- # Act
- query = "/summarize"
- response = client_offline_chat.post(
- f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
- )
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response_message != ""
- assert response_message != "No files selected for summarization. Please add files using the section on the left."
-
-
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = []
- conversation = create_conversation(message_list, default_user2)
- # post "Xi Li.markdown" file to the file filters
- file_list = (
- Entry.objects.filter(user=default_user2, file_source="computer")
- .distinct("file_path")
- .values_list("file_path", flat=True)
- )
- # pick the file that has "Xi Li.markdown" in the name
- summarization_file = ""
- for file in file_list:
- if "Birthday Gift for Xiu turning 4.markdown" in file:
- summarization_file = file
- break
- assert summarization_file != ""
-
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters",
- json={"filename": summarization_file, "conversation_id": str(conversation.id)},
- )
-
- # Act
- query = "/summarize tell me about Xiu"
- response = client_offline_chat.post(
- f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
- )
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response_message != ""
- assert response_message != "No files selected for summarization. Please add files using the section on the left."
-
-
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = []
- conversation = create_conversation(message_list, default_user2)
- # post "Xi Li.markdown" file to the file filters
- file_list = (
- Entry.objects.filter(user=default_user2, file_source="computer")
- .distinct("file_path")
- .values_list("file_path", flat=True)
- )
-
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
- )
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
- )
-
- # Act
- query = "/summarize"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
- response_message = response.json()["response"]
-
- # Assert
- assert response_message is not None
-
-
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = []
- conversation = create_conversation(message_list, default_user2)
- # Act
- query = "/summarize"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
- response_message = response.json()["response"]
- # Assert
- assert response_message == "No files selected for summarization. Please add files using the section on the left."
-
-
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
- message_list = []
- conversation1 = create_conversation(message_list, default_user2)
- conversation2 = create_conversation(message_list, default_user2)
-
- file_list = (
- Entry.objects.filter(user=default_user2, file_source="computer")
- .distinct("file_path")
- .values_list("file_path", flat=True)
- )
- summarization_file = ""
- for file in file_list:
- if "Birthday Gift for Xiu turning 4.markdown" in file:
- summarization_file = file
- break
- assert summarization_file != ""
-
- # add file filter to conversation 1.
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters",
- json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
- )
-
- query = "/summarize"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id})
- response_message = response.json()["response"]
-
- # Assert
- assert response_message == "No files selected for summarization. Please add files using the section on the left."
-
- # now make sure that the file filter is still in conversation 1
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id})
- response_message = response.json()["response"]
-
- # Assert
- assert response_message != ""
- assert response_message != "No files selected for summarization. Please add files using the section on the left."
-
-
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = []
- conversation = create_conversation(message_list, default_user2)
- # post "imaginary.markdown" file to the file filters
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters",
- json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
- )
- # Act
- query = "/summarize"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
- response_message = response.json()["response"]
- # Assert
- assert response_message == "No files selected for summarization. Please add files using the section on the left."
-
-
-@pytest.mark.django_db(transaction=True)
-@pytest.mark.chatquality
-def test_summarize_diff_user_file(
- client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
-):
- # Arrange
- message_list = []
- conversation = create_conversation(message_list, default_user2)
- # Get the pdf file called singlepage.pdf
- file_list = (
- Entry.objects.filter(user=default_user, file_source="computer")
- .distinct("file_path")
- .values_list("file_path", flat=True)
- )
- summarization_file = ""
- for file in file_list:
- if "singlepage.pdf" in file:
- summarization_file = file
- break
- assert summarization_file != ""
- # add singlepage.pdf to the file filters
- response = client_offline_chat.post(
- "api/chat/conversation/file-filters",
- json={"filename": summarization_file, "conversation_id": str(conversation.id)},
- )
-
- # Act
- query = "/summarize"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
- response_message = response.json()["response"]
-
- # Assert
- assert response_message == "No files selected for summarization. Please add files using the section on the left."
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-@freeze_time("2023-04-01", ignore=["transformers"])
-def test_answer_requires_current_date_awareness(client_offline_chat):
- "Chat actor should be able to answer questions relative to current date using provided notes"
- # Act
- query = "Where did I have lunch today?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["Arak", "Medellin"]
- assert response.status_code == 200
- assert any([expected_response in response_message for expected_response in expected_responses]), (
- "Expected chat director to say Arak, Medellin, but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-@freeze_time("2023-04-01", ignore=["transformers"])
-def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
- "Chat director should be able to answer questions that require date aware aggregation across multiple notes"
- # Act
- query = "How much did I spend on dining this year?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- assert "26" in response_message
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ("Where was I born?", "You were born Testville.", []),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- query = "Write a haiku about unit testing. Do not say anything else."
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["test", "Test"]
- assert response.status_code == 200
- assert len(response_message.splitlines()) == 3 # haikus are 3 lines long
- assert any([expected_response in response_message for expected_response in expected_responses]), (
- "Expected [T|t]est in response, but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
- # Act
- query = "What is the name of Namitas older son"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = [
- "which of them is the older",
- "which one is older",
- "which of them is older",
- "which one is the older",
- ]
- assert response.status_code == 200
- assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
- "Expected chat director to ask for clarification in response, but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ("Where was I born?", "You were born Testville.", []),
- ]
- create_conversation(message_list, default_user2)
-
- # Act
- query = "What is my name?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["Testatron", "testatron"]
- assert response.status_code == 200
- assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
- "Expected [T|t]estatron in response, but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ("What's my favorite color", "Your favorite color is green.", []),
- ("Where was I born?", "You were born Testville.", []),
- ]
- message_list2 = [
- ("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 14th August 1947.", []),
- ("What's my favorite color", "Your favorite color is maroon.", []),
- ("Where was I born?", "You were born in a potato farm.", []),
- ]
- conversation = create_conversation(message_list, default_user2)
- create_conversation(message_list2, default_user2)
-
- # Act
- query = "What is my favorite color?"
- response = client_offline_chat.post(
- f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
- )
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["green"]
- assert response.status_code == 200
- assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
- "Expected green in response, but got: " + response_message
- )
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet")
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_in_chat_history_by_conversation_id_with_agent(
- client_offline_chat, default_user2: KhojUser, offline_agent: Agent
-):
- # Arrange
- message_list = [
- ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
- ("When was I born?", "You were born on 1st April 1984.", []),
- ("What's my favorite color", "Your favorite color is green.", []),
- ("Where was I born?", "You were born Testville.", []),
- ("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
- ]
- conversation = create_conversation(message_list, default_user2, offline_agent)
-
- # Act
- query = "/general What did I eat for breakfast?"
- response = client_offline_chat.post(
- f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
- )
- response_message = response.content.decode("utf-8")
-
- # Assert that agent only responds with the summary of spending
- expected_responses = ["13.00", "13", "13.0", "thirteen"]
- assert response.status_code == 200
- assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
- "Expected green in response, but got: " + response_message
- )
-
-
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_chat_history_very_long(client_offline_chat, default_user2):
- # Arrange
- message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
- create_conversation(message_list, default_user2)
-
- # Act
- query = "What is my name?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- assert response.status_code == 200
- assert len(response_message) > 0
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.chatquality
-@pytest.mark.django_db(transaction=True)
-def test_answer_requires_multiple_independent_searches(client_offline_chat):
- "Chat director should be able to answer by doing multiple independent searches for required information"
- # Act
- query = "Is Xi older than Namita?"
- response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
- response_message = response.content.decode("utf-8")
-
- # Assert
- expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
- assert response.status_code == 200
- assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
- "Expected Xi is older than Namita, but got: " + response_message
- )
diff --git a/tests/test_org_to_entries.py b/tests/test_org_to_entries.py
index d5dcdbd2..0196ef6c 100644
--- a/tests/test_org_to_entries.py
+++ b/tests/test_org_to_entries.py
@@ -4,9 +4,8 @@ import time
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.text_to_entries import TextToEntries
-from khoj.utils.fs_syncer import get_org_files
from khoj.utils.helpers import is_none_or_empty
-from khoj.utils.rawconfig import Entry, TextContentConfig
+from khoj.utils.rawconfig import Entry
def test_configure_indexing_heading_only_entries(tmp_path):
@@ -330,46 +329,6 @@ def test_file_with_no_headings_to_entry(tmp_path):
assert len(entries[1]) == 1
-def test_get_org_files(tmp_path):
- "Ensure Org files specified via input-filter, input-files extracted"
- # Arrange
- # Include via input-filter globs
- group1_file1 = create_file(tmp_path, filename="group1-file1.org")
- group1_file2 = create_file(tmp_path, filename="group1-file2.org")
- group2_file1 = create_file(tmp_path, filename="group2-file1.org")
- group2_file2 = create_file(tmp_path, filename="group2-file2.org")
- # Include via input-file field
- orgfile1 = create_file(tmp_path, filename="orgfile1.org")
- # Not included by any filter
- create_file(tmp_path, filename="orgfile2.org")
- create_file(tmp_path, filename="text1.txt")
-
- expected_files = set(
- [
- os.path.join(tmp_path, file.name)
- for file in [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]
- ]
- )
-
- # Setup input-files, input-filters
- input_files = [tmp_path / "orgfile1.org"]
- input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
-
- org_config = TextContentConfig(
- input_files=input_files,
- input_filter=[str(filter) for filter in input_filter],
- compressed_jsonl=tmp_path / "test.jsonl",
- embeddings_file=tmp_path / "test_embeddings.jsonl",
- )
-
- # Act
- extracted_org_files = get_org_files(org_config)
-
- # Assert
- assert len(extracted_org_files) == 5
- assert set(extracted_org_files.keys()) == expected_files
-
-
def test_extract_entries_with_different_level_headings(tmp_path):
"Extract org entries with different level headings."
# Arrange
diff --git a/tests/test_pdf_to_entries.py b/tests/test_pdf_to_entries.py
index a62eca8b..d7336fdc 100644
--- a/tests/test_pdf_to_entries.py
+++ b/tests/test_pdf_to_entries.py
@@ -4,8 +4,6 @@ import re
import pytest
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
-from khoj.utils.fs_syncer import get_pdf_files
-from khoj.utils.rawconfig import TextContentConfig
def test_single_page_pdf_to_jsonl():
@@ -61,43 +59,6 @@ def test_ocr_page_pdf_to_jsonl():
assert re.search(expected_str_with_variable_spaces, raw_entry) is not None
-def test_get_pdf_files(tmp_path):
- "Ensure Pdf files specified via input-filter, input-files extracted"
- # Arrange
- # Include via input-filter globs
- group1_file1 = create_file(tmp_path, filename="group1-file1.pdf")
- group1_file2 = create_file(tmp_path, filename="group1-file2.pdf")
- group2_file1 = create_file(tmp_path, filename="group2-file1.pdf")
- group2_file2 = create_file(tmp_path, filename="group2-file2.pdf")
- # Include via input-file field
- file1 = create_file(tmp_path, filename="document.pdf")
- # Not included by any filter
- create_file(tmp_path, filename="not-included-document.pdf")
- create_file(tmp_path, filename="not-included-text.txt")
-
- expected_files = set(
- [os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
- )
-
- # Setup input-files, input-filters
- input_files = [tmp_path / "document.pdf"]
- input_filter = [tmp_path / "group1*.pdf", tmp_path / "group2*.pdf"]
-
- pdf_config = TextContentConfig(
- input_files=input_files,
- input_filter=[str(path) for path in input_filter],
- compressed_jsonl=tmp_path / "test.jsonl",
- embeddings_file=tmp_path / "test_embeddings.jsonl",
- )
-
- # Act
- extracted_pdf_files = get_pdf_files(pdf_config)
-
- # Assert
- assert len(extracted_pdf_files) == 5
- assert set(extracted_pdf_files.keys()) == expected_files
-
-
# Helper Functions
def create_file(tmp_path, entry=None, filename="document.pdf"):
pdf_file = tmp_path / filename
diff --git a/tests/test_plaintext_to_entries.py b/tests/test_plaintext_to_entries.py
index a085b2b5..558832d3 100644
--- a/tests/test_plaintext_to_entries.py
+++ b/tests/test_plaintext_to_entries.py
@@ -1,27 +1,20 @@
-import os
from pathlib import Path
+from textwrap import dedent
-from khoj.database.models import KhojUser, LocalPlaintextConfig
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
-from khoj.utils.fs_syncer import get_plaintext_files
-from khoj.utils.rawconfig import TextContentConfig
-def test_plaintext_file(tmp_path):
+def test_plaintext_file():
"Convert files with no heading to jsonl."
# Arrange
raw_entry = f"""
Hi, I am a plaintext file and I have some plaintext words.
"""
- plaintextfile = create_file(tmp_path, raw_entry)
+ plaintextfile = "test.txt"
+ data = {plaintextfile: raw_entry}
# Act
# Extract Entries from specified plaintext files
-
- data = {
- f"{plaintextfile}": raw_entry,
- }
-
entries = PlaintextToEntries.extract_plaintext_entries(data)
# Convert each entry.file to absolute path to make them JSON serializable
@@ -37,59 +30,20 @@ def test_plaintext_file(tmp_path):
assert entries[1][0].compiled == f"{plaintextfile}\n{raw_entry}"
-def test_get_plaintext_files(tmp_path):
- "Ensure Plaintext files specified via input-filter, input-files extracted"
- # Arrange
- # Include via input-filter globs
- group1_file1 = create_file(tmp_path, filename="group1-file1.md")
- group1_file2 = create_file(tmp_path, filename="group1-file2.md")
-
- group2_file1 = create_file(tmp_path, filename="group2-file1.markdown")
- group2_file2 = create_file(tmp_path, filename="group2-file2.markdown")
- group2_file4 = create_file(tmp_path, filename="group2-file4.html")
- # Include via input-file field
- file1 = create_file(tmp_path, filename="notes.txt")
- # Include unsupported file types
- create_file(tmp_path, filename="group2-unincluded.py")
- create_file(tmp_path, filename="group2-unincluded.csv")
- create_file(tmp_path, filename="group2-unincluded.csv")
- create_file(tmp_path, filename="group2-file3.mbox")
- # Not included by any filter
- create_file(tmp_path, filename="not-included-markdown.md")
- create_file(tmp_path, filename="not-included-text.txt")
-
- expected_files = set(
- [
- os.path.join(tmp_path, file.name)
- for file in [group1_file1, group1_file2, group2_file1, group2_file2, group2_file4, file1]
- ]
- )
-
- # Setup input-files, input-filters
- input_files = [tmp_path / "notes.txt"]
- input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.*"]
-
- plaintext_config = TextContentConfig(
- input_files=input_files,
- input_filter=[str(filter) for filter in input_filter],
- compressed_jsonl=tmp_path / "test.jsonl",
- embeddings_file=tmp_path / "test_embeddings.jsonl",
- )
-
- # Act
- extracted_plaintext_files = get_plaintext_files(plaintext_config)
-
- # Assert
- assert len(extracted_plaintext_files) == len(expected_files)
- assert set(extracted_plaintext_files.keys()) == set(expected_files)
-
-
-def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
+def test_parse_html_plaintext_file(tmp_path):
"Ensure HTML files are parsed correctly"
# Arrange
- # Setup input-files, input-filters
- config = LocalPlaintextConfig.objects.filter(user=default_user).first()
- extracted_plaintext_files = get_plaintext_files(config=config)
+ raw_entry = dedent(
+ f"""
+
+ Test HTML
+
+ Test content
+
+
+ """
+ )
+ extracted_plaintext_files = {"test.html": raw_entry}
# Act
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
diff --git a/tests/test_text_search.py b/tests/test_text_search.py
index 712f4aba..9e532429 100644
--- a/tests/test_text_search.py
+++ b/tests/test_text_search.py
@@ -2,23 +2,16 @@
import asyncio
import logging
import os
-from pathlib import Path
import pytest
from khoj.database.adapters import EntryAdapters
-from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
-from khoj.processor.content.docx.docx_to_entries import DocxToEntries
+from khoj.database.models import Entry, GithubConfig, KhojUser
from khoj.processor.content.github.github_to_entries import GithubToEntries
-from khoj.processor.content.images.image_to_entries import ImageToEntries
-from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
-from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
-from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.search_type import text_search
-from khoj.utils.fs_syncer import collect_files, get_org_files
-from khoj.utils.rawconfig import ContentConfig, SearchConfig
+from tests.helpers import get_index_files, get_sample_data
logger = logging.getLogger(__name__)
@@ -26,53 +19,20 @@ logger = logging.getLogger(__name__)
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
- # Arrange
- # Ensure file mentioned in org.input-files is missing
- single_new_file = Path(org_config_with_only_new_file.input_files[0])
- single_new_file.unlink()
-
- # Act
- # Generate notes embeddings during asymmetric setup
- with pytest.raises(FileNotFoundError):
- get_org_files(org_config_with_only_new_file)
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.django_db
-def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
- # Arrange
- orgfile = tmp_path / "directory.org" / "file.org"
- orgfile.parent.mkdir()
- with open(orgfile, "w") as f:
- f.write("* Heading\n- List item\n")
-
- LocalOrgConfig.objects.create(
- input_filter=[f"{tmp_path}/**/*"],
- input_files=None,
- user=default_user,
- )
-
- # Act
- org_files = collect_files(user=default_user)["org"]
-
- # Assert
- # should return orgfile and not raise IsADirectoryError
- assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.django_db
-def test_text_search_setup_with_empty_file_creates_no_entries(
- org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
-):
+def test_text_search_setup_with_empty_file_creates_no_entries(search_config, default_user: KhojUser):
# Arrange
+ initial_data = {
+ "test.org": "* First heading\nFirst content",
+ "test2.org": "* Second heading\nSecond content",
+ }
+ text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
existing_entries = Entry.objects.filter(user=default_user).count()
- data = get_org_files(org_config_with_only_new_file)
+
+ final_data = {"new_file.org": ""}
# Act
# Generate notes embeddings during asymmetric setup
- text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
+ text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
# Assert
updated_entries = Entry.objects.filter(user=default_user).count()
@@ -84,13 +44,14 @@ def test_text_search_setup_with_empty_file_creates_no_entries(
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_text_indexer_deletes_embedding_before_regenerate(
- content_config: ContentConfig, default_user: KhojUser, caplog
-):
+def test_text_indexer_deletes_embedding_before_regenerate(search_config, default_user: KhojUser, caplog):
# Arrange
+ data = {
+ "test1.org": "* Test heading\nTest content",
+ "test2.org": "* Another heading\nAnother content",
+ }
+ text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
existing_entries = Entry.objects.filter(user=default_user).count()
- org_config = LocalOrgConfig.objects.filter(user=default_user).first()
- data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
@@ -107,11 +68,10 @@ def test_text_indexer_deletes_embedding_before_regenerate(
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
+def test_text_index_same_if_content_unchanged(search_config, default_user: KhojUser, caplog):
# Arrange
existing_entries = Entry.objects.filter(user=default_user)
- org_config = LocalOrgConfig.objects.filter(user=default_user).first()
- data = get_org_files(org_config)
+ data = {"test.org": "* Test heading\nTest content"}
# Act
# Generate initial notes embeddings during asymmetric setup
@@ -136,20 +96,14 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-@pytest.mark.anyio
-# @pytest.mark.asyncio
-async def test_text_search(search_config: SearchConfig):
+@pytest.mark.asyncio
+async def test_text_search(search_config):
# Arrange
- default_user = await KhojUser.objects.acreate(
+ default_user, _ = await KhojUser.objects.aget_or_create(
username="test_user", password="test_password", email="test@example.com"
)
- org_config = await LocalOrgConfig.objects.acreate(
- input_files=None,
- input_filter=["tests/data/org/*.org"],
- index_heading_entries=False,
- user=default_user,
- )
- data = get_org_files(org_config)
+ # Get some sample org data to index
+ data = get_sample_data("org")
loop = asyncio.get_event_loop()
await loop.run_in_executor(
@@ -175,17 +129,15 @@ async def test_text_search(search_config: SearchConfig):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
+def test_entry_chunking_by_max_tokens(tmp_path, search_config, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
- new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
- with open(new_file_to_index, "w") as f:
- f.write(f"* Entry more than {max_tokens} words\n")
- for index in range(max_tokens + 1):
- f.write(f"{index} ")
-
- data = get_org_files(org_config_with_only_new_file)
+ new_file_to_index = tmp_path / "test.org"
+ content = f"* Entry more than {max_tokens} words\n"
+ for index in range(max_tokens + 1):
+ content += f"{index} "
+ data = {str(new_file_to_index): content}
# Act
# reload embeddings, entries, notes model after adding new org-mode file
@@ -200,9 +152,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_entry_chunking_by_max_tokens_not_full_corpus(
- org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
-):
+def test_entry_chunking_by_max_tokens_not_full_corpus(tmp_path, search_config, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
data = {
@@ -231,13 +181,11 @@ conda activate khoj
)
max_tokens = 256
- new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
- with open(new_file_to_index, "w") as f:
- f.write(f"* Entry more than {max_tokens} words\n")
- for index in range(max_tokens + 1):
- f.write(f"{index} ")
-
- data = get_org_files(org_config_with_only_new_file)
+ new_file_to_index = tmp_path / "test.org"
+ content = f"* Entry more than {max_tokens} words\n"
+ for index in range(max_tokens + 1):
+ content += f"{index} "
+ data = {str(new_file_to_index): content}
# Act
# reload embeddings, entries, notes model after adding new org-mode file
@@ -257,34 +205,34 @@ conda activate khoj
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
+def test_regenerate_index_with_new_entry(search_config, default_user: KhojUser):
# Arrange
+ # Initial indexed files
+ text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=True, user=default_user)
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
- org_config = LocalOrgConfig.objects.filter(user=default_user).first()
- initial_data = get_org_files(org_config)
- # append org-mode entry to first org input file in config
- org_config.input_files = [f"{new_org_file}"]
- with open(new_org_file, "w") as f:
- f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
-
- final_data = get_org_files(org_config)
-
- # Act
- text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
+ # Regenerate index with only files from test data set
+ files_to_index = get_index_files()
+ text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
+ # Act
+ # Update index with the new file
+ new_file = "test.org"
+ new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
+ files_to_index[new_file] = new_entry
+
# regenerate notes jsonl, model embeddings and model to include entry from new file
- text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
+ text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
for entry in updated_entries1:
assert entry in updated_entries2
- assert not any([new_org_file.name in entry for entry in updated_entries1])
- assert not any([new_org_file.name in entry for entry in existing_entries])
- assert any([new_org_file.name in entry for entry in updated_entries2])
+ assert not any([new_file in entry for entry in updated_entries1])
+ assert not any([new_file in entry for entry in existing_entries])
+ assert any([new_file in entry for entry in updated_entries2])
assert any(
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
@@ -294,28 +242,24 @@ def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_update_index_with_duplicate_entries_in_stable_order(
- org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
-):
+def test_update_index_with_duplicate_entries_in_stable_order(tmp_path, search_config, default_user: KhojUser):
# Arrange
+ initial_data = get_sample_data("org")
+ text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
- new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
+ new_file_to_index = tmp_path / "test.org"
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
- with open(new_file_to_index, "w") as f:
- f.write(f"{new_entry}{new_entry}")
-
- data = get_org_files(org_config_with_only_new_file)
+ # Initial data with duplicate entries
+ data = {str(new_file_to_index): f"{new_entry}{new_entry}"}
# Act
# generate embeddings, entries, notes model from scratch after adding new org-mode file
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
- data = get_org_files(org_config_with_only_new_file)
-
- # update embeddings, entries, notes model with no new changes
+ # idempotent indexing when data unchanged
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
@@ -324,6 +268,7 @@ def test_update_index_with_duplicate_entries_in_stable_order(
for entry in existing_entries:
assert entry not in updated_entries1
+ # verify the second indexing update has same entries and ordering as first
for entry in updated_entries1:
assert entry in updated_entries2
@@ -334,22 +279,17 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
+def test_update_index_with_deleted_entry(tmp_path, search_config, default_user: KhojUser):
# Arrange
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
- new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
- # Insert org-mode entries with same compiled form into new org file
+ new_file_to_index = tmp_path / "test.org"
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
- with open(new_file_to_index, "w") as f:
- f.write(f"{new_entry}{new_entry} -- Tatooine")
- initial_data = get_org_files(org_config_with_only_new_file)
- # update embeddings, entries, notes model after removing an entry from the org file
- with open(new_file_to_index, "w") as f:
- f.write(f"{new_entry}")
-
- final_data = get_org_files(org_config_with_only_new_file)
+ # Initial data with two entries
+ initial_data = {str(new_file_to_index): f"{new_entry}{new_entry} -- Tatooine"}
+ # Final data with only first entry, with second entry removed
+ final_data = {str(new_file_to_index): f"{new_entry}"}
# Act
# load embeddings, entries, notes model after adding new org file with 2 entries
@@ -375,29 +315,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
+def test_update_index_with_new_entry(search_config, default_user: KhojUser):
# Arrange
- existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
- org_config = LocalOrgConfig.objects.filter(user=default_user).first()
- data = get_org_files(org_config)
- text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
+ # Initial indexed files
+ text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=True, user=default_user)
+ old_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
- # append org-mode entry to first org input file in config
- with open(new_org_file, "w") as f:
- new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
- f.write(new_entry)
-
- data = get_org_files(org_config)
+ # Regenerate index with only files from test data set
+ files_to_index = get_index_files()
+ new_entries = text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user)
# Act
- # update embeddings, entries with the newly added note
- text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
- updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
+ # Update index with the new file
+ new_file = "test.org"
+ new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
+ final_data = {new_file: new_entry}
+
+ text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
+ updated_new_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
# Assert
- for entry in existing_entries:
- assert entry not in updated_entries1
- assert len(updated_entries1) == len(existing_entries) + 1
+ for old_entry in old_entries:
+ assert old_entry not in updated_new_entries
+ assert len(updated_new_entries) == len(new_entries) + 1
verify_embeddings(3, default_user)
@@ -409,9 +349,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
(OrgToEntries),
],
)
-def test_update_index_with_deleted_file(
- org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser
-):
+def test_update_index_with_deleted_file(text_to_entries: TextToEntries, search_config, default_user: KhojUser):
"Delete entries associated with new file when file path with empty content passed."
# Arrange
file_to_index = "test"
@@ -446,7 +384,7 @@ def test_update_index_with_deleted_file(
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
-def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
+def test_text_search_setup_github(search_config, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py
index ebd6cccf..5333e17f 100644
--- a/tests/test_word_filter.py
+++ b/tests/test_word_filter.py
@@ -1,6 +1,12 @@
# Application Packages
from khoj.search_filter.word_filter import WordFilter
-from khoj.utils.rawconfig import Entry
+
+
+# Mock Entry class for testing
+class Entry:
+ def __init__(self, compiled="", raw=""):
+ self.compiled = compiled
+ self.raw = raw
# Test