mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Drop Server Side Indexer, Native Offline Chat, Old Migration Scripts (#1212)
### Overview Make server leaner to increase development speed. Remove old indexing code and the native offline chat which was hard to maintain. - The native offline chat module was written when the local ai model api ecosystem wasn't mature. Now it is. Reuse that. - Offline chat requires GPU for usable speeds. Decoupling offline chat from Khoj server is the recommended way to go for practical inference speeds (e.g Ollama on machine, Khoj in docker etc.) ### Details - Drop old code to index files on server filesystem. Clean cli, init paths. - Drop native offline chat support with llama-cpp-python. Use established local ai APIs like Llama.cpp Server, Ollama, vLLM etc. - Drop old pre 1.0 khoj config migration scripts - Update test setup to index test data after old indexing code removed.
This commit is contained in:
@@ -18,8 +18,8 @@ ENV PATH="/opt/venv/bin:${PATH}"
|
||||
COPY pyproject.toml README.md ./
|
||||
|
||||
# Setup python environment
|
||||
# Use the pre-built llama-cpp-python, torch cpu wheel
|
||||
ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://abetlen.github.io/llama-cpp-python/whl/cpu" \
|
||||
# Use the pre-built torch cpu wheel
|
||||
ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" \
|
||||
# Avoid downloading unused cuda specific python packages
|
||||
CUDA_VISIBLE_DEVICES="" \
|
||||
# Use static version to build app without git dependency
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -64,8 +64,6 @@ jobs:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
run: |
|
||||
apt update && apt install -y git libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
|
||||
# required by llama-cpp-python prebuilt wheels
|
||||
apt install -y musl-dev && ln -s /usr/lib/x86_64-linux-musl/libc.so /lib/libc.musl-x86_64.so.1
|
||||
|
||||
- name: ⬇️ Install Postgres
|
||||
env:
|
||||
|
||||
@@ -20,7 +20,7 @@ Add all the agents you want to use for your different use-cases like Writer, Res
|
||||
### Chat Model Options
|
||||
Add all the chat models you want to try, use and switch between for your different use-cases. For each chat model you add:
|
||||
- `Chat model`: The name of an [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/en/docs/about-claude/models#model-names), [Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) or [Offline](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf) chat model.
|
||||
- `Model type`: The chat model provider like `OpenAI`, `Offline`.
|
||||
- `Model type`: The chat model provider like `OpenAI`, `Google`.
|
||||
- `Vision enabled`: Set to `true` if your model supports vision. This is currently only supported for vision capable OpenAI models like `gpt-4o`
|
||||
- `Max prompt size`, `Subscribed max prompt size`: These are optional fields. They are used to truncate the context to the maximum context size that can be passed to the model. This can help with accuracy and cost-saving.<br />
|
||||
- `Tokenizer`: This is an optional field. It is used to accurately count tokens and truncate context passed to the chat model to stay within the models max prompt size.
|
||||
|
||||
@@ -18,10 +18,6 @@ import TabItem from '@theme/TabItem';
|
||||
These are the general setup instructions for self-hosted Khoj.
|
||||
You can install the Khoj server using either [Docker](?server=docker) or [Pip](?server=pip).
|
||||
|
||||
:::info[Offline Model + GPU]
|
||||
To use the offline chat model with your GPU, we recommend using the Docker setup with Ollama . You can also use the local Khoj setup via the Python package directly.
|
||||
:::
|
||||
|
||||
:::info[First Run]
|
||||
Restart your Khoj server after the first run to ensure all settings are applied correctly.
|
||||
:::
|
||||
@@ -225,10 +221,6 @@ To start Khoj automatically in the background use [Task scheduler](https://www.w
|
||||
You can now open the web app at http://localhost:42110 and start interacting!<br />
|
||||
Nothing else is necessary, but you can customize your setup further by following the steps below.
|
||||
|
||||
:::info[First Message to Offline Chat Model]
|
||||
The offline chat model gets downloaded when you first send a message to it. The download can take a few minutes! Subsequent messages should be faster.
|
||||
:::
|
||||
|
||||
### Add Chat Models
|
||||
<h4>Login to the Khoj Admin Panel</h4>
|
||||
Go to http://localhost:42110/server/admin and login with the admin credentials you setup during installation.
|
||||
@@ -301,13 +293,14 @@ Offline chat stays completely private and can work without internet using any op
|
||||
- A Nvidia, AMD GPU or a Mac M1+ machine would significantly speed up chat responses
|
||||
:::
|
||||
|
||||
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
|
||||
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
|
||||
3. Set the `chat-model` field to the name of your preferred chat model
|
||||
- Make sure the `model-type` is set to `Offline`
|
||||
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).
|
||||
5. Restart the Khoj server and [start chatting](http://localhost:42110) with your new offline model!
|
||||
</TabItem>
|
||||
1. Install any Openai API compatible local ai model server like [llama-cpp-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server), Ollama, vLLM etc.
|
||||
2. Add an [ai model api](http://localhost:42110/server/admin/database/aimodelapi/add/) on the admin panel
|
||||
- Set the `api url` field to the url of your local ai model provider like `http://localhost:11434/v1/` for Ollama
|
||||
3. Restart the Khoj server to load models available on your local ai model provider
|
||||
- If that doesn't work, you'll need to manually add available [chat model](http://localhost:42110/server/admin/database/chatmodel/add) in the admin panel.
|
||||
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings)
|
||||
5. [Start chatting](http://localhost:42110) with your local AI!
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
:::tip[Multiple Chat Models]
|
||||
|
||||
@@ -65,7 +65,6 @@ dependencies = [
|
||||
"django == 5.1.10",
|
||||
"django-unfold == 0.42.0",
|
||||
"authlib == 1.2.1",
|
||||
"llama-cpp-python == 0.2.88",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.28.1",
|
||||
"pgvector == 0.2.4",
|
||||
|
||||
@@ -50,13 +50,11 @@ from khoj.database.adapters import (
|
||||
)
|
||||
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.api_content import configure_content, configure_search
|
||||
from khoj.routers.api_content import configure_content
|
||||
from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -232,14 +230,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
|
||||
def initialize_server(config: Optional[FullConfig]):
|
||||
try:
|
||||
configure_server(config, init=True)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
|
||||
def clean_connections(func):
|
||||
"""
|
||||
A decorator that ensures that Django database connections that have become unusable, or are obsolete, are closed
|
||||
@@ -260,19 +250,7 @@ def clean_connections(func):
|
||||
return func_wrapper
|
||||
|
||||
|
||||
def configure_server(
|
||||
config: FullConfig,
|
||||
regenerate: bool = False,
|
||||
search_type: Optional[SearchType] = None,
|
||||
init=False,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
# Update Config
|
||||
if config == None:
|
||||
logger.info(f"Initializing with default config.")
|
||||
config = FullConfig()
|
||||
state.config = config
|
||||
|
||||
def initialize_server():
|
||||
if ConversationAdapters.has_valid_ai_model_api():
|
||||
ai_model_api = ConversationAdapters.get_ai_model_api()
|
||||
state.openai_client = openai.OpenAI(api_key=ai_model_api.api_key, base_url=ai_model_api.api_base_url)
|
||||
@@ -309,43 +287,33 @@ def configure_server(
|
||||
)
|
||||
|
||||
state.SearchType = configure_search_types()
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
setup_default_agent(user)
|
||||
setup_default_agent()
|
||||
|
||||
message = (
|
||||
"📡 Telemetry disabled"
|
||||
if telemetry_disabled(state.config.app, state.telemetry_disabled)
|
||||
else "📡 Telemetry enabled"
|
||||
)
|
||||
message = "📡 Telemetry disabled" if state.telemetry_disabled else "📡 Telemetry enabled"
|
||||
logger.info(message)
|
||||
|
||||
if not init:
|
||||
initialize_content(user, regenerate, search_type)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load some search models: {e}", exc_info=True)
|
||||
|
||||
|
||||
def setup_default_agent(user: KhojUser):
|
||||
AgentAdapters.create_default_agent(user)
|
||||
def setup_default_agent():
|
||||
AgentAdapters.create_default_agent()
|
||||
|
||||
|
||||
def initialize_content(user: KhojUser, regenerate: bool, search_type: Optional[SearchType] = None):
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
logger.info("📬 Updating content index...")
|
||||
all_files = collect_files(user=user)
|
||||
status = configure_content(
|
||||
user,
|
||||
all_files,
|
||||
regenerate,
|
||||
search_type,
|
||||
)
|
||||
if not status:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
except Exception as e:
|
||||
raise e
|
||||
try:
|
||||
logger.info("📬 Updating content index...")
|
||||
status = configure_content(
|
||||
user,
|
||||
{},
|
||||
regenerate,
|
||||
search_type,
|
||||
)
|
||||
if not status:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def configure_routes(app):
|
||||
@@ -438,8 +406,7 @@ def configure_middleware(app, ssl_enabled: bool = False):
|
||||
|
||||
def update_content_index():
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
success = configure_content(user, all_files)
|
||||
success = configure_content(user, {})
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
logger.info("📪 Content index updated via Scheduler")
|
||||
@@ -464,7 +431,7 @@ def configure_search_types():
|
||||
@schedule.repeat(schedule.every(2).minutes)
|
||||
@clean_connections
|
||||
def upload_telemetry():
|
||||
if telemetry_disabled(state.config.app, state.telemetry_disabled) or not state.telemetry:
|
||||
if state.telemetry_disabled or not state.telemetry:
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -72,7 +72,6 @@ from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import (
|
||||
clean_object_for_db,
|
||||
clean_text_for_db,
|
||||
@@ -789,8 +788,8 @@ class AgentAdapters:
|
||||
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
||||
|
||||
@staticmethod
|
||||
def create_default_agent(user: KhojUser):
|
||||
default_chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||
def create_default_agent():
|
||||
default_chat_model = ConversationAdapters.get_default_chat_model(user=None)
|
||||
if default_chat_model is None:
|
||||
logger.info("No default conversation config found, skipping default agent creation")
|
||||
return None
|
||||
@@ -1553,14 +1552,6 @@ class ConversationAdapters:
|
||||
if chat_model is None:
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model()
|
||||
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
return chat_model
|
||||
|
||||
if (
|
||||
chat_model.model_type
|
||||
in [
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# Generated by Django 5.1.10 on 2025-07-19 21:33
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0091_chatmodel_friendly_name_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="chatmodel",
|
||||
name="model_type",
|
||||
field=models.CharField(
|
||||
choices=[("openai", "Openai"), ("anthropic", "Anthropic"), ("google", "Google")],
|
||||
default="google",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="chatmodel",
|
||||
name="name",
|
||||
field=models.CharField(default="gemini-2.5-flash", max_length=200),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="speechtotextmodeloptions",
|
||||
name="model_name",
|
||||
field=models.CharField(default="whisper-1", max_length=200),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="speechtotextmodeloptions",
|
||||
name="model_type",
|
||||
field=models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,36 @@
|
||||
# Generated by Django 5.1.10 on 2025-07-25 23:30
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name="localorgconfig",
|
||||
name="user",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="localpdfconfig",
|
||||
name="user",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="localplaintextconfig",
|
||||
name="user",
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name="LocalMarkdownConfig",
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name="LocalOrgConfig",
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name="LocalPdfConfig",
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name="LocalPlaintextConfig",
|
||||
),
|
||||
]
|
||||
@@ -220,16 +220,15 @@ class PriceTier(models.TextChoices):
|
||||
class ChatModel(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
ANTHROPIC = "anthropic"
|
||||
GOOGLE = "google"
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||
name = models.CharField(max_length=200, default="gemini-2.5-flash")
|
||||
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.GOOGLE)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
vision_enabled = models.BooleanField(default=False)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
@@ -489,34 +488,6 @@ class ServerChatSettings(DbBaseModel):
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class LocalOrgConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalMarkdownConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalPdfConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalPlaintextConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class SearchModelConfig(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
TEXT = "text"
|
||||
@@ -605,11 +576,10 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
class SpeechToTextModelOptions(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
||||
model_name = models.CharField(max_length=200, default="base")
|
||||
model_name = models.CharField(max_length=200, default="whisper-1")
|
||||
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
@@ -138,10 +138,10 @@ def run(should_start_server=True):
|
||||
initialization(not args.non_interactive)
|
||||
|
||||
# Create app directory, if it doesn't exist
|
||||
state.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
state.log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set Log File
|
||||
fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
|
||||
fh = logging.FileHandler(state.log_file, encoding="utf-8")
|
||||
fh.setLevel(logging.DEBUG)
|
||||
logger.addHandler(fh)
|
||||
|
||||
@@ -194,7 +194,7 @@ def run(should_start_server=True):
|
||||
# Configure Middleware
|
||||
configure_middleware(app, state.ssl_config)
|
||||
|
||||
initialize_server(args.config)
|
||||
initialize_server()
|
||||
|
||||
# If the server is started through gunicorn (external to the script), don't start the server
|
||||
if should_start_server:
|
||||
@@ -204,8 +204,7 @@ def run(should_start_server=True):
|
||||
|
||||
|
||||
def set_state(args):
|
||||
state.config_file = args.config_file
|
||||
state.config = args.config
|
||||
state.log_file = args.log_file
|
||||
state.verbose = args.verbose
|
||||
state.host = args.host
|
||||
state.port = args.port
|
||||
@@ -214,7 +213,6 @@ def set_state(args):
|
||||
)
|
||||
state.anonymous_mode = args.anonymous_mode
|
||||
state.khoj_version = version("khoj")
|
||||
state.chat_on_gpu = args.chat_on_gpu
|
||||
|
||||
|
||||
def start_server(app, host=None, port=None, socket=None):
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Current format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enable-offline-chat: false
|
||||
chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
|
||||
New format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enable-offline-chat: false
|
||||
chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
"""
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_chat_default_model(args):
|
||||
schema_version = "0.12.4"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
return args
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
if "offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
return args
|
||||
if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]:
|
||||
return args
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse("0.12.4"):
|
||||
logger.info(
|
||||
f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Update offline chat model to mistral in GGUF format to use latest GPT4All
|
||||
offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"]
|
||||
if offline_chat_model.endswith(".bin"):
|
||||
raw_config["processor"]["conversation"]["offline-chat"]["chat-model"] = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
Current format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enable-offline-chat: false
|
||||
chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
|
||||
New format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enable-offline-chat: false
|
||||
chat-model: NousResearch/Hermes-2-Pro-Mistral-7B-GGUF
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
"""
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_chat_default_model(args):
|
||||
schema_version = "1.7.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
return args
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
if "offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
return args
|
||||
if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]:
|
||||
return args
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse(schema_version):
|
||||
logger.info(
|
||||
f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Update offline chat model to use Nous Research's Hermes-2-Pro GGUF in path format suitable for llama-cpp
|
||||
offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"]
|
||||
if offline_chat_model == "mistral-7b-instruct-v0.1.Q4_0.gguf":
|
||||
raw_config["processor"]["conversation"]["offline-chat"][
|
||||
"chat-model"
|
||||
] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
||||
@@ -1,83 +0,0 @@
|
||||
"""
|
||||
Current format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
enable-offline-chat: false
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
openai:
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
|
||||
New format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enable-offline-chat: false
|
||||
chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
|
||||
tokenizer: null
|
||||
max_prompt_size: null
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
openai:
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
"""
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_chat_schema(args):
|
||||
schema_version = "0.12.3"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
return args
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
|
||||
logger.info(
|
||||
f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Create max-prompt-size field in conversation processor schema
|
||||
raw_config["processor"]["conversation"]["max-prompt-size"] = None
|
||||
raw_config["processor"]["conversation"]["tokenizer"] = None
|
||||
|
||||
# Create offline chat schema based on existing enable_offline_chat field in khoj config schema
|
||||
offline_chat_model = (
|
||||
raw_config["processor"]["conversation"]
|
||||
.get("offline-chat", {})
|
||||
.get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
|
||||
)
|
||||
raw_config["processor"]["conversation"]["offline-chat"] = {
|
||||
"enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
|
||||
"chat-model": offline_chat_model,
|
||||
}
|
||||
|
||||
# Delete old enable-offline-chat field from conversation processor schema
|
||||
if "enable-offline-chat" in raw_config["processor"]["conversation"]:
|
||||
del raw_config["processor"]["conversation"]["enable-offline-chat"]
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
||||
@@ -1,29 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_model(args):
|
||||
schema_version = "0.10.1"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse("0.10.1"):
|
||||
logger.info(
|
||||
f"Migrating offline model used for version {previous_version} to latest version for {args.version_no}"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# If the user has downloaded the offline model, remove it from the cache.
|
||||
offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin")
|
||||
if os.path.exists(offline_model_path):
|
||||
os.remove(offline_model_path)
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
return args
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
Current format of khoj.yml
|
||||
---
|
||||
app:
|
||||
should-log-telemetry: true
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
chat-model: gpt-3.5-turbo
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
model: text-davinci-003
|
||||
openai-api-key: sk-secret-key
|
||||
search-type:
|
||||
...
|
||||
|
||||
New format of khoj.yml
|
||||
---
|
||||
app:
|
||||
should-log-telemetry: true
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
openai:
|
||||
chat-model: gpt-3.5-turbo
|
||||
openai-api-key: sk-secret-key
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
enable-offline-chat: false
|
||||
search-type:
|
||||
...
|
||||
"""
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_processor_conversation_schema(args):
|
||||
schema_version = "0.10.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
return args
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
|
||||
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
|
||||
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
|
||||
if current_openai_api_key is None and current_chat_model is None:
|
||||
return args
|
||||
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Add enable_offline_chat to khoj config schema
|
||||
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||
|
||||
# Update conversation processor schema
|
||||
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||
raw_config["processor"]["conversation"] = {
|
||||
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
|
||||
"conversation-logfile": conversation_logfile,
|
||||
"enable-offline-chat": False,
|
||||
}
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
The application config currently looks like this:
|
||||
app:
|
||||
should-log-telemetry: true
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
max-prompt-size: null
|
||||
offline-chat:
|
||||
chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
|
||||
enable-offline-chat: false
|
||||
openai:
|
||||
api-key: sk-blah
|
||||
chat-model: gpt-3.5-turbo
|
||||
tokenizer: null
|
||||
search-type:
|
||||
asymmetric:
|
||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
||||
encoder-type: null
|
||||
model-directory: /Users/si/.khoj/search/asymmetric
|
||||
image:
|
||||
encoder: sentence-transformers/clip-ViT-B-32
|
||||
encoder-type: null
|
||||
model-directory: /Users/si/.khoj/search/image
|
||||
symmetric:
|
||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
encoder: sentence-transformers/all-MiniLM-L6-v2
|
||||
encoder-type: null
|
||||
model-directory: ~/.khoj/search/symmetric
|
||||
version: 0.14.0
|
||||
|
||||
|
||||
The new version will looks like this:
|
||||
app:
|
||||
should-log-telemetry: true
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enabled: false
|
||||
openai:
|
||||
api-key: sk-blah
|
||||
chat-model-options:
|
||||
- chat-model: gpt-3.5-turbo
|
||||
tokenizer: null
|
||||
type: openai
|
||||
- chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
|
||||
tokenizer: null
|
||||
type: offline
|
||||
search-type:
|
||||
asymmetric:
|
||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
||||
version: 0.15.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_server_pg(args):
|
||||
schema_version = "0.15.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse(schema_version):
|
||||
logger.info(
|
||||
f"Migrating configuration used for version {previous_version} to latest version for server with postgres in {args.version_no}"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
if raw_config is None:
|
||||
return args
|
||||
|
||||
if "search-type" in raw_config and raw_config["search-type"]:
|
||||
if "asymmetric" in raw_config["search-type"]:
|
||||
# Delete all existing search models
|
||||
SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
|
||||
# Create new search model from existing Khoj YAML config
|
||||
asymmetric_search = raw_config["search-type"]["asymmetric"]
|
||||
SearchModelConfig.objects.create(
|
||||
name="default",
|
||||
model_type=SearchModelConfig.ModelType.TEXT,
|
||||
bi_encoder=asymmetric_search.get("encoder"),
|
||||
cross_encoder=asymmetric_search.get("cross-encoder"),
|
||||
)
|
||||
|
||||
if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
|
||||
processor_conversation = raw_config["processor"]["conversation"]
|
||||
|
||||
if "offline-chat" in raw_config["processor"]["conversation"]:
|
||||
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
|
||||
ChatModel.objects.create(
|
||||
name=offline_chat.get("chat-model"),
|
||||
tokenizer=processor_conversation.get("tokenizer"),
|
||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
)
|
||||
|
||||
if (
|
||||
"openai" in raw_config["processor"]["conversation"]
|
||||
and raw_config["processor"]["conversation"]["openai"]
|
||||
):
|
||||
openai = raw_config["processor"]["conversation"]["openai"]
|
||||
|
||||
if openai.get("api-key") is None:
|
||||
logger.error("OpenAI API Key is not set. Will not be migrating OpenAI config.")
|
||||
else:
|
||||
if openai.get("chat-model") is None:
|
||||
openai["chat-model"] = "gpt-3.5-turbo"
|
||||
|
||||
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
|
||||
|
||||
ChatModel.objects.create(
|
||||
name=openai.get("chat-model"),
|
||||
tokenizer=processor_conversation.get("tokenizer"),
|
||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
ai_model_api=openai_model_api,
|
||||
)
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
return args
|
||||
@@ -1,17 +0,0 @@
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_config_to_version(args):
|
||||
schema_version = "0.9.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
# Add version to khoj config schema
|
||||
if "version" not in raw_config:
|
||||
raw_config["version"] = schema_version
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
# regenerate khoj index on first start of this version
|
||||
# this should refresh index and apply index corruption fixes from #325
|
||||
args.regenerate = True
|
||||
|
||||
return args
|
||||
@@ -20,7 +20,6 @@ magika = Magika()
|
||||
|
||||
class GithubToEntries(TextToEntries):
|
||||
def __init__(self, config: GithubConfig):
|
||||
super().__init__(config)
|
||||
raw_repos = config.githubrepoconfig.all()
|
||||
repos = []
|
||||
for repo in raw_repos:
|
||||
|
||||
@@ -47,7 +47,6 @@ class NotionBlockType(Enum):
|
||||
|
||||
class NotionToEntries(TextToEntries):
|
||||
def __init__(self, config: NotionConfig):
|
||||
super().__init__(config)
|
||||
self.config = NotionContentConfig(
|
||||
token=config.token,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,6 @@ logger = logging.getLogger(__name__)
|
||||
class TextToEntries(ABC):
|
||||
def __init__(self, config: Any = None):
|
||||
self.embeddings_model = state.embeddings_model
|
||||
self.config = config
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from time import perf_counter
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def converse_offline(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict] = [],
|
||||
online_results={},
|
||||
code_results={},
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
additional_context: List[str] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
# Model
|
||||
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using Llama (Async Version)
|
||||
"""
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
|
||||
tracer["chat_model"] = model_name
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_system_prompt_offline_chat.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.system_prompt_offline_chat.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
system_prompt,
|
||||
chat_history,
|
||||
context_message=context_message,
|
||||
model_name=model_name,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=additional_context,
|
||||
)
|
||||
|
||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||
|
||||
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||
queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue()
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
|
||||
def _sync_llm_thread():
|
||||
"""Synchronous function to run in a separate thread."""
|
||||
aggregated_response = ""
|
||||
start_time = perf_counter()
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages,
|
||||
loaded_model=offline_chat_model,
|
||||
stop=stop_phrases,
|
||||
max_prompt_size=max_prompt_size,
|
||||
streaming=True,
|
||||
tracer=tracer,
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_delta: str = response["choices"][0]["delta"].get("content", "")
|
||||
# Log the time taken to start response
|
||||
if aggregated_response == "" and response_delta != "":
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Handle response chunk
|
||||
aggregated_response += response_delta
|
||||
# Put chunk into the asyncio queue (non-blocking)
|
||||
try:
|
||||
queue.put_nowait(ResponseWithThought(text=response_delta))
|
||||
except asyncio.QueueFull:
|
||||
# Should not happen with default queue size unless consumer is very slow
|
||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||
# Potentially block here or handle differently if needed
|
||||
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
# Signal end of stream
|
||||
queue.put_nowait(None)
|
||||
|
||||
# Start the synchronous thread
|
||||
thread = Thread(target=_sync_llm_thread)
|
||||
thread.start()
|
||||
|
||||
# Asynchronously consume from the queue
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None: # End of stream signal
|
||||
queue.task_done()
|
||||
break
|
||||
yield chunk
|
||||
queue.task_done()
|
||||
|
||||
# Wait for the thread to finish (optional, ensures cleanup)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, thread.join)
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
messages: List[ChatMessage],
|
||||
loaded_model=None,
|
||||
model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
temperature: float = 0.2,
|
||||
streaming=False,
|
||||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
response_type: str = "text",
|
||||
tracer: dict = {},
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict,
|
||||
stop=stop,
|
||||
stream=streaming,
|
||||
temperature=temperature,
|
||||
response_format={"type": response_type},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return response
|
||||
|
||||
response_text: str = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return ResponseWithThought(text=response_text)
|
||||
@@ -1,80 +0,0 @@
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_device_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
||||
# Initialize Model Parameters
|
||||
# Use n_ctx=0 to get context size from the model
|
||||
kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||
|
||||
# Decide whether to load model to GPU or CPU
|
||||
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
||||
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
||||
|
||||
# Add chat format if known
|
||||
if "llama-3" in repo_id.lower():
|
||||
kwargs["chat_format"] = "llama-3"
|
||||
elif "gemma-2" in repo_id.lower():
|
||||
kwargs["chat_format"] = "gemma"
|
||||
|
||||
# Check if the model is already downloaded
|
||||
model_path = load_model_from_cache(repo_id, filename)
|
||||
chat_model = None
|
||||
try:
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
except:
|
||||
# Load model on CPU if GPU is not available
|
||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
|
||||
# Now load the model with context size set based on:
|
||||
# 1. context size supported by model and
|
||||
# 2. configured size or machine (V)RAM
|
||||
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
|
||||
logger.debug(
|
||||
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
|
||||
)
|
||||
return chat_model
|
||||
|
||||
|
||||
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
|
||||
from llama_cpp.llama import Llama
|
||||
|
||||
if model_path:
|
||||
return Llama(model_path, **kwargs)
|
||||
else:
|
||||
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
|
||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||
# Construct the path to the model file in the cache directory
|
||||
repo_org, repo_name = repo_id.split("/")
|
||||
object_id = "--".join([repo_type, repo_org, repo_name])
|
||||
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
|
||||
|
||||
# Check if the model file exists
|
||||
paths = glob.glob(model_path)
|
||||
if paths:
|
||||
return paths[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int:
|
||||
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||
configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
|
||||
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
|
||||
configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None
|
||||
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
|
||||
@@ -1,15 +0,0 @@
|
||||
import whisper
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
async def transcribe_audio_offline(audio_filename: str, model: str) -> str:
|
||||
"""
|
||||
Transcribe audio file offline using Whisper
|
||||
"""
|
||||
# Send the audio data to the Whisper API
|
||||
if not state.whisper_model:
|
||||
state.whisper_model = whisper.load_model(model)
|
||||
response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
|
||||
return response["text"]
|
||||
@@ -78,38 +78,6 @@ no_entries_found = PromptTemplate.from_template(
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Conversation Prompts for Offline Chat Models
|
||||
## --
|
||||
system_prompt_offline_chat = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
- If you do not know the answer, say 'I don't know.'
|
||||
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
|
||||
- Do not print verbatim Notes unless necessary.
|
||||
|
||||
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
|
||||
Today is {day_of_week}, {current_date} in UTC.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
custom_system_prompt_offline_chat = PromptTemplate.from_template(
|
||||
"""
|
||||
You are {name}, a personal agent on Khoj.
|
||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
- If you do not know the answer, say 'I don't know.'
|
||||
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
|
||||
- Do not print verbatim Notes unless necessary.
|
||||
|
||||
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
|
||||
Today is {day_of_week}, {current_date} in UTC.
|
||||
|
||||
Instructions:\n{bio}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Notes Conversation
|
||||
## --
|
||||
notes_conversation = PromptTemplate.from_template(
|
||||
|
||||
@@ -18,8 +18,6 @@ import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import LlamaTokenizer
|
||||
from llama_cpp.llama import Llama
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
@@ -32,7 +30,6 @@ from khoj.database.models import (
|
||||
KhojUser,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -85,12 +82,6 @@ model_to_prompt_size = {
|
||||
"claude-sonnet-4-20250514": 60000,
|
||||
"claude-opus-4-0": 60000,
|
||||
"claude-opus-4-20250514": 60000,
|
||||
# Offline Models
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF": 20000,
|
||||
"bartowski/gemma-2-9b-it-GGUF": 6000,
|
||||
"bartowski/gemma-2-2b-it-GGUF": 6000,
|
||||
}
|
||||
model_to_tokenizer: Dict[str, str] = {}
|
||||
|
||||
@@ -573,7 +564,6 @@ def generate_chatml_messages_with_context(
|
||||
system_message: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model_name="gpt-4o-mini",
|
||||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
query_images=None,
|
||||
@@ -588,10 +578,7 @@ def generate_chatml_messages_with_context(
|
||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
if not max_prompt_size:
|
||||
if loaded_model:
|
||||
max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
||||
else:
|
||||
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
|
||||
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
|
||||
|
||||
# Scale lookback turns proportional to max prompt size supported by model
|
||||
lookback_turns = max_prompt_size // 750
|
||||
@@ -735,7 +722,7 @@ def generate_chatml_messages_with_context(
|
||||
message.content = [{"type": "text", "text": message.content}]
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
@@ -743,25 +730,20 @@ def generate_chatml_messages_with_context(
|
||||
|
||||
def get_encoder(
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
if loaded_model:
|
||||
encoder = loaded_model.tokenizer()
|
||||
elif model_name.startswith("gpt-") or model_name.startswith("o1"):
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
elif tokenizer_name:
|
||||
if tokenizer_name:
|
||||
if tokenizer_name in state.pretrained_tokenizers:
|
||||
encoder = state.pretrained_tokenizers[tokenizer_name]
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
state.pretrained_tokenizers[tokenizer_name] = encoder
|
||||
else:
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
except:
|
||||
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
||||
if state.verbose > 2:
|
||||
@@ -773,7 +755,7 @@ def get_encoder(
|
||||
|
||||
def count_tokens(
|
||||
message_content: str | list[str | dict],
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
|
||||
) -> int:
|
||||
"""
|
||||
Count the total number of tokens in a list of messages.
|
||||
@@ -825,11 +807,10 @@ def truncate_messages(
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size: int,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
|
||||
encoder = get_encoder(model_name, tokenizer_name)
|
||||
|
||||
# Extract system message from messages
|
||||
system_message = None
|
||||
|
||||
@@ -235,7 +235,6 @@ def is_operator_model(model: str) -> ChatModel.ModelType | None:
|
||||
"claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC,
|
||||
"claude-sonnet-4": ChatModel.ModelType.ANTHROPIC,
|
||||
"claude-opus-4": ChatModel.ModelType.ANTHROPIC,
|
||||
"ui-tars-1.5": ChatModel.ModelType.OFFLINE,
|
||||
}
|
||||
for operator_model in operator_models:
|
||||
if model.startswith(operator_model):
|
||||
|
||||
@@ -15,7 +15,6 @@ from khoj.configure import initialize_content
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo
|
||||
from khoj.database.models import KhojUser, SpeechToTextModelOptions
|
||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
@@ -88,22 +87,14 @@ def update(
|
||||
force: Optional[bool] = False,
|
||||
):
|
||||
user = request.user.object
|
||||
if not state.config:
|
||||
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/settings, plugins or by editing {state.config_file}."
|
||||
logger.warning(error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
try:
|
||||
initialize_content(user=user, regenerate=force, search_type=t)
|
||||
except Exception as e:
|
||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||
error_msg = f"🚨 Failed to update server indexed content via API: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
else:
|
||||
components = []
|
||||
if state.search_models:
|
||||
components.append("Search models")
|
||||
components_msg = ", ".join(components)
|
||||
logger.info(f"📪 {components_msg} updated via API")
|
||||
logger.info(f"📪 Server indexed content updated via API")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
@@ -150,9 +141,6 @@ async def transcribe(
|
||||
if not speech_to_text_config:
|
||||
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
if speech_to_text_config.ai_model_api:
|
||||
|
||||
@@ -27,16 +27,7 @@ from khoj.database.adapters import (
|
||||
get_user_notion_config,
|
||||
)
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
KhojUser,
|
||||
LocalMarkdownConfig,
|
||||
LocalOrgConfig,
|
||||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
NotionConfig,
|
||||
)
|
||||
from khoj.database.models import GithubConfig, GithubRepoConfig, NotionConfig
|
||||
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
from khoj.routers.helpers import (
|
||||
@@ -47,17 +38,9 @@ from khoj.routers.helpers import (
|
||||
get_user_config,
|
||||
update_telemetry_state,
|
||||
)
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
FullConfig,
|
||||
GithubContentConfig,
|
||||
NotionContentConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.rawconfig import GithubContentConfig, NotionContentConfig
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -192,8 +175,6 @@ async def set_content_github(
|
||||
updated_config: Union[GithubContentConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
user = request.user.object
|
||||
|
||||
try:
|
||||
@@ -225,8 +206,6 @@ async def set_content_notion(
|
||||
updated_config: Union[NotionContentConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
user = request.user.object
|
||||
|
||||
try:
|
||||
@@ -323,10 +302,6 @@ def get_content_types(request: Request, client: Optional[str] = None):
|
||||
configured_content_types = set(EntryAdapters.get_unique_file_types(user))
|
||||
configured_content_types |= {"all"}
|
||||
|
||||
if state.config and state.config.content_type:
|
||||
for ctype in state.config.content_type.model_dump(exclude_none=True):
|
||||
configured_content_types.add(ctype)
|
||||
|
||||
return list(configured_content_types & all_content_types)
|
||||
|
||||
|
||||
@@ -606,28 +581,6 @@ async def indexer(
|
||||
docx=index_files["docx"],
|
||||
)
|
||||
|
||||
if state.config == None:
|
||||
logger.info("📬 Initializing content index on first run.")
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||
processor=None,
|
||||
)
|
||||
state.config = default_full_config
|
||||
default_content_config = ContentConfig(
|
||||
org=None,
|
||||
markdown=None,
|
||||
pdf=None,
|
||||
docx=None,
|
||||
image=None,
|
||||
github=None,
|
||||
notion=None,
|
||||
plaintext=None,
|
||||
)
|
||||
state.config.content_type = default_content_config
|
||||
save_config_to_file_updated_state()
|
||||
configure_search(state.search_models, state.config.search_type)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
success = await loop.run_in_executor(
|
||||
None,
|
||||
@@ -674,14 +627,6 @@ async def indexer(
|
||||
return Response(content=indexed_filenames, status_code=200)
|
||||
|
||||
|
||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
def map_config_to_object(content_source: str):
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
return GithubConfig
|
||||
@@ -689,56 +634,3 @@ def map_config_to_object(content_source: str):
|
||||
return NotionConfig
|
||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||
return "Computer"
|
||||
|
||||
|
||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||
if config.content_type:
|
||||
if config.content_type.org:
|
||||
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||
await LocalOrgConfig.objects.acreate(
|
||||
input_files=config.content_type.org.input_files,
|
||||
input_filter=config.content_type.org.input_filter,
|
||||
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.markdown:
|
||||
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||
await LocalMarkdownConfig.objects.acreate(
|
||||
input_files=config.content_type.markdown.input_files,
|
||||
input_filter=config.content_type.markdown.input_filter,
|
||||
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.pdf:
|
||||
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||
await LocalPdfConfig.objects.acreate(
|
||||
input_files=config.content_type.pdf.input_files,
|
||||
input_filter=config.content_type.pdf.input_filter,
|
||||
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.plaintext:
|
||||
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||
await LocalPlaintextConfig.objects.acreate(
|
||||
input_files=config.content_type.plaintext.input_files,
|
||||
input_filter=config.content_type.plaintext.input_filter,
|
||||
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.github:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=config.content_type.github.pat_token,
|
||||
repos=config.content_type.github.repos,
|
||||
)
|
||||
if config.content_type.notion:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=config.content_type.notion.token,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_config():
|
||||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||
|
||||
@@ -89,10 +89,6 @@ from khoj.processor.conversation.google.gemini_chat import (
|
||||
converse_gemini,
|
||||
gemini_send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
send_message_to_model_offline,
|
||||
)
|
||||
from khoj.processor.conversation.openai.gpt import (
|
||||
converse_openai,
|
||||
send_message_to_model,
|
||||
@@ -117,7 +113,6 @@ from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import (
|
||||
LRU,
|
||||
ConversationCommand,
|
||||
@@ -168,14 +163,6 @@ async def is_ready_to_chat(user: KhojUser):
|
||||
if user_chat_model == None:
|
||||
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
chat_model_name = user_chat_model.name
|
||||
max_tokens = user_chat_model.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
logger.info("Loading Offline Chat Model...")
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
return True
|
||||
|
||||
if (
|
||||
user_chat_model
|
||||
and (
|
||||
@@ -231,7 +218,6 @@ def update_telemetry_state(
|
||||
telemetry_type=telemetry_type,
|
||||
api=api,
|
||||
client=client,
|
||||
app_config=state.config.app,
|
||||
disable_telemetry_env=state.telemetry_disabled,
|
||||
properties=user_state,
|
||||
)
|
||||
@@ -1470,12 +1456,6 @@ async def send_message_to_model_wrapper(
|
||||
vision_available = chat_model.vision_enabled
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
loaded_model = None
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
@@ -1483,7 +1463,6 @@ async def send_message_to_model_wrapper(
|
||||
system_message=system_message,
|
||||
chat_history=chat_history,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
@@ -1492,18 +1471,7 @@ async def send_message_to_model_wrapper(
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModel.ModelType.OPENAI:
|
||||
if model_type == ChatModel.ModelType.OPENAI:
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
@@ -1565,19 +1533,12 @@ def send_message_to_model_wrapper_sync(
|
||||
vision_available = chat_model.vision_enabled
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
loaded_model = None
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
chat_history=chat_history,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=model_type,
|
||||
@@ -1585,18 +1546,7 @@ def send_message_to_model_wrapper_sync(
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModel.ModelType.OPENAI:
|
||||
if model_type == ChatModel.ModelType.OPENAI:
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
@@ -1678,30 +1628,7 @@ async def agenerate_chat_response(
|
||||
chat_model = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
if chat_model.model_type == "offline":
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response_generator = converse_offline(
|
||||
# Query
|
||||
user_query=query_to_run,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
query_files=query_files,
|
||||
chat_history=chat_history,
|
||||
# Model
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model.name,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
openai_chat_config = chat_model.ai_model_api
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model_name = chat_model.name
|
||||
@@ -2798,7 +2725,8 @@ def configure_content(
|
||||
|
||||
search_type = t.value if t else None
|
||||
|
||||
no_documents = all([not files.get(file_type) for file_type in files])
|
||||
# Check if client sent any documents of the supported types
|
||||
no_client_sent_documents = all([not files.get(file_type) for file_type in files])
|
||||
|
||||
if files is None:
|
||||
logger.warning(f"🚨 No files to process for {search_type} search.")
|
||||
@@ -2872,7 +2800,8 @@ def configure_content(
|
||||
success = False
|
||||
|
||||
try:
|
||||
if no_documents:
|
||||
# Run server side indexing of user Github docs if no client sent documents
|
||||
if no_client_sent_documents:
|
||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if (
|
||||
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
|
||||
@@ -2892,7 +2821,8 @@ def configure_content(
|
||||
success = False
|
||||
|
||||
try:
|
||||
if no_documents:
|
||||
# Run server side indexing of user Notion docs if no client sent documents
|
||||
if no_client_sent_documents:
|
||||
# Initialize Notion Search
|
||||
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||
if (
|
||||
|
||||
@@ -1,36 +1,19 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from importlib.metadata import version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from khoj.migrations.migrate_offline_chat_default_model import (
|
||||
migrate_offline_chat_default_model,
|
||||
)
|
||||
from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
|
||||
from khoj.migrations.migrate_offline_model import migrate_offline_model
|
||||
from khoj.migrations.migrate_processor_config_openai import (
|
||||
migrate_processor_conversation_schema,
|
||||
)
|
||||
from khoj.migrations.migrate_server_pg import migrate_server_pg
|
||||
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||
from khoj.utils.helpers import is_env_var_true, resolve_absolute_path
|
||||
from khoj.utils.yaml import parse_config_from_file
|
||||
|
||||
|
||||
def cli(args=None):
|
||||
# Setup Argument Parser for the Commandline Interface
|
||||
parser = argparse.ArgumentParser(description="Start Khoj; An AI personal assistant for your Digital Brain")
|
||||
parser.add_argument(
|
||||
"--config-file", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--regenerate",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Regenerate model embeddings from source files. Default: false",
|
||||
"--log-file",
|
||||
default="~/.khoj/khoj.log",
|
||||
type=pathlib.Path,
|
||||
help="File path for server logs. Default: ~/.khoj/khoj.log",
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0")
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1")
|
||||
@@ -43,14 +26,11 @@ def cli(args=None):
|
||||
parser.add_argument("--sslcert", type=str, help="Path to SSL certificate file")
|
||||
parser.add_argument("--sslkey", type=str, help="Path to SSL key file")
|
||||
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
||||
parser.add_argument(
|
||||
"--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--anonymous-mode",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
|
||||
help="Run Khoj in single user mode with no login required. Useful for personal use or testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
@@ -64,38 +44,10 @@ def cli(args=None):
|
||||
if len(remaining_args) > 0:
|
||||
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||
|
||||
# Set default values for arguments
|
||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||
|
||||
args.version_no = version("khoj")
|
||||
if args.version:
|
||||
# Show version of khoj installed and exit
|
||||
print(args.version_no)
|
||||
exit(0)
|
||||
|
||||
# Normalize config_file path to absolute path
|
||||
args.config_file = resolve_absolute_path(args.config_file)
|
||||
|
||||
if not args.config_file.exists():
|
||||
args.config = None
|
||||
else:
|
||||
args = run_migrations(args)
|
||||
args.config = parse_config_from_file(args.config_file)
|
||||
if is_env_var_true("KHOJ_TELEMETRY_DISABLE"):
|
||||
args.config.app.should_log_telemetry = False
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def run_migrations(args):
|
||||
migrations = [
|
||||
migrate_config_to_version,
|
||||
migrate_processor_conversation_schema,
|
||||
migrate_offline_model,
|
||||
migrate_offline_chat_schema,
|
||||
migrate_offline_chat_default_model,
|
||||
migrate_server_pg,
|
||||
]
|
||||
for migration in migrations:
|
||||
args = migration(args)
|
||||
return args
|
||||
|
||||
@@ -1,22 +1,7 @@
|
||||
# System Packages
|
||||
from __future__ import annotations # to avoid quoting type hints
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
from khoj.utils.models import BaseEncoder
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
@@ -29,53 +14,3 @@ class SearchType(str, Enum):
|
||||
Notion = "notion"
|
||||
Plaintext = "plaintext"
|
||||
Docx = "docx"
|
||||
|
||||
|
||||
class ProcessorType(str, Enum):
|
||||
Conversation = "conversation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
enabled: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
image_names: List[str]
|
||||
image_embeddings: torch.Tensor
|
||||
image_metadata_embeddings: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextSearchModel:
|
||||
bi_encoder: BaseEncoder
|
||||
cross_encoder: Optional[CrossEncoder] = None
|
||||
top_k: Optional[int] = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageSearchModel:
|
||||
image_encoder: BaseEncoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineChatProcessorConfig:
|
||||
loaded_model: Union[Any, None] = None
|
||||
|
||||
|
||||
class OfflineChatProcessorModel:
|
||||
def __init__(self, chat_model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", max_tokens: int = None):
|
||||
self.chat_model = chat_model
|
||||
self.loaded_model = None
|
||||
try:
|
||||
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
|
||||
except ValueError as e:
|
||||
self.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
@@ -10,13 +10,6 @@ empty_escape_sequences = "\n|\r|\t| "
|
||||
app_env_filepath = "~/.khoj/env"
|
||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||
content_directory = "~/.khoj/content/"
|
||||
default_offline_chat_models = [
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||
"bartowski/gemma-2-9b-it-GGUF",
|
||||
"bartowski/gemma-2-2b-it-GGUF",
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF",
|
||||
]
|
||||
default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1", "o3", "o4-mini"]
|
||||
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"]
|
||||
default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"]
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from magika import Magika
|
||||
|
||||
from khoj.database.models import (
|
||||
KhojUser,
|
||||
LocalMarkdownConfig,
|
||||
LocalOrgConfig,
|
||||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
)
|
||||
from khoj.utils.config import SearchType
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
magika = Magika()
|
||||
|
||||
|
||||
def collect_files(user: KhojUser, search_type: Optional[SearchType] = SearchType.All) -> dict:
|
||||
files: dict[str, dict] = {"docx": {}, "image": {}}
|
||||
|
||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||
org_config = LocalOrgConfig.objects.filter(user=user).first()
|
||||
files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
||||
markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
|
||||
files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Plaintext:
|
||||
plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
|
||||
files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Pdf:
|
||||
pdf_config = LocalPdfConfig.objects.filter(user=user).first()
|
||||
files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
|
||||
files["image"] = {}
|
||||
files["docx"] = {}
|
||||
return files
|
||||
|
||||
|
||||
def construct_config_from_db(db_config) -> TextContentConfig:
|
||||
return TextContentConfig(
|
||||
input_files=db_config.input_files,
|
||||
input_filter=db_config.input_filter,
|
||||
index_heading_entries=db_config.index_heading_entries,
|
||||
)
|
||||
|
||||
|
||||
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
|
||||
def is_plaintextfile(file: str):
|
||||
"Check if file is plaintext file"
|
||||
# Check if file path exists
|
||||
content_group = magika.identify_path(Path(file)).output.group
|
||||
# Use file extension to decide plaintext if file content is not identifiable
|
||||
valid_text_file_extensions = ("txt", "md", "markdown", "org" "mbox", "rst", "html", "htm", "xml")
|
||||
return file.endswith(valid_text_file_extensions) or content_group in ["text", "code"]
|
||||
|
||||
def extract_html_content(html_content: str):
|
||||
"Extract content from HTML"
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
# Extract required fields from config
|
||||
input_files, input_filters = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(input_files) and is_none_or_empty(input_filters):
|
||||
logger.debug("At least one of input-files or input-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
# Get all plain text files to process
|
||||
absolute_plaintext_files, filtered_plaintext_files = set(), set()
|
||||
if input_files:
|
||||
absolute_plaintext_files = {get_absolute_path(jsonl_file) for jsonl_file in input_files}
|
||||
if input_filters:
|
||||
filtered_plaintext_files = {
|
||||
filtered_file
|
||||
for plaintext_file_filter in input_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(plaintext_file_filter), recursive=True)
|
||||
if os.path.isfile(filtered_file)
|
||||
}
|
||||
|
||||
all_target_files = sorted(absolute_plaintext_files | filtered_plaintext_files)
|
||||
|
||||
files_with_no_plaintext_extensions = {
|
||||
target_files for target_files in all_target_files if not is_plaintextfile(target_files)
|
||||
}
|
||||
if any(files_with_no_plaintext_extensions):
|
||||
logger.warning(f"Skipping unsupported files from plaintext indexing: {files_with_no_plaintext_extensions}")
|
||||
all_target_files = list(set(all_target_files) - files_with_no_plaintext_extensions)
|
||||
|
||||
logger.debug(f"Processing files: {all_target_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_target_files:
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
try:
|
||||
plaintext_content = f.read()
|
||||
if file.endswith(("html", "htm", "xml")):
|
||||
plaintext_content = extract_html_content(plaintext_content)
|
||||
filename_to_content_map[file] = plaintext_content
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as plaintext. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
def get_org_files(config: TextContentConfig):
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filters = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filters):
|
||||
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
# Get Org files to process
|
||||
absolute_org_files, filtered_org_files = set(), set()
|
||||
if org_files:
|
||||
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||
if org_file_filters:
|
||||
filtered_org_files = {
|
||||
filtered_file
|
||||
for org_file_filter in org_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(org_file_filter), recursive=True)
|
||||
if os.path.isfile(filtered_file)
|
||||
}
|
||||
|
||||
all_org_files = sorted(absolute_org_files | filtered_org_files)
|
||||
|
||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||
if any(files_with_non_org_extensions):
|
||||
logger.warning(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_org_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_org_files:
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as org. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
def get_markdown_files(config: TextContentConfig):
|
||||
# Extract required fields from config
|
||||
markdown_files, markdown_file_filters = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filters):
|
||||
logger.debug("At least one of markdown-files or markdown-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
# Get markdown files to process
|
||||
absolute_markdown_files, filtered_markdown_files = set(), set()
|
||||
if markdown_files:
|
||||
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
|
||||
|
||||
if markdown_file_filters:
|
||||
filtered_markdown_files = {
|
||||
filtered_file
|
||||
for markdown_file_filter in markdown_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter), recursive=True)
|
||||
if os.path.isfile(filtered_file)
|
||||
}
|
||||
|
||||
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
|
||||
|
||||
files_with_non_markdown_extensions = {
|
||||
md_file for md_file in all_markdown_files if not md_file.endswith(".md") and not md_file.endswith(".markdown")
|
||||
}
|
||||
|
||||
if any(files_with_non_markdown_extensions):
|
||||
logger.warning(
|
||||
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing files: {all_markdown_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_markdown_files:
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as markdown. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
def get_pdf_files(config: TextContentConfig):
|
||||
# Extract required fields from config
|
||||
pdf_files, pdf_file_filters = (
|
||||
config.input_files,
|
||||
config.input_filter,
|
||||
)
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(pdf_files) and is_none_or_empty(pdf_file_filters):
|
||||
logger.debug("At least one of pdf-files or pdf-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
# Get PDF files to process
|
||||
absolute_pdf_files, filtered_pdf_files = set(), set()
|
||||
if pdf_files:
|
||||
absolute_pdf_files = {get_absolute_path(pdf_file) for pdf_file in pdf_files}
|
||||
if pdf_file_filters:
|
||||
filtered_pdf_files = {
|
||||
filtered_file
|
||||
for pdf_file_filter in pdf_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(pdf_file_filter), recursive=True)
|
||||
if os.path.isfile(filtered_file)
|
||||
}
|
||||
|
||||
all_pdf_files = sorted(absolute_pdf_files | filtered_pdf_files)
|
||||
|
||||
files_with_non_pdf_extensions = {pdf_file for pdf_file in all_pdf_files if not pdf_file.endswith(".pdf")}
|
||||
|
||||
if any(files_with_non_pdf_extensions):
|
||||
logger.warning(f"[Warning] There maybe non pdf-mode files in the input set: {files_with_non_pdf_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_pdf_files}")
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_pdf_files:
|
||||
with open(file, "rb") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file} as PDF. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
@@ -47,7 +47,6 @@ if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -267,23 +266,16 @@ def get_server_id():
|
||||
return server_id
|
||||
|
||||
|
||||
def telemetry_disabled(app_config: AppConfig, telemetry_disable_env) -> bool:
|
||||
if telemetry_disable_env is True:
|
||||
return True
|
||||
return not app_config or not app_config.should_log_telemetry
|
||||
|
||||
|
||||
def log_telemetry(
|
||||
telemetry_type: str,
|
||||
api: str = None,
|
||||
client: Optional[str] = None,
|
||||
app_config: Optional[AppConfig] = None,
|
||||
disable_telemetry_env: bool = False,
|
||||
properties: dict = None,
|
||||
):
|
||||
"""Log basic app usage telemetry like client, os, api called"""
|
||||
# Do not log usage telemetry, if telemetry is disabled via app config
|
||||
if telemetry_disabled(app_config, disable_telemetry_env):
|
||||
if disable_telemetry_env:
|
||||
return []
|
||||
|
||||
if properties.get("server_id") is None:
|
||||
|
||||
@@ -16,7 +16,6 @@ from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tok
|
||||
from khoj.utils.constants import (
|
||||
default_anthropic_chat_models,
|
||||
default_gemini_chat_models,
|
||||
default_offline_chat_models,
|
||||
default_openai_chat_models,
|
||||
)
|
||||
|
||||
@@ -72,7 +71,6 @@ def initialization(interactive: bool = True):
|
||||
default_api_key=openai_api_key,
|
||||
api_base_url=openai_base_url,
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
provider_name=provider,
|
||||
)
|
||||
@@ -118,7 +116,6 @@ def initialization(interactive: bool = True):
|
||||
default_gemini_chat_models,
|
||||
default_api_key=os.getenv("GEMINI_API_KEY"),
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
provider_name="Google Gemini",
|
||||
)
|
||||
@@ -145,40 +142,11 @@ def initialization(interactive: bool = True):
|
||||
default_anthropic_chat_models,
|
||||
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
# Set up offline chat models
|
||||
_setup_chat_model_provider(
|
||||
ChatModel.ModelType.OFFLINE,
|
||||
default_offline_chat_models,
|
||||
default_api_key=None,
|
||||
vision_enabled=False,
|
||||
is_offline=True,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
logger.info("🗣️ Chat model configuration complete")
|
||||
|
||||
# Set up offline speech to text model
|
||||
use_offline_speech2text_model = "n" if not interactive else input("Use offline speech to text model? (y/n): ")
|
||||
if use_offline_speech2text_model == "y":
|
||||
logger.info("🗣️ Setting up offline speech to text model")
|
||||
# Delete any existing speech to text model options. There can only be one.
|
||||
SpeechToTextModelOptions.objects.all().delete()
|
||||
|
||||
default_offline_speech2text_model = "base"
|
||||
offline_speech2text_model = input(
|
||||
f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): "
|
||||
)
|
||||
offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model
|
||||
SpeechToTextModelOptions.objects.create(
|
||||
model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE
|
||||
)
|
||||
|
||||
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
||||
|
||||
def _setup_chat_model_provider(
|
||||
model_type: ChatModel.ModelType,
|
||||
default_chat_models: list,
|
||||
@@ -186,7 +154,6 @@ def initialization(interactive: bool = True):
|
||||
interactive: bool,
|
||||
api_base_url: str = None,
|
||||
vision_enabled: bool = False,
|
||||
is_offline: bool = False,
|
||||
provider_name: str = None,
|
||||
) -> Tuple[bool, AiModelApi]:
|
||||
supported_vision_models = (
|
||||
@@ -195,11 +162,6 @@ def initialization(interactive: bool = True):
|
||||
provider_name = provider_name or model_type.name.capitalize()
|
||||
|
||||
default_use_model = default_api_key is not None
|
||||
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment.
|
||||
# This usually means there's not enough RAM to load offline models directly within the application.
|
||||
# In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
|
||||
if is_offline:
|
||||
default_use_model = False
|
||||
|
||||
use_model_provider = (
|
||||
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y"
|
||||
@@ -211,13 +173,12 @@ def initialization(interactive: bool = True):
|
||||
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
|
||||
|
||||
ai_model_api = None
|
||||
if not is_offline:
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_api_key
|
||||
else:
|
||||
api_key = default_api_key
|
||||
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_api_key
|
||||
else:
|
||||
api_key = default_api_key
|
||||
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
|
||||
|
||||
if interactive:
|
||||
user_chat_models = input(
|
||||
|
||||
@@ -48,17 +48,6 @@ class FilesFilterRequest(BaseModel):
|
||||
conversation_id: str
|
||||
|
||||
|
||||
class TextConfigBase(ConfigBase):
|
||||
compressed_jsonl: Path
|
||||
embeddings_file: Path
|
||||
|
||||
|
||||
class TextContentConfig(ConfigBase):
|
||||
input_files: Optional[List[Path]] = None
|
||||
input_filter: Optional[List[str]] = None
|
||||
index_heading_entries: Optional[bool] = False
|
||||
|
||||
|
||||
class GithubRepoConfig(ConfigBase):
|
||||
name: str
|
||||
owner: str
|
||||
@@ -74,62 +63,6 @@ class NotionContentConfig(ConfigBase):
|
||||
token: str
|
||||
|
||||
|
||||
class ContentConfig(ConfigBase):
|
||||
org: Optional[TextContentConfig] = None
|
||||
markdown: Optional[TextContentConfig] = None
|
||||
pdf: Optional[TextContentConfig] = None
|
||||
plaintext: Optional[TextContentConfig] = None
|
||||
github: Optional[GithubContentConfig] = None
|
||||
notion: Optional[NotionContentConfig] = None
|
||||
image: Optional[TextContentConfig] = None
|
||||
docx: Optional[TextContentConfig] = None
|
||||
|
||||
|
||||
class ImageSearchConfig(ConfigBase):
|
||||
encoder: str
|
||||
encoder_type: Optional[str] = None
|
||||
model_directory: Optional[Path] = None
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class SearchConfig(ConfigBase):
|
||||
image: Optional[ImageSearchConfig] = None
|
||||
|
||||
|
||||
class OpenAIProcessorConfig(ConfigBase):
|
||||
api_key: str
|
||||
chat_model: Optional[str] = "gpt-4o-mini"
|
||||
|
||||
|
||||
class OfflineChatProcessorConfig(ConfigBase):
|
||||
chat_model: Optional[str] = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
openai: Optional[OpenAIProcessorConfig] = None
|
||||
offline_chat: Optional[OfflineChatProcessorConfig] = None
|
||||
max_prompt_size: Optional[int] = None
|
||||
tokenizer: Optional[str] = None
|
||||
|
||||
|
||||
class ProcessorConfig(ConfigBase):
|
||||
conversation: Optional[ConversationProcessorConfig] = None
|
||||
|
||||
|
||||
class AppConfig(ConfigBase):
|
||||
should_log_telemetry: bool = True
|
||||
|
||||
|
||||
class FullConfig(ConfigBase):
|
||||
content_type: Optional[ContentConfig] = None
|
||||
search_type: Optional[SearchConfig] = None
|
||||
processor: Optional[ProcessorConfig] = None
|
||||
app: Optional[AppConfig] = AppConfig()
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
class SearchResponse(ConfigBase):
|
||||
entry: str
|
||||
score: float
|
||||
|
||||
@@ -12,19 +12,14 @@ from whisper import Whisper
|
||||
from khoj.database.models import ProcessLock
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import OfflineChatProcessorModel, SearchModels
|
||||
from khoj.utils.helpers import LRU, get_device, is_env_var_true
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
search_models = SearchModels()
|
||||
embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||
cross_encoder_model: Dict[str, CrossEncoderModel] = None
|
||||
openai_client: OpenAI = None
|
||||
offline_chat_processor_config: OfflineChatProcessorModel = None
|
||||
whisper_model: Whisper = None
|
||||
config_file: Path = None
|
||||
log_file: Path = None
|
||||
verbose: int = 0
|
||||
host: str = None
|
||||
port: int = None
|
||||
@@ -39,7 +34,6 @@ telemetry: List[Dict[str, str]] = []
|
||||
telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
|
||||
khoj_version: str = None
|
||||
device = get_device()
|
||||
chat_on_gpu: bool = True
|
||||
anonymous_mode: bool = False
|
||||
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
|
||||
billing_enabled: bool = (
|
||||
|
||||
@@ -1,47 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
# Do not emit tags when dumping to YAML
|
||||
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None # type: ignore[assignment]
|
||||
|
||||
|
||||
def save_config_to_file_updated_state():
|
||||
with open(state.config_file, "w") as outfile:
|
||||
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
|
||||
outfile.close()
|
||||
return state.config
|
||||
|
||||
|
||||
def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
|
||||
"Write config to YML file"
|
||||
# Create output directory, if it doesn't exist
|
||||
yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(yaml_config_file, "w", encoding="utf-8") as config_file:
|
||||
yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
|
||||
|
||||
|
||||
def load_config_from_file(yaml_config_file: Path) -> dict:
|
||||
"Read config from YML file"
|
||||
config_from_file = None
|
||||
with open(yaml_config_file, "r", encoding="utf-8") as config_file:
|
||||
config_from_file = yaml.safe_load(config_file)
|
||||
return config_from_file
|
||||
|
||||
|
||||
def parse_config_from_string(yaml_config: dict) -> FullConfig:
|
||||
"Parse and validate config in YML string"
|
||||
return FullConfig.model_validate(yaml_config)
|
||||
|
||||
|
||||
def parse_config_from_file(yaml_config_file):
|
||||
"Parse and validate config in YML file"
|
||||
return parse_config_from_string(load_config_from_file(yaml_config_file))
|
||||
|
||||
|
||||
def yaml_dump(data):
|
||||
return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -11,6 +8,7 @@ from khoj.configure import (
|
||||
configure_routes,
|
||||
configure_search_types,
|
||||
)
|
||||
from khoj.database.adapters import get_default_search_model
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModel,
|
||||
@@ -19,21 +17,14 @@ from khoj.database.models import (
|
||||
GithubRepoConfig,
|
||||
KhojApiUser,
|
||||
KhojUser,
|
||||
LocalMarkdownConfig,
|
||||
LocalOrgConfig,
|
||||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
)
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.api_content import configure_content
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import fs_syncer, state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import web_directory
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
|
||||
from tests.helpers import (
|
||||
AiModelApiFactory,
|
||||
ChatModelFactory,
|
||||
@@ -43,6 +34,8 @@ from tests.helpers import (
|
||||
UserFactory,
|
||||
get_chat_api_key,
|
||||
get_chat_provider,
|
||||
get_index_files,
|
||||
get_sample_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,23 +52,16 @@ def django_db_setup(django_db_setup, django_db_blocker):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_config() -> SearchConfig:
|
||||
def search_config():
|
||||
search_model = get_default_search_model()
|
||||
state.embeddings_model = dict()
|
||||
state.embeddings_model["default"] = EmbeddingsModel()
|
||||
state.cross_encoder_model = dict()
|
||||
state.cross_encoder_model["default"] = CrossEncoderModel()
|
||||
|
||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
search_config = SearchConfig()
|
||||
|
||||
search_config.image = ImageSearchConfig(
|
||||
encoder="sentence-transformers/clip-ViT-B-32",
|
||||
model_directory=model_dir / "image/",
|
||||
encoder_type=None,
|
||||
state.embeddings_model["default"] = EmbeddingsModel(
|
||||
model_name=search_model.bi_encoder, model_kwargs=search_model.bi_encoder_model_config
|
||||
)
|
||||
state.cross_encoder_model = dict()
|
||||
state.cross_encoder_model["default"] = CrossEncoderModel(
|
||||
model_name=search_model.cross_encoder, model_kwargs=search_model.cross_encoder_model_config
|
||||
)
|
||||
|
||||
return search_config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@@ -196,17 +182,6 @@ def default_openai_chat_model_option():
|
||||
return chat_model
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def offline_agent():
|
||||
chat_model = ChatModelFactory()
|
||||
return Agent.objects.create(
|
||||
name="Accountant",
|
||||
chat_model=chat_model,
|
||||
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def openai_agent():
|
||||
@@ -218,13 +193,6 @@ def openai_agent():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_process_lock():
|
||||
@@ -236,72 +204,23 @@ def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture(scope="function")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
content_config = ContentConfig()
|
||||
|
||||
LocalOrgConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=False, user=default_user)
|
||||
|
||||
if os.getenv("GITHUB_PAT_TOKEN"):
|
||||
GithubConfig.objects.create(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
GithubRepoConfig.objects.create(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
github_config=GithubConfig.objects.get(user=default_user),
|
||||
)
|
||||
|
||||
LocalPlaintextConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
return content_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def md_content_config():
|
||||
markdown_config = LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
)
|
||||
|
||||
return markdown_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||
def chat_client(search_config, default_user2: KhojUser):
|
||||
return chat_client_builder(search_config, default_user2, require_auth=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser):
|
||||
def chat_client_with_auth(search_config, default_user2: KhojUser):
|
||||
return chat_client_builder(search_config, default_user2, require_auth=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||
def chat_client_no_background(search_config, default_user2: KhojUser):
|
||||
return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client_with_large_kb(search_config: SearchConfig, default_user2: KhojUser):
|
||||
def chat_client_with_large_kb(search_config, default_user2: KhojUser):
|
||||
"""
|
||||
Chat client fixture that creates a large knowledge base with many files
|
||||
for stress testing atomic agent updates.
|
||||
@@ -312,19 +231,14 @@ def chat_client_with_large_kb(search_config: SearchConfig, default_user2: KhojUs
|
||||
@pytest.mark.django_db
|
||||
def chat_client_builder(search_config, user, index_content=True, require_auth=False):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
if index_content:
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=user,
|
||||
)
|
||||
file_type = "markdown"
|
||||
files_to_index = {file_type: get_index_files(input_filters=[f"tests/data/{file_type}/*.{file_type}"])}
|
||||
|
||||
# Index Markdown Content for Search
|
||||
all_files = fs_syncer.collect_files(user=user)
|
||||
configure_content(user, all_files)
|
||||
configure_content(user, files_to_index)
|
||||
|
||||
# Initialize Processor from Config
|
||||
chat_provider = get_chat_provider()
|
||||
@@ -360,17 +274,17 @@ def large_kb_chat_client_builder(search_config, user):
|
||||
import tempfile
|
||||
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
# Create temporary directory for large number of test files
|
||||
temp_dir = tempfile.mkdtemp(prefix="khoj_test_large_kb_")
|
||||
file_type = "markdown"
|
||||
large_file_list = []
|
||||
|
||||
try:
|
||||
# Generate 200 test files with substantial content
|
||||
for i in range(300):
|
||||
file_path = os.path.join(temp_dir, f"test_file_{i:03d}.markdown")
|
||||
file_path = os.path.join(temp_dir, f"test_file_{i:03d}.{file_type}")
|
||||
content = f"""
|
||||
# Test File {i}
|
||||
|
||||
@@ -420,16 +334,9 @@ End of file {i}.
|
||||
f.write(content)
|
||||
large_file_list.append(file_path)
|
||||
|
||||
# Create LocalMarkdownConfig with all the generated files
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=large_file_list,
|
||||
input_filter=None,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Index all the files into the user's knowledge base
|
||||
all_files = fs_syncer.collect_files(user=user)
|
||||
configure_content(user, all_files)
|
||||
# Index all generated files into the user's knowledge base
|
||||
files_to_index = {file_type: get_index_files(input_files=large_file_list, input_filters=None)}
|
||||
configure_content(user, files_to_index)
|
||||
|
||||
# Verify we have a substantial knowledge base
|
||||
file_count = FileObject.objects.filter(user=user, agent=None).count()
|
||||
@@ -481,12 +388,8 @@ def fastapi_app():
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(
|
||||
content_config: ContentConfig,
|
||||
search_config: SearchConfig,
|
||||
api_user: KhojApiUser,
|
||||
):
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
state.embeddings_model = dict()
|
||||
state.embeddings_model["default"] = EmbeddingsModel()
|
||||
@@ -516,173 +419,18 @@ def client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=default_user2,
|
||||
)
|
||||
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
configure_content(default_user2, all_files)
|
||||
|
||||
# Initialize Processor from Config
|
||||
ChatModelFactory(
|
||||
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
|
||||
tokenizer=None,
|
||||
max_prompt_size=None,
|
||||
model_type="offline",
|
||||
)
|
||||
UserConversationProcessorConfigFactory(user=default_user2)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
|
||||
# Setup
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
input_filters = org_config.input_filter
|
||||
new_org_file = Path(input_filters[0]).parent / "new_file.org"
|
||||
new_org_file.touch()
|
||||
|
||||
yield new_org_file
|
||||
|
||||
# Cleanup
|
||||
if new_org_file.exists():
|
||||
new_org_file.unlink()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
|
||||
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
|
||||
return LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pdf_configured_user1(default_user: KhojUser):
|
||||
LocalPdfConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/pdf/singlepage.pdf"],
|
||||
user=default_user,
|
||||
)
|
||||
# Index Markdown Content for Search
|
||||
all_files = fs_syncer.collect_files(user=default_user)
|
||||
configure_content(default_user, all_files)
|
||||
# Read data from pdf file at tests/data/pdf/singlepage.pdf
|
||||
pdf_file_path = "tests/data/pdf/singlepage.pdf"
|
||||
with open(pdf_file_path, "rb") as pdf_file:
|
||||
pdf_data = pdf_file.read()
|
||||
|
||||
knowledge_base = {"pdf": {"singlepage.pdf": pdf_data}}
|
||||
# Index Content for Search
|
||||
configure_content(default_user, knowledge_base)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_org_data():
|
||||
return get_sample_data("org")
|
||||
|
||||
|
||||
def get_sample_data(type):
|
||||
sample_data = {
|
||||
"org": {
|
||||
"elisp.org": """
|
||||
* Emacs Khoj
|
||||
/An Emacs interface for [[https://github.com/khoj-ai/khoj][khoj]]/
|
||||
|
||||
** Requirements
|
||||
- Install and Run [[https://github.com/khoj-ai/khoj][khoj]]
|
||||
|
||||
** Installation
|
||||
*** Direct
|
||||
- Put ~khoj.el~ in your Emacs load path. For e.g. ~/.emacs.d/lisp
|
||||
- Load via ~use-package~ in your ~/.emacs.d/init.el or .emacs file by adding below snippet
|
||||
#+begin_src elisp
|
||||
;; Khoj Package
|
||||
(use-package khoj
|
||||
:load-path "~/.emacs.d/lisp/khoj.el"
|
||||
:bind ("C-c s" . 'khoj))
|
||||
#+end_src
|
||||
|
||||
*** Using [[https://github.com/quelpa/quelpa#installation][Quelpa]]
|
||||
- Ensure [[https://github.com/quelpa/quelpa#installation][Quelpa]], [[https://github.com/quelpa/quelpa-use-package#installation][quelpa-use-package]] are installed
|
||||
- Add below snippet to your ~/.emacs.d/init.el or .emacs config file and execute it.
|
||||
#+begin_src elisp
|
||||
;; Khoj Package
|
||||
(use-package khoj
|
||||
:quelpa (khoj :fetcher url :url "https://raw.githubusercontent.com/khoj-ai/khoj/master/interface/emacs/khoj.el")
|
||||
:bind ("C-c s" . 'khoj))
|
||||
#+end_src
|
||||
|
||||
** Usage
|
||||
1. Call ~khoj~ using keybinding ~C-c s~ or ~M-x khoj~
|
||||
2. Enter Query in Natural Language
|
||||
e.g. "What is the meaning of life?" "What are my life goals?"
|
||||
3. Wait for results
|
||||
*Note: It takes about 15s on a Mac M1 and a ~100K lines corpus of org-mode files*
|
||||
4. (Optional) Narrow down results further
|
||||
Include/Exclude specific words from results by adding to query
|
||||
e.g. "What is the meaning of life? -god +none"
|
||||
|
||||
""",
|
||||
"readme.org": """
|
||||
* Khoj
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
|
||||
** Dependencies
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
|
||||
** Install
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src""",
|
||||
},
|
||||
"markdown": {
|
||||
"readme.markdown": """
|
||||
# Khoj
|
||||
Allow natural language search on user content like notes, images using transformer based models
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline
|
||||
|
||||
## Dependencies
|
||||
- Python3
|
||||
- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links)
|
||||
|
||||
## Install
|
||||
```shell
|
||||
git clone
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
```
|
||||
"""
|
||||
},
|
||||
"plaintext": {
|
||||
"readme.txt": """
|
||||
Khoj
|
||||
Allow natural language search on user content like notes, images using transformer based models
|
||||
|
||||
All data is processed locally. User can interface with khoj app via Emacs, API or Commandline
|
||||
|
||||
Dependencies
|
||||
- Python3
|
||||
- Miniconda
|
||||
|
||||
Install
|
||||
git clone
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
"""
|
||||
},
|
||||
}
|
||||
|
||||
return sample_data[type]
|
||||
|
||||
143
tests/helpers.py
143
tests/helpers.py
@@ -1,3 +1,5 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
@@ -17,9 +19,12 @@ from khoj.database.models import (
|
||||
UserConversationConfig,
|
||||
)
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
|
||||
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE):
|
||||
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
||||
if provider and provider in ChatModel.ModelType:
|
||||
return ChatModel.ModelType(provider)
|
||||
@@ -61,6 +66,140 @@ def generate_chat_history(message_list):
|
||||
return chat_history
|
||||
|
||||
|
||||
def get_sample_data(type):
|
||||
sample_data = {
|
||||
"org": {
|
||||
"elisp.org": """
|
||||
* Emacs Khoj
|
||||
/An Emacs interface for [[https://github.com/khoj-ai/khoj][khoj]]/
|
||||
|
||||
** Requirements
|
||||
- Install and Run [[https://github.com/khoj-ai/khoj][khoj]]
|
||||
|
||||
** Installation
|
||||
*** Direct
|
||||
- Put ~khoj.el~ in your Emacs load path. For e.g. ~/.emacs.d/lisp
|
||||
- Load via ~use-package~ in your ~/.emacs.d/init.el or .emacs file by adding below snippet
|
||||
#+begin_src elisp
|
||||
;; Khoj Package
|
||||
(use-package khoj
|
||||
:load-path "~/.emacs.d/lisp/khoj.el"
|
||||
:bind ("C-c s" . 'khoj))
|
||||
#+end_src
|
||||
|
||||
*** Using [[https://github.com/quelpa/quelpa#installation][Quelpa]]
|
||||
- Ensure [[https://github.com/quelpa/quelpa#installation][Quelpa]], [[https://github.com/quelpa/quelpa-use-package#installation][quelpa-use-package]] are installed
|
||||
- Add below snippet to your ~/.emacs.d/init.el or .emacs config file and execute it.
|
||||
#+begin_src elisp
|
||||
;; Khoj Package
|
||||
(use-package khoj
|
||||
:quelpa (khoj :fetcher url :url "https://raw.githubusercontent.com/khoj-ai/khoj/master/interface/emacs/khoj.el")
|
||||
:bind ("C-c s" . 'khoj))
|
||||
#+end_src
|
||||
|
||||
** Usage
|
||||
1. Call ~khoj~ using keybinding ~C-c s~ or ~M-x khoj~
|
||||
2. Enter Query in Natural Language
|
||||
e.g. "What is the meaning of life?" "What are my life goals?"
|
||||
3. Wait for results
|
||||
*Note: It takes about 15s on a Mac M1 and a ~100K lines corpus of org-mode files*
|
||||
4. (Optional) Narrow down results further
|
||||
Include/Exclude specific words from results by adding to query
|
||||
e.g. "What is the meaning of life? -god +none"
|
||||
|
||||
""",
|
||||
"readme.org": """
|
||||
* Khoj
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
|
||||
** Dependencies
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
|
||||
** Install
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src""",
|
||||
},
|
||||
"markdown": {
|
||||
"readme.markdown": """
|
||||
# Khoj
|
||||
Allow natural language search on user content like notes, images using transformer based models
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [Emacs](./interface/emacs/khoj.el), API or Commandline
|
||||
|
||||
## Dependencies
|
||||
- Python3
|
||||
- [Miniconda](https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links)
|
||||
|
||||
## Install
|
||||
```shell
|
||||
git clone
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
```
|
||||
"""
|
||||
},
|
||||
"plaintext": {
|
||||
"readme.txt": """
|
||||
Khoj
|
||||
Allow natural language search on user content like notes, images using transformer based models
|
||||
|
||||
All data is processed locally. User can interface with khoj app via Emacs, API or Commandline
|
||||
|
||||
Dependencies
|
||||
- Python3
|
||||
- Miniconda
|
||||
|
||||
Install
|
||||
git clone
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
"""
|
||||
},
|
||||
}
|
||||
|
||||
return sample_data[type]
|
||||
|
||||
|
||||
def get_index_files(
|
||||
input_files: list[str] = None, input_filters: list[str] | None = ["tests/data/org/*.org"]
|
||||
) -> dict[str, str]:
|
||||
# Input Validation
|
||||
if is_none_or_empty(input_files) and is_none_or_empty(input_filters):
|
||||
logger.debug("At least one of input_files or input_filter is required to be specified")
|
||||
return {}
|
||||
|
||||
# Get files to process
|
||||
absolute_files, filtered_files = set(), set()
|
||||
if input_files:
|
||||
absolute_files = {get_absolute_path(input_file) for input_file in input_files}
|
||||
if input_filters:
|
||||
filtered_files = {
|
||||
filtered_file
|
||||
for file_filter in input_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(file_filter), recursive=True)
|
||||
if os.path.isfile(filtered_file)
|
||||
}
|
||||
|
||||
all_files = sorted(absolute_files | filtered_files)
|
||||
|
||||
filename_to_content_map = {}
|
||||
for file in all_files:
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
try:
|
||||
filename_to_content_map[file] = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {file}. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
return filename_to_content_map
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
@@ -93,7 +232,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory):
|
||||
|
||||
max_prompt_size = 20000
|
||||
tokenizer = None
|
||||
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
||||
name = "gemini-2.0-flash"
|
||||
model_type = get_chat_provider()
|
||||
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from tests.helpers import ChatModelFactory
|
||||
def test_create_default_agent(default_user: KhojUser):
|
||||
ChatModelFactory()
|
||||
|
||||
agent = AgentAdapters.create_default_agent(default_user)
|
||||
agent = AgentAdapters.create_default_agent()
|
||||
assert agent is not None
|
||||
assert agent.input_tools == []
|
||||
assert agent.output_modes == []
|
||||
|
||||
@@ -1,49 +1,15 @@
|
||||
# Standard Modules
|
||||
from pathlib import Path
|
||||
from random import random
|
||||
|
||||
from khoj.utils.cli import cli
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_minimal_default():
|
||||
# Act
|
||||
actual_args = cli([])
|
||||
actual_args = cli(["-vvv"])
|
||||
|
||||
# Assert
|
||||
assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml"))
|
||||
assert actual_args.regenerate == False
|
||||
assert actual_args.verbose == 0
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_invalid_config_file_path():
|
||||
# Arrange
|
||||
non_existent_config_file = f"non-existent-khoj-{random()}.yml"
|
||||
|
||||
# Act
|
||||
actual_args = cli([f"--config-file={non_existent_config_file}"])
|
||||
|
||||
# Assert
|
||||
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
|
||||
assert actual_args.config == None
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_config_from_file():
|
||||
# Act
|
||||
actual_args = cli(["--config-file=tests/data/config.yml", "--regenerate", "-vvv"])
|
||||
|
||||
# Assert
|
||||
assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
|
||||
assert actual_args.regenerate == True
|
||||
assert actual_args.config is not None
|
||||
assert actual_args.log_file == Path("~/.khoj/khoj.log")
|
||||
assert actual_args.verbose == 3
|
||||
|
||||
# Ensure content config is loaded from file
|
||||
assert actual_args.config.content_type.org.input_files == [
|
||||
Path("~/first_from_config.org"),
|
||||
Path("~/second_from_config.org"),
|
||||
]
|
||||
|
||||
@@ -13,7 +13,6 @@ from khoj.database.models import KhojApiUser, KhojUser
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import state
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
|
||||
|
||||
# Test
|
||||
@@ -283,10 +282,6 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
|
||||
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
# Arrange
|
||||
state.anonymous_mode = True
|
||||
if state.config and state.config.content_type:
|
||||
state.config.content_type = None
|
||||
state.search_models = configure_search_types()
|
||||
|
||||
configure_routes(fastapi_app)
|
||||
client = TestClient(fastapi_app)
|
||||
|
||||
@@ -300,7 +295,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||
def test_notes_search(client, search_config, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
@@ -319,7 +314,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||
def test_notes_search_no_results(client, search_config, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
@@ -335,9 +330,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_only_filters(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser
|
||||
):
|
||||
def test_notes_search_with_only_filters(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(
|
||||
@@ -401,9 +394,7 @@ def test_notes_search_with_exclude_filter(client, sample_org_data, default_user:
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_requires_parent_context(
|
||||
client, search_config: SearchConfig, sample_org_data, default_user: KhojUser
|
||||
):
|
||||
def test_notes_search_requires_parent_context(client, search_config, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
# Application Packages
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
# Mock Entry class for testing
|
||||
class Entry:
|
||||
def __init__(self, compiled="", raw="", file=""):
|
||||
self.compiled = compiled
|
||||
self.raw = raw
|
||||
self.file = file
|
||||
|
||||
|
||||
def test_can_filter_no_file_filter():
|
||||
|
||||
@@ -3,8 +3,6 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
||||
from khoj.utils.fs_syncer import get_markdown_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_extract_markdown_with_no_headings(tmp_path):
|
||||
@@ -212,43 +210,6 @@ longer body line 2.1
|
||||
), "Third entry is second entries child heading"
|
||||
|
||||
|
||||
def test_get_markdown_files(tmp_path):
|
||||
"Ensure Markdown files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.md")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.md")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.markdown")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.markdown")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="notes.md")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-markdown.md")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = set(
|
||||
[os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.md"]
|
||||
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
|
||||
|
||||
markdown_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(filter) for filter in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_org_files = get_markdown_files(markdown_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert set(extracted_org_files.keys()) == expected_files
|
||||
|
||||
|
||||
def test_line_number_tracking_in_recursive_split():
|
||||
"Ensure line numbers in URIs are correct after recursive splitting by checking against the actual file."
|
||||
# Arrange
|
||||
|
||||
@@ -1,610 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
||||
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
)
|
||||
|
||||
import freezegun
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.utils.constants import default_offline_chat_models
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_model():
|
||||
return download_model(default_offline_chat_models[0], max_tokens=5000)
|
||||
|
||||
|
||||
freezegun.configure(extend_ignore_list=["transformers"])
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
assert len(response) >= 1
|
||||
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
|
||||
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
# The user query should be the last question in the response
|
||||
assert response[-1] == ["Which countries did I visit last month?"]
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
|
||||
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='1984-01-01'", ""),
|
||||
("dt>='1984-01-01'", "dt<'1985-01-01'"),
|
||||
("dt>='1984-01-01'", "dt<='1984-12-31'"),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to 1984 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2):
|
||||
# Act
|
||||
responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(responses) >= 2
|
||||
assert ["the Sun" in response for response in responses]
|
||||
assert ["the Moon" in response for response in responses]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["height", "taller", "shorter", "heights", "who"]
|
||||
assert len(response) <= 3
|
||||
|
||||
for question in response:
|
||||
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
query = "Does he have any sons?"
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
query,
|
||||
default_user2,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
any_expected_with_barbara = [
|
||||
"sibling",
|
||||
"brother",
|
||||
]
|
||||
|
||||
any_expected_with_anderson = [
|
||||
"son",
|
||||
"sons",
|
||||
"children",
|
||||
"family",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
# Ensure the remaining generated search queries use proper nouns and chat history context
|
||||
for question in response:
|
||||
if "Barbara" in question:
|
||||
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
elif "Anderson" in question:
|
||||
assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
else:
|
||||
assert False, (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
"Is she a Doctor?",
|
||||
default_user2,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"Barbara",
|
||||
"Anderson",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||
"Expected chat actor to mention person's by name, but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
default_user2,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert "Masai Mara" in response[0]
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to April 2000 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, expected_conversation_commands",
|
||||
[
|
||||
(
|
||||
"Where did I learn to swim?",
|
||||
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"Where is the nearest hospital?",
|
||||
{"sources": [ConversationCommand.Online], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"Summarize the wikipedia page on the history of the internet",
|
||||
{"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"How many noble gases are there?",
|
||||
{"sources": [ConversationCommand.General], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"Make a painting incorporating my past diving experiences",
|
||||
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image},
|
||||
),
|
||||
(
|
||||
"Create a chart of the weather over the next 7 days in Timbuktu",
|
||||
{"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"What's the highest point in this country and have I been there?",
|
||||
{"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_select_data_sources_actor_chooses_to_search_notes(
|
||||
client_offline_chat, user_query, expected_conversation_commands, default_user2
|
||||
):
|
||||
# Act
|
||||
selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)
|
||||
|
||||
# Assert
|
||||
assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"])
|
||||
assert expected_conversation_commands["output"] == selected_conversation_commands["output"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
chat_log = [
|
||||
(
|
||||
"Let's talk about the current events around the world.",
|
||||
"Sure, let's discuss the current events. What would you like to know?",
|
||||
[],
|
||||
),
|
||||
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
|
||||
]
|
||||
chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log))
|
||||
|
||||
# Act
|
||||
tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["online"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj", "KHOJ"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context in previous notes and chat history to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
|
||||
assert "Testville" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[
|
||||
{"compiled": "Testatron was born on 1st April 1984 in Testville."}
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
assert "Testville" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
|
||||
@pytest.mark.chatquality
|
||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
"Chat actor should not try make up answers to unanswerable questions."
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"don't know",
|
||||
"do not know",
|
||||
"no information",
|
||||
"do not have",
|
||||
"don't have",
|
||||
"cannot answer",
|
||||
"I'm sorry",
|
||||
]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to say they don't know in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_current_date_awareness(loaded_model):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I have for Dinner today?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["tacos", "Tacos"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected [T|t]acos in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
|
||||
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How much did I spend on dining this year?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
assert "20" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
|
||||
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "testing"]
|
||||
assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
|
||||
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
|
||||
"Expected [T|t]est in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""# Ramya
|
||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Fang
|
||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Aiyla
|
||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How many kids does my older sister have?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister", "Which one"]
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_agent_prompt_should_be_used(loaded_model, offline_agent):
|
||||
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
|
||||
# Arrange
|
||||
context = [
|
||||
{"compiled": f"""I went to the store and bought some bananas for 2.20"""},
|
||||
{"compiled": f"""I went to the store and bought some apples for 1.30"""},
|
||||
{"compiled": f"""I went to the store and bought some oranges for 6.00"""},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model without the agent prompt does not include the summary of purchases
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
assert all([expected_response not in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + response
|
||||
)
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
loaded_model=loaded_model,
|
||||
agent=offline_agent,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model with the agent prompt does include the summary of purchases
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
"Ensure chat context and response together do not exceed max prompt size for the model"
|
||||
# Arrange
|
||||
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
|
||||
context = [{"compiled": " ".join([f"{number}" for number in range(2043)])}]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What numbers come after these?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert prompt_size_exceeded_error not in response, (
|
||||
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||
)
|
||||
@@ -1,726 +0,0 @@
|
||||
import urllib.parse
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from tests.helpers import ConversationFactory, get_chat_provider
|
||||
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
)
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def generate_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message, context in message_list:
|
||||
message_to_log(
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
chat_history=conversation_log.get("chat", []),
|
||||
)
|
||||
return conversation_log
|
||||
|
||||
|
||||
def create_conversation(message_list, user, agent=None):
|
||||
# Generate conversation logs
|
||||
conversation_log = generate_history(message_list)
|
||||
# Update Conversation Metadata Logs in Database
|
||||
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
|
||||
# Act
|
||||
query = "Hello, my name is Testatron. Who are you?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_online_content(client_offline_chat):
|
||||
# Act
|
||||
q = "/online give me the link to paul graham's essay how to do great work"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"https://paulgraham.com/greatwork.html",
|
||||
"https://www.paulgraham.com/greatwork.html",
|
||||
"http://www.paulgraham.com/greatwork.html",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
), f"Expected links: {expected_responses}. Actual response: {response_message}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_online_webpage_content(client_offline_chat):
|
||||
# Act
|
||||
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
expected_responses = ["185", "1871", "horse"]
|
||||
assert response.status_code == 200
|
||||
assert any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
), f"Expected response with {expected_responses}. But actual response had: {response_message}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "What is my name?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected [T|t]estatron in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was Xi Li born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was I born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# 1. Infer who I am from chat history
|
||||
# 2. Infer I was born in Testville from previously retrieved notes
|
||||
assert "Testville" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was I born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# Inference in a multi-turn conversation
|
||||
# 1. Infer who I am from chat history
|
||||
# 2. Search for notes about when <my_name_from_chat_history> was born
|
||||
# 3. Extract where I was born from currently retrieved notes
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was I born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say they don't know in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_general_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "/general Where was Xi Li born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" not in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "/notes Where was Xi Li born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_file_filter(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
no_answer_response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": no_answer_query, "stream": True}
|
||||
).content.decode("utf-8")
|
||||
answer_response = client_offline_chat.post(f"/api/chat", json={"q": answer_query, "stream": True}).content.decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "Fujiang" not in no_answer_response
|
||||
assert "Fujiang" in answer_response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response_message == prompts.no_notes_found.format()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize tell me about Xiu"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message is not None
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation1 = create_conversation(message_list, default_user2)
|
||||
conversation2 = create_conversation(message_list, default_user2)
|
||||
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
# add file filter to conversation 1.
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
|
||||
)
|
||||
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
# now make sure that the file filter is still in conversation 1
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "imaginary.markdown" file to the file filters
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||
)
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_diff_user_file(
|
||||
client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# Get the pdf file called singlepage.pdf
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "singlepage.pdf" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
# add singlepage.pdf to the file filters
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01", ignore=["transformers"])
|
||||
def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Act
|
||||
query = "Where did I have lunch today?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Arak", "Medellin"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say Arak, Medellin, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01", ignore=["transformers"])
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
query = "How much did I spend on dining this year?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "26" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "Write a haiku about unit testing. Do not say anything else."
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "Test"]
|
||||
assert response.status_code == 200
|
||||
assert len(response_message.splitlines()) == 3 # haikus are 3 lines long
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected [T|t]est in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
|
||||
# Act
|
||||
query = "What is the name of Namitas older son"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"which of them is the older",
|
||||
"which one is older",
|
||||
"which of them is older",
|
||||
"which one is the older",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat director to ask for clarification in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "What is my name?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected [T|t]estatron in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
message_list2 = [
|
||||
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 14th August 1947.", []),
|
||||
("What's my favorite color", "Your favorite color is maroon.", []),
|
||||
("Where was I born?", "You were born in a potato farm.", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
create_conversation(message_list2, default_user2)
|
||||
|
||||
# Act
|
||||
query = "What is my favorite color?"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["green"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||
client_offline_chat, default_user2: KhojUser, offline_agent: Agent
|
||||
):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2, offline_agent)
|
||||
|
||||
# Act
|
||||
query = "/general What did I eat for breakfast?"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert that agent only responds with the summary of spending
|
||||
expected_responses = ["13.00", "13", "13.0", "thirteen"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "What is my name?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert len(response_message) > 0
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
query = "Is Xi older than Namita?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
@@ -4,9 +4,8 @@ import time
|
||||
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_configure_indexing_heading_only_entries(tmp_path):
|
||||
@@ -330,46 +329,6 @@ def test_file_with_no_headings_to_entry(tmp_path):
|
||||
assert len(entries[1]) == 1
|
||||
|
||||
|
||||
def test_get_org_files(tmp_path):
|
||||
"Ensure Org files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.org")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.org")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.org")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.org")
|
||||
# Include via input-file field
|
||||
orgfile1 = create_file(tmp_path, filename="orgfile1.org")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="orgfile2.org")
|
||||
create_file(tmp_path, filename="text1.txt")
|
||||
|
||||
expected_files = set(
|
||||
[
|
||||
os.path.join(tmp_path, file.name)
|
||||
for file in [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]
|
||||
]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "orgfile1.org"]
|
||||
input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
|
||||
|
||||
org_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(filter) for filter in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_org_files = get_org_files(org_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert set(extracted_org_files.keys()) == expected_files
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract org entries with different level headings."
|
||||
# Arrange
|
||||
|
||||
@@ -4,8 +4,6 @@ import re
|
||||
import pytest
|
||||
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
from khoj.utils.fs_syncer import get_pdf_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_single_page_pdf_to_jsonl():
|
||||
@@ -61,43 +59,6 @@ def test_ocr_page_pdf_to_jsonl():
|
||||
assert re.search(expected_str_with_variable_spaces, raw_entry) is not None
|
||||
|
||||
|
||||
def test_get_pdf_files(tmp_path):
|
||||
"Ensure Pdf files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.pdf")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.pdf")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.pdf")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.pdf")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="document.pdf")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-document.pdf")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = set(
|
||||
[os.path.join(tmp_path, file.name) for file in [group1_file1, group1_file2, group2_file1, group2_file2, file1]]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "document.pdf"]
|
||||
input_filter = [tmp_path / "group1*.pdf", tmp_path / "group2*.pdf"]
|
||||
|
||||
pdf_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(path) for path in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_pdf_files = get_pdf_files(pdf_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_pdf_files) == 5
|
||||
assert set(extracted_pdf_files.keys()) == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry=None, filename="document.pdf"):
|
||||
pdf_file = tmp_path / filename
|
||||
|
||||
@@ -1,27 +1,20 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from khoj.database.models import KhojUser, LocalPlaintextConfig
|
||||
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.utils.fs_syncer import get_plaintext_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_plaintext_file(tmp_path):
|
||||
def test_plaintext_file():
|
||||
"Convert files with no heading to jsonl."
|
||||
# Arrange
|
||||
raw_entry = f"""
|
||||
Hi, I am a plaintext file and I have some plaintext words.
|
||||
"""
|
||||
plaintextfile = create_file(tmp_path, raw_entry)
|
||||
plaintextfile = "test.txt"
|
||||
data = {plaintextfile: raw_entry}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified plaintext files
|
||||
|
||||
data = {
|
||||
f"{plaintextfile}": raw_entry,
|
||||
}
|
||||
|
||||
entries = PlaintextToEntries.extract_plaintext_entries(data)
|
||||
|
||||
# Convert each entry.file to absolute path to make them JSON serializable
|
||||
@@ -37,59 +30,20 @@ def test_plaintext_file(tmp_path):
|
||||
assert entries[1][0].compiled == f"{plaintextfile}\n{raw_entry}"
|
||||
|
||||
|
||||
def test_get_plaintext_files(tmp_path):
|
||||
"Ensure Plaintext files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.md")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.md")
|
||||
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.markdown")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.markdown")
|
||||
group2_file4 = create_file(tmp_path, filename="group2-file4.html")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="notes.txt")
|
||||
# Include unsupported file types
|
||||
create_file(tmp_path, filename="group2-unincluded.py")
|
||||
create_file(tmp_path, filename="group2-unincluded.csv")
|
||||
create_file(tmp_path, filename="group2-unincluded.csv")
|
||||
create_file(tmp_path, filename="group2-file3.mbox")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-markdown.md")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = set(
|
||||
[
|
||||
os.path.join(tmp_path, file.name)
|
||||
for file in [group1_file1, group1_file2, group2_file1, group2_file2, group2_file4, file1]
|
||||
]
|
||||
)
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.txt"]
|
||||
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.*"]
|
||||
|
||||
plaintext_config = TextContentConfig(
|
||||
input_files=input_files,
|
||||
input_filter=[str(filter) for filter in input_filter],
|
||||
compressed_jsonl=tmp_path / "test.jsonl",
|
||||
embeddings_file=tmp_path / "test_embeddings.jsonl",
|
||||
)
|
||||
|
||||
# Act
|
||||
extracted_plaintext_files = get_plaintext_files(plaintext_config)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_plaintext_files) == len(expected_files)
|
||||
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
||||
|
||||
|
||||
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||
def test_parse_html_plaintext_file(tmp_path):
|
||||
"Ensure HTML files are parsed correctly"
|
||||
# Arrange
|
||||
# Setup input-files, input-filters
|
||||
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
|
||||
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||
raw_entry = dedent(
|
||||
f"""
|
||||
<html>
|
||||
<head><title>Test HTML</title></head>
|
||||
<body>
|
||||
<div>Test content</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
extracted_plaintext_files = {"test.html": raw_entry}
|
||||
|
||||
# Act
|
||||
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
|
||||
|
||||
@@ -2,23 +2,16 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
|
||||
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
||||
from khoj.database.models import Entry, GithubConfig, KhojUser
|
||||
from khoj.processor.content.github.github_to_entries import GithubToEntries
|
||||
from khoj.processor.content.images.image_to_entries import ImageToEntries
|
||||
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from tests.helpers import get_index_files, get_sample_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,53 +19,20 @@ logger = logging.getLogger(__name__)
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
|
||||
# Arrange
|
||||
# Ensure file mentioned in org.input-files is missing
|
||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||
single_new_file.unlink()
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_org_files(org_config_with_only_new_file)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
|
||||
# Arrange
|
||||
orgfile = tmp_path / "directory.org" / "file.org"
|
||||
orgfile.parent.mkdir()
|
||||
with open(orgfile, "w") as f:
|
||||
f.write("* Heading\n- List item\n")
|
||||
|
||||
LocalOrgConfig.objects.create(
|
||||
input_filter=[f"{tmp_path}/**/*"],
|
||||
input_files=None,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
# Act
|
||||
org_files = collect_files(user=default_user)["org"]
|
||||
|
||||
# Assert
|
||||
# should return orgfile and not raise IsADirectoryError
|
||||
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||
):
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(search_config, default_user: KhojUser):
|
||||
# Arrange
|
||||
initial_data = {
|
||||
"test.org": "* First heading\nFirst content",
|
||||
"test2.org": "* Second heading\nSecond content",
|
||||
}
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
final_data = {"new_file.org": ""}
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||
@@ -84,13 +44,14 @@ def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_indexer_deletes_embedding_before_regenerate(
|
||||
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
def test_text_indexer_deletes_embedding_before_regenerate(search_config, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
data = {
|
||||
"test1.org": "* Test heading\nTest content",
|
||||
"test2.org": "* Another heading\nAnother content",
|
||||
}
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
@@ -107,11 +68,10 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
def test_text_index_same_if_content_unchanged(search_config, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
data = {"test.org": "* Test heading\nTest content"}
|
||||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
@@ -136,20 +96,14 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.anyio
|
||||
# @pytest.mark.asyncio
|
||||
async def test_text_search(search_config: SearchConfig):
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_search(search_config):
|
||||
# Arrange
|
||||
default_user = await KhojUser.objects.acreate(
|
||||
default_user, _ = await KhojUser.objects.aget_or_create(
|
||||
username="test_user", password="test_password", email="test@example.com"
|
||||
)
|
||||
org_config = await LocalOrgConfig.objects.acreate(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
data = get_org_files(org_config)
|
||||
# Get some sample org data to index
|
||||
data = get_sample_data("org")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
@@ -175,17 +129,15 @@ async def test_text_search(search_config: SearchConfig):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
def test_entry_chunking_by_max_tokens(tmp_path, search_config, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
max_tokens = 256
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"* Entry more than {max_tokens} words\n")
|
||||
for index in range(max_tokens + 1):
|
||||
f.write(f"{index} ")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
new_file_to_index = tmp_path / "test.org"
|
||||
content = f"* Entry more than {max_tokens} words\n"
|
||||
for index in range(max_tokens + 1):
|
||||
content += f"{index} "
|
||||
data = {str(new_file_to_index): content}
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
@@ -200,9 +152,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
def test_entry_chunking_by_max_tokens_not_full_corpus(tmp_path, search_config, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
data = {
|
||||
@@ -231,13 +181,11 @@ conda activate khoj
|
||||
)
|
||||
|
||||
max_tokens = 256
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"* Entry more than {max_tokens} words\n")
|
||||
for index in range(max_tokens + 1):
|
||||
f.write(f"{index} ")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
new_file_to_index = tmp_path / "test.org"
|
||||
content = f"* Entry more than {max_tokens} words\n"
|
||||
for index in range(max_tokens + 1):
|
||||
content += f"{index} "
|
||||
data = {str(new_file_to_index): content}
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
@@ -257,34 +205,34 @@ conda activate khoj
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||
def test_regenerate_index_with_new_entry(search_config, default_user: KhojUser):
|
||||
# Arrange
|
||||
# Initial indexed files
|
||||
text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=True, user=default_user)
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
initial_data = get_org_files(org_config)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
org_config.input_files = [f"{new_org_file}"]
|
||||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
final_data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
# Regenerate index with only files from test data set
|
||||
files_to_index = get_index_files()
|
||||
text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Act
|
||||
# Update index with the new file
|
||||
new_file = "test.org"
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
files_to_index[new_file] = new_entry
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
for entry in updated_entries1:
|
||||
assert entry in updated_entries2
|
||||
|
||||
assert not any([new_org_file.name in entry for entry in updated_entries1])
|
||||
assert not any([new_org_file.name in entry for entry in existing_entries])
|
||||
assert any([new_org_file.name in entry for entry in updated_entries2])
|
||||
assert not any([new_file in entry for entry in updated_entries1])
|
||||
assert not any([new_file in entry for entry in existing_entries])
|
||||
assert any([new_file in entry for entry in updated_entries2])
|
||||
|
||||
assert any(
|
||||
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
|
||||
@@ -294,28 +242,24 @@ def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||
):
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(tmp_path, search_config, default_user: KhojUser):
|
||||
# Arrange
|
||||
initial_data = get_sample_data("org")
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
new_file_to_index = tmp_path / "test.org"
|
||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry}")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
# Initial data with duplicate entries
|
||||
data = {str(new_file_to_index): f"{new_entry}{new_entry}"}
|
||||
|
||||
# Act
|
||||
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model with no new changes
|
||||
# idempotent indexing when data unchanged
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
@@ -324,6 +268,7 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
|
||||
# verify the second indexing update has same entries and ordering as first
|
||||
for entry in updated_entries1:
|
||||
assert entry in updated_entries2
|
||||
|
||||
@@ -334,22 +279,17 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
|
||||
def test_update_index_with_deleted_entry(tmp_path, search_config, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
new_file_to_index = tmp_path / "test.org"
|
||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||
initial_data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}")
|
||||
|
||||
final_data = get_org_files(org_config_with_only_new_file)
|
||||
# Initial data with two entries
|
||||
initial_data = {str(new_file_to_index): f"{new_entry}{new_entry} -- Tatooine"}
|
||||
# Final data with only first entry, with second entry removed
|
||||
final_data = {str(new_file_to_index): f"{new_entry}"}
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
@@ -375,29 +315,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||
def test_update_index_with_new_entry(search_config, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
# Initial indexed files
|
||||
text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=True, user=default_user)
|
||||
old_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
f.write(new_entry)
|
||||
|
||||
data = get_org_files(org_config)
|
||||
# Regenerate index with only files from test data set
|
||||
files_to_index = get_index_files()
|
||||
new_entries = text_search.setup(OrgToEntries, files_to_index, regenerate=True, user=default_user)
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
# Update index with the new file
|
||||
new_file = "test.org"
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
final_data = {new_file: new_entry}
|
||||
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
updated_new_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
assert len(updated_entries1) == len(existing_entries) + 1
|
||||
for old_entry in old_entries:
|
||||
assert old_entry not in updated_new_entries
|
||||
assert len(updated_new_entries) == len(new_entries) + 1
|
||||
verify_embeddings(3, default_user)
|
||||
|
||||
|
||||
@@ -409,9 +349,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
||||
(OrgToEntries),
|
||||
],
|
||||
)
|
||||
def test_update_index_with_deleted_file(
|
||||
org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser
|
||||
):
|
||||
def test_update_index_with_deleted_file(text_to_entries: TextToEntries, search_config, default_user: KhojUser):
|
||||
"Delete entries associated with new file when file path with empty content passed."
|
||||
# Arrange
|
||||
file_to_index = "test"
|
||||
@@ -446,7 +384,7 @@ def test_update_index_with_deleted_file(
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||
def test_text_search_setup_github(search_config, default_user: KhojUser):
|
||||
# Arrange
|
||||
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
# Application Packages
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
# Mock Entry class for testing
|
||||
class Entry:
|
||||
def __init__(self, compiled="", raw=""):
|
||||
self.compiled = compiled
|
||||
self.raw = raw
|
||||
|
||||
|
||||
# Test
|
||||
|
||||
Reference in New Issue
Block a user