diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 45da7dda..641bf68b 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -18,8 +18,8 @@ ENV PATH="/opt/venv/bin:${PATH}" COPY pyproject.toml README.md ./ # Setup python environment - # Use the pre-built llama-cpp-python, torch cpu wheel -ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://abetlen.github.io/llama-cpp-python/whl/cpu" \ + # Use the pre-built torch cpu wheel +ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" \ # Avoid downloading unused cuda specific python packages CUDA_VISIBLE_DEVICES="" \ # Use static version to build app without git dependency diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e8db2c89..8e04a8fc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -64,8 +64,6 @@ jobs: DEBIAN_FRONTEND: noninteractive run: | apt update && apt install -y git libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6 - # required by llama-cpp-python prebuilt wheels - apt install -y musl-dev && ln -s /usr/lib/x86_64-linux-musl/libc.so /lib/libc.musl-x86_64.so.1 - name: ⬇️ Install Postgres env: diff --git a/documentation/docs/advanced/admin.md b/documentation/docs/advanced/admin.md index e04f81e2..961056b8 100644 --- a/documentation/docs/advanced/admin.md +++ b/documentation/docs/advanced/admin.md @@ -20,7 +20,7 @@ Add all the agents you want to use for your different use-cases like Writer, Res ### Chat Model Options Add all the chat models you want to try, use and switch between for your different use-cases. For each chat model you add: - `Chat model`: The name of an [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/en/docs/about-claude/models#model-names), [Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) or [Offline](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf) chat model. -- `Model type`: The chat model provider like `OpenAI`, `Offline`. +- `Model type`: The chat model provider like `OpenAI`, `Google`. - `Vision enabled`: Set to `true` if your model supports vision. This is currently only supported for vision capable OpenAI models like `gpt-4o` - `Max prompt size`, `Subscribed max prompt size`: These are optional fields. They are used to truncate the context to the maximum context size that can be passed to the model. This can help with accuracy and cost-saving.
- `Tokenizer`: This is an optional field. It is used to accurately count tokens and truncate context passed to the chat model to stay within the models max prompt size. diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx index e3c1e12f..01d60496 100644 --- a/documentation/docs/get-started/setup.mdx +++ b/documentation/docs/get-started/setup.mdx @@ -18,10 +18,6 @@ import TabItem from '@theme/TabItem'; These are the general setup instructions for self-hosted Khoj. You can install the Khoj server using either [Docker](?server=docker) or [Pip](?server=pip). -:::info[Offline Model + GPU] -To use the offline chat model with your GPU, we recommend using the Docker setup with Ollama . You can also use the local Khoj setup via the Python package directly. -::: - :::info[First Run] Restart your Khoj server after the first run to ensure all settings are applied correctly. ::: @@ -225,10 +221,6 @@ To start Khoj automatically in the background use [Task scheduler](https://www.w You can now open the web app at http://localhost:42110 and start interacting!
Nothing else is necessary, but you can customize your setup further by following the steps below. -:::info[First Message to Offline Chat Model] -The offline chat model gets downloaded when you first send a message to it. The download can take a few minutes! Subsequent messages should be faster. -::: - ### Add Chat Models

Login to the Khoj Admin Panel

Go to http://localhost:42110/server/admin and login with the admin credentials you setup during installation. @@ -301,13 +293,14 @@ Offline chat stays completely private and can work without internet using any op - A Nvidia, AMD GPU or a Mac M1+ machine would significantly speed up chat responses ::: -1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*. -2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel -3. Set the `chat-model` field to the name of your preferred chat model - - Make sure the `model-type` is set to `Offline` -4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/). -5. Restart the Khoj server and [start chatting](http://localhost:42110) with your new offline model! - +1. Install any Openai API compatible local ai model server like [llama-cpp-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server), Ollama, vLLM etc. +2. Add an [ai model api](http://localhost:42110/server/admin/database/aimodelapi/add/) on the admin panel + - Set the `api url` field to the url of your local ai model provider like `http://localhost:11434/v1/` for Ollama +3. Restart the Khoj server to load models available on your local ai model provider + - If that doesn't work, you'll need to manually add available [chat model](http://localhost:42110/server/admin/database/chatmodel/add) in the admin panel. +4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) +5. [Start chatting](http://localhost:42110) with your local AI! + :::tip[Multiple Chat Models] diff --git a/pyproject.toml b/pyproject.toml index 1a580dae..6491dc06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ dependencies = [ "django == 5.1.10", "django-unfold == 0.42.0", "authlib == 1.2.1", - "llama-cpp-python == 0.2.88", "itsdangerous == 2.1.2", "httpx == 0.28.1", "pgvector == 0.2.4", diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 14eaadfe..76d7578b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -72,7 +72,6 @@ from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils import state -from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( clean_object_for_db, clean_text_for_db, @@ -1553,14 +1552,6 @@ class ConversationAdapters: if chat_model is None: chat_model = await ConversationAdapters.aget_default_chat_model() - if chat_model.model_type == ChatModel.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - chat_model_name = chat_model.name - max_tokens = chat_model.max_prompt_size - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - - return chat_model - if ( chat_model.model_type in [ diff --git a/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py b/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py new file mode 100644 index 00000000..256aa93d --- /dev/null +++ b/src/khoj/database/migrations/0092_alter_chatmodel_model_type_alter_chatmodel_name_and_more.py @@ -0,0 +1,36 @@ +# Generated by Django 5.1.10 on 2025-07-19 21:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0091_chatmodel_friendly_name_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="chatmodel", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("anthropic", "Anthropic"), ("google", "Google")], + default="google", + max_length=200, + ), + ), + migrations.AlterField( + model_name="chatmodel", + name="name", + field=models.CharField(default="gemini-2.5-flash", max_length=200), + ), + migrations.AlterField( + model_name="speechtotextmodeloptions", + name="model_name", + field=models.CharField(default="whisper-1", max_length=200), + ), + migrations.AlterField( + model_name="speechtotextmodeloptions", + name="model_type", + field=models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 7f80459a..03b43376 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -220,16 +220,15 @@ class PriceTier(models.TextChoices): class ChatModel(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" - OFFLINE = "offline" ANTHROPIC = "anthropic" GOOGLE = "google" max_prompt_size = models.IntegerField(default=None, null=True, blank=True) subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) - name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") + name = models.CharField(max_length=200, default="gemini-2.5-flash") friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True) - model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.GOOGLE) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) vision_enabled = models.BooleanField(default=False) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) @@ -605,11 +604,10 @@ class TextToImageModelConfig(DbBaseModel): class SpeechToTextModelOptions(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" - OFFLINE = "offline" - model_name = models.CharField(max_length=200, default="base") + model_name = models.CharField(max_length=200, default="whisper-1") friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True) - model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) diff --git a/src/khoj/main.py b/src/khoj/main.py index 50da2624..d5918a14 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -214,7 +214,6 @@ def set_state(args): ) state.anonymous_mode = args.anonymous_mode state.khoj_version = version("khoj") - state.chat_on_gpu = args.chat_on_gpu def start_server(app, host=None, port=None, socket=None): diff --git a/src/khoj/processor/conversation/offline/__init__.py b/src/khoj/processor/conversation/offline/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py deleted file mode 100644 index b117a48d..00000000 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ /dev/null @@ -1,224 +0,0 @@ -import asyncio -import logging -import os -from datetime import datetime -from threading import Thread -from time import perf_counter -from typing import Any, AsyncGenerator, Dict, List, Union - -from langchain_core.messages.chat import ChatMessage -from llama_cpp import Llama - -from khoj.database.models import Agent, ChatMessageModel, ChatModel -from khoj.processor.conversation import prompts -from khoj.processor.conversation.offline.utils import download_model -from khoj.processor.conversation.utils import ( - ResponseWithThought, - commit_conversation_trace, - generate_chatml_messages_with_context, - messages_to_print, -) -from khoj.utils import state -from khoj.utils.helpers import ( - is_none_or_empty, - is_promptrace_enabled, - truncate_code_context, -) -from khoj.utils.rawconfig import FileAttachment, LocationData -from khoj.utils.yaml import yaml_dump - -logger = logging.getLogger(__name__) - - -async def converse_offline( - # Query - user_query: str, - # Context - references: list[dict] = [], - online_results={}, - code_results={}, - query_files: str = None, - generated_files: List[FileAttachment] = None, - additional_context: List[str] = None, - generated_asset_results: Dict[str, Dict] = {}, - location_data: LocationData = None, - user_name: str = None, - chat_history: list[ChatMessageModel] = [], - # Model - model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - loaded_model: Union[Any, None] = None, - max_prompt_size=None, - tokenizer_name=None, - agent: Agent = None, - tracer: dict = {}, -) -> AsyncGenerator[ResponseWithThought, None]: - """ - Converse with user using Llama (Async Version) - """ - # Initialize Variables - assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) - tracer["chat_model"] = model_name - current_date = datetime.now() - - if agent and agent.personality: - system_prompt = prompts.custom_system_prompt_offline_chat.format( - name=agent.name, - bio=agent.personality, - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - else: - system_prompt = prompts.system_prompt_offline_chat.format( - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - - if location_data: - location_prompt = prompts.user_location.format(location=f"{location_data}") - system_prompt = f"{system_prompt}\n{location_prompt}" - - if user_name: - user_name_prompt = prompts.user_name.format(name=user_name) - system_prompt = f"{system_prompt}\n{user_name_prompt}" - - # Get Conversation Primer appropriate to Conversation Type - context_message = "" - if not is_none_or_empty(references): - context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n" - if not is_none_or_empty(online_results): - simplified_online_results = online_results.copy() - for result in online_results: - if online_results[result].get("webpages"): - simplified_online_results[result] = online_results[result]["webpages"] - - context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n" - if not is_none_or_empty(code_results): - context_message += ( - f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" - ) - context_message = context_message.strip() - - # Setup Prompt with Primer or Conversation History - messages = generate_chatml_messages_with_context( - user_query, - system_prompt, - chat_history, - context_message=context_message, - model_name=model_name, - loaded_model=offline_chat_model, - max_prompt_size=max_prompt_size, - tokenizer_name=tokenizer_name, - model_type=ChatModel.ModelType.OFFLINE, - query_files=query_files, - generated_files=generated_files, - generated_asset_results=generated_asset_results, - program_execution_context=additional_context, - ) - - logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") - - # Use asyncio.Queue and a thread to bridge sync iterator - queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue() - stop_phrases = ["", "INST]", "Notes:"] - - def _sync_llm_thread(): - """Synchronous function to run in a separate thread.""" - aggregated_response = "" - start_time = perf_counter() - state.chat_lock.acquire() - try: - response_iterator = send_message_to_model_offline( - messages, - loaded_model=offline_chat_model, - stop=stop_phrases, - max_prompt_size=max_prompt_size, - streaming=True, - tracer=tracer, - ) - for response in response_iterator: - response_delta: str = response["choices"][0]["delta"].get("content", "") - # Log the time taken to start response - if aggregated_response == "" and response_delta != "": - logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") - # Handle response chunk - aggregated_response += response_delta - # Put chunk into the asyncio queue (non-blocking) - try: - queue.put_nowait(ResponseWithThought(text=response_delta)) - except asyncio.QueueFull: - # Should not happen with default queue size unless consumer is very slow - logger.warning("Asyncio queue full during offline LLM streaming.") - # Potentially block here or handle differently if needed - asyncio.run(queue.put(ResponseWithThought(text=response_delta))) - - # Log the time taken to stream the entire response - logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - - # Save conversation trace - tracer["chat_model"] = model_name - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) - - except Exception as e: - logger.error(f"Error in offline LLM thread: {e}", exc_info=True) - finally: - state.chat_lock.release() - # Signal end of stream - queue.put_nowait(None) - - # Start the synchronous thread - thread = Thread(target=_sync_llm_thread) - thread.start() - - # Asynchronously consume from the queue - while True: - chunk = await queue.get() - if chunk is None: # End of stream signal - queue.task_done() - break - yield chunk - queue.task_done() - - # Wait for the thread to finish (optional, ensures cleanup) - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, thread.join) - - -def send_message_to_model_offline( - messages: List[ChatMessage], - loaded_model=None, - model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - temperature: float = 0.2, - streaming=False, - stop=[], - max_prompt_size: int = None, - response_type: str = "text", - tracer: dict = {}, -): - assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) - messages_dict = [{"role": message.role, "content": message.content} for message in messages] - seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None - response = offline_chat_model.create_chat_completion( - messages_dict, - stop=stop, - stream=streaming, - temperature=temperature, - response_format={"type": response_type}, - seed=seed, - ) - - if streaming: - return response - - response_text: str = response["choices"][0]["message"].get("content", "") - - # Save conversation trace for non-streaming responses - # Streamed responses need to be saved by the calling function - tracer["chat_model"] = model_name - tracer["temperature"] = temperature - if is_promptrace_enabled(): - commit_conversation_trace(messages, response_text, tracer) - - return ResponseWithThought(text=response_text) diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py deleted file mode 100644 index 88082ad1..00000000 --- a/src/khoj/processor/conversation/offline/utils.py +++ /dev/null @@ -1,80 +0,0 @@ -import glob -import logging -import math -import os -from typing import Any, Dict - -from huggingface_hub.constants import HF_HUB_CACHE - -from khoj.utils import state -from khoj.utils.helpers import get_device_memory - -logger = logging.getLogger(__name__) - - -def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None): - # Initialize Model Parameters - # Use n_ctx=0 to get context size from the model - kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False} - - # Decide whether to load model to GPU or CPU - device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu" - kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0 - - # Add chat format if known - if "llama-3" in repo_id.lower(): - kwargs["chat_format"] = "llama-3" - elif "gemma-2" in repo_id.lower(): - kwargs["chat_format"] = "gemma" - - # Check if the model is already downloaded - model_path = load_model_from_cache(repo_id, filename) - chat_model = None - try: - chat_model = load_model(model_path, repo_id, filename, kwargs) - except: - # Load model on CPU if GPU is not available - kwargs["n_gpu_layers"], device = 0, "cpu" - chat_model = load_model(model_path, repo_id, filename, kwargs) - - # Now load the model with context size set based on: - # 1. context size supported by model and - # 2. configured size or machine (V)RAM - kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens) - chat_model = load_model(model_path, repo_id, filename, kwargs) - - logger.debug( - f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window." - ) - return chat_model - - -def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}): - from llama_cpp.llama import Llama - - if model_path: - return Llama(model_path, **kwargs) - else: - return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) - - -def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): - # Construct the path to the model file in the cache directory - repo_org, repo_name = repo_id.split("/") - object_id = "--".join([repo_type, repo_org, repo_name]) - model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename]) - - # Check if the model file exists - paths = glob.glob(model_path) - if paths: - return paths[0] - else: - return None - - -def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int: - """Infer max prompt size based on device memory and max context window supported by the model""" - configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens - vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic - configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None - return min(configured_max_tokens, vram_based_n_ctx, model_context_window) diff --git a/src/khoj/processor/conversation/offline/whisper.py b/src/khoj/processor/conversation/offline/whisper.py deleted file mode 100644 index d8dd4457..00000000 --- a/src/khoj/processor/conversation/offline/whisper.py +++ /dev/null @@ -1,15 +0,0 @@ -import whisper -from asgiref.sync import sync_to_async - -from khoj.utils import state - - -async def transcribe_audio_offline(audio_filename: str, model: str) -> str: - """ - Transcribe audio file offline using Whisper - """ - # Send the audio data to the Whisper API - if not state.whisper_model: - state.whisper_model = whisper.load_model(model) - response = await sync_to_async(state.whisper_model.transcribe)(audio_filename) - return response["text"] diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 47589e93..fcf30d07 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -78,38 +78,6 @@ no_entries_found = PromptTemplate.from_template( """.strip() ) -## Conversation Prompts for Offline Chat Models -## -- -system_prompt_offline_chat = PromptTemplate.from_template( - """ -You are Khoj, a smart, inquisitive and helpful personal assistant. -- Use your general knowledge and past conversation with the user as context to inform your responses. -- If you do not know the answer, say 'I don't know.' -- Think step-by-step and ask questions to get the necessary information to answer the user's question. -- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations. -- Do not print verbatim Notes unless necessary. - -Note: More information about you, the company or Khoj apps can be found at https://khoj.dev. -Today is {day_of_week}, {current_date} in UTC. -""".strip() -) - -custom_system_prompt_offline_chat = PromptTemplate.from_template( - """ -You are {name}, a personal agent on Khoj. -- Use your general knowledge and past conversation with the user as context to inform your responses. -- If you do not know the answer, say 'I don't know.' -- Think step-by-step and ask questions to get the necessary information to answer the user's question. -- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations. -- Do not print verbatim Notes unless necessary. - -Note: More information about you, the company or Khoj apps can be found at https://khoj.dev. -Today is {day_of_week}, {current_date} in UTC. - -Instructions:\n{bio} -""".strip() -) - ## Notes Conversation ## -- notes_conversation = PromptTemplate.from_template( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 3bd17856..6e60c786 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -18,8 +18,6 @@ import requests import tiktoken import yaml from langchain_core.messages.chat import ChatMessage -from llama_cpp import LlamaTokenizer -from llama_cpp.llama import Llama from pydantic import BaseModel, ConfigDict, ValidationError, create_model from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -32,7 +30,6 @@ from khoj.database.models import ( KhojUser, ) from khoj.processor.conversation import prompts -from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -85,12 +82,6 @@ model_to_prompt_size = { "claude-sonnet-4-20250514": 60000, "claude-opus-4-0": 60000, "claude-opus-4-20250514": 60000, - # Offline Models - "bartowski/Qwen2.5-14B-Instruct-GGUF": 20000, - "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, - "bartowski/Llama-3.2-3B-Instruct-GGUF": 20000, - "bartowski/gemma-2-9b-it-GGUF": 6000, - "bartowski/gemma-2-2b-it-GGUF": 6000, } model_to_tokenizer: Dict[str, str] = {} @@ -573,7 +564,6 @@ def generate_chatml_messages_with_context( system_message: str = None, chat_history: list[ChatMessageModel] = [], model_name="gpt-4o-mini", - loaded_model: Optional[Llama] = None, max_prompt_size=None, tokenizer_name=None, query_images=None, @@ -588,10 +578,7 @@ def generate_chatml_messages_with_context( """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs if not max_prompt_size: - if loaded_model: - max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) - else: - max_prompt_size = model_to_prompt_size.get(model_name, 10000) + max_prompt_size = model_to_prompt_size.get(model_name, 10000) # Scale lookback turns proportional to max prompt size supported by model lookback_turns = max_prompt_size // 750 @@ -735,7 +722,7 @@ def generate_chatml_messages_with_context( message.content = [{"type": "text", "text": message.content}] # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) + messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) # Return message in chronological order return messages[::-1] @@ -743,25 +730,20 @@ def generate_chatml_messages_with_context( def get_encoder( model_name: str, - loaded_model: Optional[Llama] = None, tokenizer_name=None, -) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer: +) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast: default_tokenizer = "gpt-4o" try: - if loaded_model: - encoder = loaded_model.tokenizer() - elif model_name.startswith("gpt-") or model_name.startswith("o1"): - # as tiktoken doesn't recognize o1 model series yet - encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) - elif tokenizer_name: + if tokenizer_name: if tokenizer_name in state.pretrained_tokenizers: encoder = state.pretrained_tokenizers[tokenizer_name] else: encoder = AutoTokenizer.from_pretrained(tokenizer_name) state.pretrained_tokenizers[tokenizer_name] = encoder else: - encoder = download_model(model_name).tokenizer() + # as tiktoken doesn't recognize o1 model series yet + encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) except: encoder = tiktoken.encoding_for_model(default_tokenizer) if state.verbose > 2: @@ -773,7 +755,7 @@ def get_encoder( def count_tokens( message_content: str | list[str | dict], - encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding, + encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding, ) -> int: """ Count the total number of tokens in a list of messages. @@ -825,11 +807,10 @@ def truncate_messages( messages: list[ChatMessage], max_prompt_size: int, model_name: str, - loaded_model: Optional[Llama] = None, tokenizer_name=None, ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - encoder = get_encoder(model_name, loaded_model, tokenizer_name) + encoder = get_encoder(model_name, tokenizer_name) # Extract system message from messages system_message = None diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 979205a0..2b63c40d 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -235,7 +235,6 @@ def is_operator_model(model: str) -> ChatModel.ModelType | None: "claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC, "claude-sonnet-4": ChatModel.ModelType.ANTHROPIC, "claude-opus-4": ChatModel.ModelType.ANTHROPIC, - "ui-tars-1.5": ChatModel.ModelType.OFFLINE, } for operator_model in operator_models: if model.startswith(operator_model): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 23805200..b1562e64 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -15,7 +15,6 @@ from khoj.configure import initialize_content from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo from khoj.database.models import KhojUser, SpeechToTextModelOptions -from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, @@ -150,9 +149,6 @@ async def transcribe( if not speech_to_text_config: # If the user has not configured a speech to text model, return an unsupported on server error status_code = 501 - elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: - speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio_offline(audio_filename, speech2text_model) elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: speech2text_model = speech_to_text_config.model_name if speech_to_text_config.ai_model_api: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d3554f7f..7833d3d8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -89,10 +89,6 @@ from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, gemini_send_message_to_model, ) -from khoj.processor.conversation.offline.chat_model import ( - converse_offline, - send_message_to_model_offline, -) from khoj.processor.conversation.openai.gpt import ( converse_openai, send_message_to_model, @@ -117,7 +113,6 @@ from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.search_type import text_search from khoj.utils import state -from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( LRU, ConversationCommand, @@ -168,14 +163,6 @@ async def is_ready_to_chat(user: KhojUser): if user_chat_model == None: user_chat_model = await ConversationAdapters.aget_default_chat_model(user) - if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE: - chat_model_name = user_chat_model.name - max_tokens = user_chat_model.max_prompt_size - if state.offline_chat_processor_config is None: - logger.info("Loading Offline Chat Model...") - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - return True - if ( user_chat_model and ( @@ -1470,12 +1457,6 @@ async def send_message_to_model_wrapper( vision_available = chat_model.vision_enabled api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - loaded_model = None - - if model_type == ChatModel.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( user_message=query, @@ -1483,7 +1464,6 @@ async def send_message_to_model_wrapper( system_message=system_message, chat_history=chat_history, model_name=chat_model_name, - loaded_model=loaded_model, tokenizer_name=tokenizer, max_prompt_size=max_tokens, vision_enabled=vision_available, @@ -1492,18 +1472,7 @@ async def send_message_to_model_wrapper( query_files=query_files, ) - if model_type == ChatModel.ModelType.OFFLINE: - return send_message_to_model_offline( - messages=truncated_messages, - loaded_model=loaded_model, - model_name=chat_model_name, - max_prompt_size=max_tokens, - streaming=False, - response_type=response_type, - tracer=tracer, - ) - - elif model_type == ChatModel.ModelType.OPENAI: + if model_type == ChatModel.ModelType.OPENAI: return send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1565,19 +1534,12 @@ def send_message_to_model_wrapper_sync( vision_available = chat_model.vision_enabled api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - loaded_model = None - - if model_type == ChatModel.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, chat_history=chat_history, model_name=chat_model_name, - loaded_model=loaded_model, max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=model_type, @@ -1585,18 +1547,7 @@ def send_message_to_model_wrapper_sync( query_files=query_files, ) - if model_type == ChatModel.ModelType.OFFLINE: - return send_message_to_model_offline( - messages=truncated_messages, - loaded_model=loaded_model, - model_name=chat_model_name, - max_prompt_size=max_tokens, - streaming=False, - response_type=response_type, - tracer=tracer, - ) - - elif model_type == ChatModel.ModelType.OPENAI: + if model_type == ChatModel.ModelType.OPENAI: return send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1678,30 +1629,7 @@ async def agenerate_chat_response( chat_model = vision_enabled_config vision_available = True - if chat_model.model_type == "offline": - loaded_model = state.offline_chat_processor_config.loaded_model - chat_response_generator = converse_offline( - # Query - user_query=query_to_run, - # Context - references=compiled_references, - online_results=online_results, - generated_files=raw_generated_files, - generated_asset_results=generated_asset_results, - location_data=location_data, - user_name=user_name, - query_files=query_files, - chat_history=chat_history, - # Model - loaded_model=loaded_model, - model_name=chat_model.name, - max_prompt_size=chat_model.max_prompt_size, - tokenizer_name=chat_model.tokenizer, - agent=agent, - tracer=tracer, - ) - - elif chat_model.model_type == ChatModel.ModelType.OPENAI: + if chat_model.model_type == ChatModel.ModelType.OPENAI: openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key chat_model_name = chat_model.name diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 786cbb62..66cdda74 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -33,9 +33,6 @@ def cli(args=None): parser.add_argument("--sslcert", type=str, help="Path to SSL certificate file") parser.add_argument("--sslkey", type=str, help="Path to SSL key file") parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit") - parser.add_argument( - "--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model" - ) parser.add_argument( "--anonymous-mode", action="store_true", @@ -54,9 +51,6 @@ def cli(args=None): if len(remaining_args) > 0: logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}") - # Set default values for arguments - args.chat_on_gpu = not args.disable_chat_on_gpu - args.version_no = version("khoj") if args.version: # Show version of khoj installed and exit diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 03dad75c..d1b6f20a 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union import torch -from khoj.processor.conversation.offline.utils import download_model - logger = logging.getLogger(__name__) @@ -62,20 +60,3 @@ class ImageSearchModel: @dataclass class SearchModels: text_search: Optional[TextSearchModel] = None - - -@dataclass -class OfflineChatProcessorConfig: - loaded_model: Union[Any, None] = None - - -class OfflineChatProcessorModel: - def __init__(self, chat_model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", max_tokens: int = None): - self.chat_model = chat_model - self.loaded_model = None - try: - self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens) - except ValueError as e: - self.loaded_model = None - logger.error(f"Error while loading offline chat model: {e}", exc_info=True) - raise e diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 4f128aa7..b2ed49de 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -10,13 +10,6 @@ empty_escape_sequences = "\n|\r|\t| " app_env_filepath = "~/.khoj/env" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" content_directory = "~/.khoj/content/" -default_offline_chat_models = [ - "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", - "bartowski/Llama-3.2-3B-Instruct-GGUF", - "bartowski/gemma-2-9b-it-GGUF", - "bartowski/gemma-2-2b-it-GGUF", - "bartowski/Qwen2.5-14B-Instruct-GGUF", -] default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1", "o3", "o4-mini"] default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"] default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"] diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 9ed1bdff..336b228d 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -16,7 +16,6 @@ from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tok from khoj.utils.constants import ( default_anthropic_chat_models, default_gemini_chat_models, - default_offline_chat_models, default_openai_chat_models, ) @@ -72,7 +71,6 @@ def initialization(interactive: bool = True): default_api_key=openai_api_key, api_base_url=openai_base_url, vision_enabled=True, - is_offline=False, interactive=interactive, provider_name=provider, ) @@ -118,7 +116,6 @@ def initialization(interactive: bool = True): default_gemini_chat_models, default_api_key=os.getenv("GEMINI_API_KEY"), vision_enabled=True, - is_offline=False, interactive=interactive, provider_name="Google Gemini", ) @@ -145,17 +142,6 @@ def initialization(interactive: bool = True): default_anthropic_chat_models, default_api_key=os.getenv("ANTHROPIC_API_KEY"), vision_enabled=True, - is_offline=False, - interactive=interactive, - ) - - # Set up offline chat models - _setup_chat_model_provider( - ChatModel.ModelType.OFFLINE, - default_offline_chat_models, - default_api_key=None, - vision_enabled=False, - is_offline=True, interactive=interactive, ) @@ -186,7 +172,6 @@ def initialization(interactive: bool = True): interactive: bool, api_base_url: str = None, vision_enabled: bool = False, - is_offline: bool = False, provider_name: str = None, ) -> Tuple[bool, AiModelApi]: supported_vision_models = ( @@ -195,11 +180,6 @@ def initialization(interactive: bool = True): provider_name = provider_name or model_type.name.capitalize() default_use_model = default_api_key is not None - # If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment. - # This usually means there's not enough RAM to load offline models directly within the application. - # In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case. - if is_offline: - default_use_model = False use_model_provider = ( default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y" @@ -211,13 +191,12 @@ def initialization(interactive: bool = True): logger.info(f"️💬 Setting up your {provider_name} chat configuration") ai_model_api = None - if not is_offline: - if interactive: - user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ") - api_key = user_api_key if user_api_key != "" else default_api_key - else: - api_key = default_api_key - ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) + if interactive: + user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ") + api_key = user_api_key if user_api_key != "" else default_api_key + else: + api_key = default_api_key + ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) if interactive: user_chat_models = input( diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index df7f3334..0148511a 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -103,13 +103,8 @@ class OpenAIProcessorConfig(ConfigBase): chat_model: Optional[str] = "gpt-4o-mini" -class OfflineChatProcessorConfig(ConfigBase): - chat_model: Optional[str] = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" - - class ConversationProcessorConfig(ConfigBase): openai: Optional[OpenAIProcessorConfig] = None - offline_chat: Optional[OfflineChatProcessorConfig] = None max_prompt_size: Optional[int] = None tokenizer: Optional[str] = None diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index f96409c2..6acd3d65 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -12,7 +12,7 @@ from whisper import Whisper from khoj.database.models import ProcessLock from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config -from khoj.utils.config import OfflineChatProcessorModel, SearchModels +from khoj.utils.config import SearchModels from khoj.utils.helpers import LRU, get_device, is_env_var_true from khoj.utils.rawconfig import FullConfig @@ -22,7 +22,6 @@ search_models = SearchModels() embeddings_model: Dict[str, EmbeddingsModel] = None cross_encoder_model: Dict[str, CrossEncoderModel] = None openai_client: OpenAI = None -offline_chat_processor_config: OfflineChatProcessorModel = None whisper_model: Whisper = None config_file: Path = None verbose: int = 0 @@ -39,7 +38,6 @@ telemetry: List[Dict[str, str]] = [] telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE") khoj_version: str = None device = get_device() -chat_on_gpu: bool = True anonymous_mode: bool = False pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict() billing_enabled: bool = ( diff --git a/tests/conftest.py b/tests/conftest.py index 25876d61..77c86dfa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -196,17 +196,6 @@ def default_openai_chat_model_option(): return chat_model -@pytest.mark.django_db -@pytest.fixture -def offline_agent(): - chat_model = ChatModelFactory() - return Agent.objects.create( - name="Accountant", - chat_model=chat_model, - personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.", - ) - - @pytest.mark.django_db @pytest.fixture def openai_agent(): @@ -516,40 +505,6 @@ def client( return TestClient(app) -@pytest.fixture(scope="function") -def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): - # Initialize app state - state.config.search_type = search_config - state.SearchType = configure_search_types() - - LocalMarkdownConfig.objects.create( - input_files=None, - input_filter=["tests/data/markdown/*.markdown"], - user=default_user2, - ) - - all_files = fs_syncer.collect_files(user=default_user2) - configure_content(default_user2, all_files) - - # Initialize Processor from Config - ChatModelFactory( - name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF", - tokenizer=None, - max_prompt_size=None, - model_type="offline", - ) - UserConversationProcessorConfigFactory(user=default_user2) - - state.anonymous_mode = True - - app = FastAPI() - - configure_routes(app) - configure_middleware(app) - app.mount("/static", StaticFiles(directory=web_directory), name="static") - return TestClient(app) - - @pytest.fixture(scope="function") def new_org_file(default_user: KhojUser, content_config: ContentConfig): # Setup diff --git a/tests/helpers.py b/tests/helpers.py index d3c94abc..53ce4ea6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -19,7 +19,7 @@ from khoj.database.models import ( from khoj.processor.conversation.utils import message_to_log -def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE): +def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE): provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER") if provider and provider in ChatModel.ModelType: return ChatModel.ModelType(provider) @@ -93,7 +93,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory): max_prompt_size = 20000 tokenizer = None - name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF" + name = "gemini-2.0-flash" model_type = get_chat_provider() ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None) diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py deleted file mode 100644 index 979710b6..00000000 --- a/tests/test_offline_chat_actors.py +++ /dev/null @@ -1,610 +0,0 @@ -from datetime import datetime - -import pytest - -from khoj.database.models import ChatModel -from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions -from khoj.utils.helpers import ConversationCommand -from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider - -SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE -pytestmark = pytest.mark.skipif( - SKIP_TESTS, - reason="Disable in CI to avoid long test runs.", -) - -import freezegun -from freezegun import freeze_time - -from khoj.processor.conversation.offline.chat_model import converse_offline -from khoj.processor.conversation.offline.utils import download_model -from khoj.utils.constants import default_offline_chat_models - - -@pytest.fixture(scope="session") -def loaded_model(): - return download_model(default_offline_chat_models[0], max_tokens=5000) - - -freezegun.configure(extend_ignore_list=["transformers"]) - - -# Test -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2): - # Act - response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model) - - assert len(response) >= 1 - - assert any( - [ - "dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0], - "dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0], - 'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0], - 'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0], - ] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting") -@pytest.mark.chatquality -@freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2): - # Act - response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model) - - # Assert - assert len(response) >= 1 - # The user query should be the last question in the response - assert response[-1] == ["Which countries did I visit last month?"] - assert any( - [ - "dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0], - "dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0], - 'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0], - 'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0], - ] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting") -@pytest.mark.chatquality -@freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2): - # Act - response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model) - - # Assert - expected_responses = [ - ("dt>='1984-01-01'", ""), - ("dt>='1984-01-01'", "dt<'1985-01-01'"), - ("dt>='1984-01-01'", "dt<='1984-12-31'"), - ] - assert len(response) == 1 - assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( - "Expected date filter to limit to 1984 in response but got: " + response[0] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2): - # Act - responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model) - - # Assert - assert len(responses) >= 2 - assert ["the Sun" in response for response in responses] - assert ["the Moon" in response for response in responses] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2): - # Act - response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model) - - # Assert - expected_responses = ["height", "taller", "shorter", "heights", "who"] - assert len(response) <= 3 - - for question in response: - assert any([expected_response in question.lower() for expected_response in expected_responses]), ( - "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2): - # Arrange - message_list = [ - ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), - ] - query = "Does he have any sons?" - - # Act - response = extract_questions( - query, - default_user2, - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - use_history=True, - ) - - any_expected_with_barbara = [ - "sibling", - "brother", - ] - - any_expected_with_anderson = [ - "son", - "sons", - "children", - "family", - ] - - # Assert - assert len(response) >= 1 - # Ensure the remaining generated search queries use proper nouns and chat history context - for question in response: - if "Barbara" in question: - assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), ( - "Expected search queries using proper nouns and chat history for context, but got: " + question - ) - elif "Anderson" in question: - assert any([expected_response in question for expected_response in any_expected_with_anderson]), ( - "Expected search queries using proper nouns and chat history for context, but got: " + question - ) - else: - assert False, ( - "Expected search queries using proper nouns and chat history for context, but got: " + question - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2): - # Arrange - message_list = [ - ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), - ] - - # Act - response = extract_questions( - "Is she a Doctor?", - default_user2, - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - use_history=True, - ) - - expected_responses = [ - "Barbara", - "Anderson", - ] - - # Assert - assert len(response) >= 1 - assert any([expected_response in response[0] for expected_response in expected_responses]), ( - "Expected chat actor to mention person's by name, but got: " + response[0] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context") -@pytest.mark.chatquality -def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2): - # Arrange - message_list = [ - ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []), - ] - - # Act - response = extract_questions( - "What was the Pizza place we ate at over there?", - default_user2, - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - - # Assert - expected_responses = [ - ("dt>='2000-04-01'", "dt<'2000-05-01'"), - ("dt>='2000-04-01'", "dt<='2000-04-30'"), - ('dt>="2000-04-01"', 'dt<"2000-05-01"'), - ('dt>="2000-04-01"', 'dt<="2000-04-30"'), - ] - assert len(response) == 1 - assert "Masai Mara" in response[0] - assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( - "Expected date filter to limit to April 2000 in response but got: " + response[0] - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -@pytest.mark.parametrize( - "user_query, expected_conversation_commands", - [ - ( - "Where did I learn to swim?", - {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text}, - ), - ( - "Where is the nearest hospital?", - {"sources": [ConversationCommand.Online], "output": ConversationCommand.Text}, - ), - ( - "Summarize the wikipedia page on the history of the internet", - {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text}, - ), - ( - "How many noble gases are there?", - {"sources": [ConversationCommand.General], "output": ConversationCommand.Text}, - ), - ( - "Make a painting incorporating my past diving experiences", - {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image}, - ), - ( - "Create a chart of the weather over the next 7 days in Timbuktu", - {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text}, - ), - ( - "What's the highest point in this country and have I been there?", - {"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text}, - ), - ], -) -async def test_select_data_sources_actor_chooses_to_search_notes( - client_offline_chat, user_query, expected_conversation_commands, default_user2 -): - # Act - selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2) - - # Assert - assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"]) - assert expected_conversation_commands["output"] == selected_conversation_commands["output"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2): - # Arrange - user_query = "What's the latest in the Israel/Palestine conflict?" - chat_log = [ - ( - "Let's talk about the current events around the world.", - "Sure, let's discuss the current events. What would you like to know?", - [], - ), - ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), - ] - chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log)) - - # Act - tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["online"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_chat_with_no_chat_history_or_retrieved_content(loaded_model): - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Hello, my name is Testatron. Who are you?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["Khoj", "khoj", "KHOJ"] - assert len(response) > 0 - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model): - "Chat actor needs to use context in previous notes and chat history to answer question" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ( - "When was I born?", - "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], - ), - ] - - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Where was I born?", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert len(response) > 0 - # Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes - assert "Testville" in response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): - "Chat actor needs to use context across currently retrieved notes and chat history to answer question" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - - # Act - response_gen = converse_offline( - references=[ - {"compiled": "Testatron was born on 1st April 1984 in Testville."} - ], # Assume context retrieved from notes for the user_query - user_query="Where was I born?", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert len(response) > 0 - assert "Testville" in response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer") -@pytest.mark.chatquality -def test_refuse_answering_unanswerable_question(loaded_model): - "Chat actor should not try make up answers to unanswerable questions." - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Where was I born?", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = [ - "don't know", - "do not know", - "no information", - "do not have", - "don't have", - "cannot answer", - "I'm sorry", - ] - assert len(response) > 0 - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected chat actor to say they don't know in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_requires_current_date_awareness(loaded_model): - "Chat actor should be able to answer questions relative to current date using provided notes" - # Arrange - context = [ - { - "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""2020-04-01 "SuperMercado" "Bananas" -Expenses:Food:Groceries 10.00 USD""" - }, - { - "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What did I have for Dinner today?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["tacos", "Tacos"] - assert len(response) > 0 - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected [T|t]acos in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model): - "Chat actor should be able to answer questions that require date aware aggregation across multiple notes" - # Arrange - context = [ - { - "compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch" -Expenses:Food:Dining 10.00 USD""" - }, - { - "compiled": f"""2020-04-01 "SuperMercado" "Bananas" -Expenses:Food:Groceries 10.00 USD""" - }, - { - "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" -Expenses:Food:Dining 10.00 USD""" - }, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="How much did I spend on dining this year?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert len(response) > 0 - assert "20" in response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model): - "Chat actor should be able to answer general questions not requiring looking at chat history or notes" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("Where was I born?", "You were born Testville.", []), - ] - - # Act - response_gen = converse_offline( - references=[], # Assume no context retrieved from notes for the user_query - user_query="Write a haiku about unit testing in 3 lines", - chat_history=generate_chat_history(message_list), - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["test", "testing"] - assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines. - assert any([expected_response in response.lower() for expected_response in expected_responses]), ( - "Expected [T|t]est in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model): - "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" - # Arrange - context = [ - { - "compiled": f"""# Ramya -My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""" - }, - { - "compiled": f"""# Fang -My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""" - }, - { - "compiled": f"""# Aiyla -My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""" - }, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="How many kids does my older sister have?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister", "Which one"] - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected chat actor to ask for clarification in response, but got: " + response - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -def test_agent_prompt_should_be_used(loaded_model, offline_agent): - "Chat actor should ask be tuned to think like an accountant based on the agent definition" - # Arrange - context = [ - {"compiled": f"""I went to the store and bought some bananas for 2.20"""}, - {"compiled": f"""I went to the store and bought some apples for 1.30"""}, - {"compiled": f"""I went to the store and bought some oranges for 6.00"""}, - ] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What did I buy?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert that the model without the agent prompt does not include the summary of purchases - expected_responses = ["9.50", "9.5"] - assert all([expected_response not in response for expected_response in expected_responses]), ( - "Expected chat actor to summarize values of purchases" + response - ) - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What did I buy?", - loaded_model=loaded_model, - agent=offline_agent, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert that the model with the agent prompt does include the summary of purchases - expected_responses = ["9.50", "9.5"] - assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected chat actor to summarize values of purchases" + response - ) - - -# ---------------------------------------------------------------------------------------------------- -def test_chat_does_not_exceed_prompt_size(loaded_model): - "Ensure chat context and response together do not exceed max prompt size for the model" - # Arrange - prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed" - context = [{"compiled": " ".join([f"{number}" for number in range(2043)])}] - - # Act - response_gen = converse_offline( - references=context, # Assume context retrieved from notes for the user_query - user_query="What numbers come after these?", - loaded_model=loaded_model, - ) - response = "".join([response_chunk for response_chunk in response_gen]) - - # Assert - assert prompt_size_exceeded_error not in response, ( - "Expected chat response to be within prompt limits, but got exceeded error: " + response - ) diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py deleted file mode 100644 index 0eb4a0dc..00000000 --- a/tests/test_offline_chat_director.py +++ /dev/null @@ -1,726 +0,0 @@ -import urllib.parse - -import pytest -from faker import Faker -from freezegun import freeze_time - -from khoj.database.models import Agent, ChatModel, Entry, KhojUser -from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import message_to_log -from tests.helpers import ConversationFactory, get_chat_provider - -SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE -pytestmark = pytest.mark.skipif( - SKIP_TESTS, - reason="Disable in CI to avoid long test runs.", -) - -fake = Faker() - - -# Helpers -# ---------------------------------------------------------------------------------------------------- -def generate_history(message_list): - # Generate conversation logs - conversation_log = {"chat": []} - for user_message, gpt_message, context in message_list: - message_to_log( - user_message, - gpt_message, - {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - chat_history=conversation_log.get("chat", []), - ) - return conversation_log - - -def create_conversation(message_list, user, agent=None): - # Generate conversation logs - conversation_log = generate_history(message_list) - # Update Conversation Metadata Logs in Database - return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent) - - -# Tests -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat): - # Act - query = "Hello, my name is Testatron. Who are you?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Khoj", "khoj"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_chat_with_online_content(client_offline_chat): - # Act - q = "/online give me the link to paul graham's essay how to do great work" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.json()["response"] - - # Assert - expected_responses = [ - "https://paulgraham.com/greatwork.html", - "https://www.paulgraham.com/greatwork.html", - "http://www.paulgraham.com/greatwork.html", - ] - assert response.status_code == 200 - assert any( - [expected_response in response_message for expected_response in expected_responses] - ), f"Expected links: {expected_responses}. Actual response: {response_message}" - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_chat_with_online_webpage_content(client_offline_chat): - # Act - q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.json()["response"] - - # Assert - expected_responses = ["185", "1871", "horse"] - assert response.status_code == 200 - assert any( - [expected_response in response_message for expected_response in expected_responses] - ), f"Expected response with {expected_responses}. But actual response had: {response_message}" - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_chat_history(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - create_conversation(message_list, default_user2) - - # Act - q = "What is my name?" - response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Testatron", "testatron"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected [T|t]estatron in response but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ( - "When was I born?", - "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], - ), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was Xi Li born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "Fujiang" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ( - "When was I born?", - "You were born on 1st April 1984.", - ["Testatron was born on 1st April 1984 in Testville."], - ), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was I born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - # 1. Infer who I am from chat history - # 2. Infer I was born in Testville from previously retrieved notes - assert "Testville" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was I born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - # Inference in a multi-turn conversation - # 1. Infer who I am from chat history - # 2. Search for notes about when was born - # 3. Extract where I was born from currently retrieved notes - assert "Fujiang" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2): - "Chat director should say don't know as not enough contexts in chat history or retrieved to answer question" - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ] - create_conversation(message_list, default_user2) - - # Act - q = "Where was I born?" - response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected chat director to say they don't know in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_using_general_command(client_offline_chat, default_user2): - # Arrange - message_list = [] - create_conversation(message_list, default_user2) - - # Act - query = "/general Where was Xi Li born?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "Fujiang" not in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2): - # Arrange - message_list = [] - create_conversation(message_list, default_user2) - - # Act - query = "/notes Where was Xi Li born?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "Fujiang" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_using_file_filter(client_offline_chat, default_user2): - # Arrange - no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"') - answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"') - message_list = [] - create_conversation(message_list, default_user2) - - # Act - no_answer_response = client_offline_chat.post( - f"/api/chat", json={"q": no_answer_query, "stream": True} - ).content.decode("utf-8") - answer_response = client_offline_chat.post(f"/api/chat", json={"q": answer_query, "stream": True}).content.decode( - "utf-8" - ) - - # Assert - assert "Fujiang" not in no_answer_response - assert "Fujiang" in answer_response - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_not_known_using_notes_command(client_offline_chat, default_user2): - # Arrange - message_list = [] - create_conversation(message_list, default_user2) - - # Act - query = urllib.parse.quote("/notes Where was Testatron born?") - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert response_message == prompts.no_notes_found.format() - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_one_file(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "Xi Li.markdown" file to the file filters - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - # pick the file that has "Xi Li.markdown" in the name - summarization_file = "" - for file in file_list: - if "Birthday Gift for Xiu turning 4.markdown" in file: - summarization_file = file - break - assert summarization_file != "" - - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation.id)}, - ) - - # Act - query = "/summarize" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert - assert response_message != "" - assert response_message != "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "Xi Li.markdown" file to the file filters - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - # pick the file that has "Xi Li.markdown" in the name - summarization_file = "" - for file in file_list: - if "Birthday Gift for Xiu turning 4.markdown" in file: - summarization_file = file - break - assert summarization_file != "" - - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation.id)}, - ) - - # Act - query = "/summarize tell me about Xiu" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert - assert response_message != "" - assert response_message != "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "Xi Li.markdown" file to the file filters - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - - response = client_offline_chat.post( - "api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)} - ) - response = client_offline_chat.post( - "api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)} - ) - - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - - # Assert - assert response_message is not None - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_no_files(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser): - message_list = [] - conversation1 = create_conversation(message_list, default_user2) - conversation2 = create_conversation(message_list, default_user2) - - file_list = ( - Entry.objects.filter(user=default_user2, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - summarization_file = "" - for file in file_list: - if "Birthday Gift for Xiu turning 4.markdown" in file: - summarization_file = file - break - assert summarization_file != "" - - # add file filter to conversation 1. - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation1.id)}, - ) - - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id}) - response_message = response.json()["response"] - - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - # now make sure that the file filter is still in conversation 1 - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id}) - response_message = response.json()["response"] - - # Assert - assert response_message != "" - assert response_message != "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # post "imaginary.markdown" file to the file filters - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)}, - ) - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - -@pytest.mark.django_db(transaction=True) -@pytest.mark.chatquality -def test_summarize_diff_user_file( - client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser -): - # Arrange - message_list = [] - conversation = create_conversation(message_list, default_user2) - # Get the pdf file called singlepage.pdf - file_list = ( - Entry.objects.filter(user=default_user, file_source="computer") - .distinct("file_path") - .values_list("file_path", flat=True) - ) - summarization_file = "" - for file in file_list: - if "singlepage.pdf" in file: - summarization_file = file - break - assert summarization_file != "" - # add singlepage.pdf to the file filters - response = client_offline_chat.post( - "api/chat/conversation/file-filters", - json={"filename": summarization_file, "conversation_id": str(conversation.id)}, - ) - - # Act - query = "/summarize" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) - response_message = response.json()["response"] - - # Assert - assert response_message == "No files selected for summarization. Please add files using the section on the left." - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -@freeze_time("2023-04-01", ignore=["transformers"]) -def test_answer_requires_current_date_awareness(client_offline_chat): - "Chat actor should be able to answer questions relative to current date using provided notes" - # Act - query = "Where did I have lunch today?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Arak", "Medellin"] - assert response.status_code == 200 - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected chat director to say Arak, Medellin, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -@freeze_time("2023-04-01", ignore=["transformers"]) -def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat): - "Chat director should be able to answer questions that require date aware aggregation across multiple notes" - # Act - query = "How much did I spend on dining this year?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert "26" in response_message - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("Where was I born?", "You were born Testville.", []), - ] - create_conversation(message_list, default_user2) - - # Act - query = "Write a haiku about unit testing. Do not say anything else." - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["test", "Test"] - assert response.status_code == 200 - assert len(response_message.splitlines()) == 3 # haikus are 3 lines long - assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected [T|t]est in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2): - # Act - query = "What is the name of Namitas older son" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = [ - "which of them is the older", - "which one is older", - "which of them is older", - "which one is the older", - ] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected chat director to ask for clarification in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not capable of answering this question yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("Where was I born?", "You were born Testville.", []), - ] - create_conversation(message_list, default_user2) - - # Act - query = "What is my name?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["Testatron", "testatron"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected [T|t]estatron in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("What's my favorite color", "Your favorite color is green.", []), - ("Where was I born?", "You were born Testville.", []), - ] - message_list2 = [ - ("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 14th August 1947.", []), - ("What's my favorite color", "Your favorite color is maroon.", []), - ("Where was I born?", "You were born in a potato farm.", []), - ] - conversation = create_conversation(message_list, default_user2) - create_conversation(message_list2, default_user2) - - # Act - query = "What is my favorite color?" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["green"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected green in response, but got: " + response_message - ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet") -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_by_conversation_id_with_agent( - client_offline_chat, default_user2: KhojUser, offline_agent: Agent -): - # Arrange - message_list = [ - ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), - ("When was I born?", "You were born on 1st April 1984.", []), - ("What's my favorite color", "Your favorite color is green.", []), - ("Where was I born?", "You were born Testville.", []), - ("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []), - ] - conversation = create_conversation(message_list, default_user2, offline_agent) - - # Act - query = "/general What did I eat for breakfast?" - response = client_offline_chat.post( - f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True} - ) - response_message = response.content.decode("utf-8") - - # Assert that agent only responds with the summary of spending - expected_responses = ["13.00", "13", "13.0", "thirteen"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected green in response, but got: " + response_message - ) - - -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_chat_history_very_long(client_offline_chat, default_user2): - # Arrange - message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)] - create_conversation(message_list, default_user2) - - # Act - query = "What is my name?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - assert response.status_code == 200 - assert len(response_message) > 0 - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.chatquality -@pytest.mark.django_db(transaction=True) -def test_answer_requires_multiple_independent_searches(client_offline_chat): - "Chat director should be able to answer by doing multiple independent searches for required information" - # Act - query = "Is Xi older than Namita?" - response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True}) - response_message = response.content.decode("utf-8") - - # Assert - expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] - assert response.status_code == 200 - assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( - "Expected Xi is older than Namita, but got: " + response_message - )