Drop native offline chat support with llama-cpp-python

It is recommended to chat with open-source models by running an
open-source server like Ollama, Llama.cpp on your GPU powered machine
or use a commercial provider of open-source models like DeepInfra or
OpenRouter.

These chat model serving options provide a mature Openai compatible
API that already works with Khoj.

Directly using offline chat models only worked reasonably with pip
install on a machine with GPU. Docker setup of khoj had trouble with
accessing GPU. And without GPU access offline chat is too slow.

Deprecating support for an offline chat provider directly from within
Khoj will reduce code complexity and increase developement velocity.
Offline models are subsumed to use existing Openai ai model provider.
This commit is contained in:
Debanjum
2025-07-03 01:49:18 -07:00
parent 3f8cc71aca
commit b1f2737c9a
28 changed files with 71 additions and 1945 deletions

View File

@@ -18,8 +18,8 @@ ENV PATH="/opt/venv/bin:${PATH}"
COPY pyproject.toml README.md ./
# Setup python environment
# Use the pre-built llama-cpp-python, torch cpu wheel
ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://abetlen.github.io/llama-cpp-python/whl/cpu" \
# Use the pre-built torch cpu wheel
ENV PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" \
# Avoid downloading unused cuda specific python packages
CUDA_VISIBLE_DEVICES="" \
# Use static version to build app without git dependency

View File

@@ -64,8 +64,6 @@ jobs:
DEBIAN_FRONTEND: noninteractive
run: |
apt update && apt install -y git libegl1 sqlite3 libsqlite3-dev libsqlite3-0 ffmpeg libsm6 libxext6
# required by llama-cpp-python prebuilt wheels
apt install -y musl-dev && ln -s /usr/lib/x86_64-linux-musl/libc.so /lib/libc.musl-x86_64.so.1
- name: ⬇️ Install Postgres
env:

View File

@@ -20,7 +20,7 @@ Add all the agents you want to use for your different use-cases like Writer, Res
### Chat Model Options
Add all the chat models you want to try, use and switch between for your different use-cases. For each chat model you add:
- `Chat model`: The name of an [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/en/docs/about-claude/models#model-names), [Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) or [Offline](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf) chat model.
- `Model type`: The chat model provider like `OpenAI`, `Offline`.
- `Model type`: The chat model provider like `OpenAI`, `Google`.
- `Vision enabled`: Set to `true` if your model supports vision. This is currently only supported for vision capable OpenAI models like `gpt-4o`
- `Max prompt size`, `Subscribed max prompt size`: These are optional fields. They are used to truncate the context to the maximum context size that can be passed to the model. This can help with accuracy and cost-saving.<br />
- `Tokenizer`: This is an optional field. It is used to accurately count tokens and truncate context passed to the chat model to stay within the models max prompt size.

View File

@@ -18,10 +18,6 @@ import TabItem from '@theme/TabItem';
These are the general setup instructions for self-hosted Khoj.
You can install the Khoj server using either [Docker](?server=docker) or [Pip](?server=pip).
:::info[Offline Model + GPU]
To use the offline chat model with your GPU, we recommend using the Docker setup with Ollama . You can also use the local Khoj setup via the Python package directly.
:::
:::info[First Run]
Restart your Khoj server after the first run to ensure all settings are applied correctly.
:::
@@ -225,10 +221,6 @@ To start Khoj automatically in the background use [Task scheduler](https://www.w
You can now open the web app at http://localhost:42110 and start interacting!<br />
Nothing else is necessary, but you can customize your setup further by following the steps below.
:::info[First Message to Offline Chat Model]
The offline chat model gets downloaded when you first send a message to it. The download can take a few minutes! Subsequent messages should be faster.
:::
### Add Chat Models
<h4>Login to the Khoj Admin Panel</h4>
Go to http://localhost:42110/server/admin and login with the admin credentials you setup during installation.
@@ -301,13 +293,14 @@ Offline chat stays completely private and can work without internet using any op
- A Nvidia, AMD GPU or a Mac M1+ machine would significantly speed up chat responses
:::
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
3. Set the `chat-model` field to the name of your preferred chat model
- Make sure the `model-type` is set to `Offline`
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).
5. Restart the Khoj server and [start chatting](http://localhost:42110) with your new offline model!
</TabItem>
1. Install any Openai API compatible local ai model server like [llama-cpp-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server), Ollama, vLLM etc.
2. Add an [ai model api](http://localhost:42110/server/admin/database/aimodelapi/add/) on the admin panel
- Set the `api url` field to the url of your local ai model provider like `http://localhost:11434/v1/` for Ollama
3. Restart the Khoj server to load models available on your local ai model provider
- If that doesn't work, you'll need to manually add available [chat model](http://localhost:42110/server/admin/database/chatmodel/add) in the admin panel.
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings)
5. [Start chatting](http://localhost:42110) with your local AI!
</TabItem>
</Tabs>
:::tip[Multiple Chat Models]

View File

@@ -65,7 +65,6 @@ dependencies = [
"django == 5.1.10",
"django-unfold == 0.42.0",
"authlib == 1.2.1",
"llama-cpp-python == 0.2.88",
"itsdangerous == 2.1.2",
"httpx == 0.28.1",
"pgvector == 0.2.4",

View File

@@ -72,7 +72,6 @@ from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
clean_object_for_db,
clean_text_for_db,
@@ -1553,14 +1552,6 @@ class ConversationAdapters:
if chat_model is None:
chat_model = await ConversationAdapters.aget_default_chat_model()
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return chat_model
if (
chat_model.model_type
in [

View File

@@ -0,0 +1,36 @@
# Generated by Django 5.1.10 on 2025-07-19 21:33
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0091_chatmodel_friendly_name_and_more"),
]
operations = [
migrations.AlterField(
model_name="chatmodel",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("anthropic", "Anthropic"), ("google", "Google")],
default="google",
max_length=200,
),
),
migrations.AlterField(
model_name="chatmodel",
name="name",
field=models.CharField(default="gemini-2.5-flash", max_length=200),
),
migrations.AlterField(
model_name="speechtotextmodeloptions",
name="model_name",
field=models.CharField(default="whisper-1", max_length=200),
),
migrations.AlterField(
model_name="speechtotextmodeloptions",
name="model_type",
field=models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200),
),
]

View File

@@ -220,16 +220,15 @@ class PriceTier(models.TextChoices):
class ChatModel(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
ANTHROPIC = "anthropic"
GOOGLE = "google"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
name = models.CharField(max_length=200, default="gemini-2.5-flash")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.GOOGLE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -605,11 +604,10 @@ class TextToImageModelConfig(DbBaseModel):
class SpeechToTextModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
model_name = models.CharField(max_length=200, default="base")
model_name = models.CharField(max_length=200, default="whisper-1")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)

View File

@@ -214,7 +214,6 @@ def set_state(args):
)
state.anonymous_mode = args.anonymous_mode
state.khoj_version = version("khoj")
state.chat_on_gpu = args.chat_on_gpu
def start_server(app, host=None, port=None, socket=None):

View File

@@ -1,224 +0,0 @@
import asyncio
import logging
import os
from datetime import datetime
from threading import Thread
from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Union
from langchain_core.messages.chat import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ResponseWithThought,
commit_conversation_trace,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils import state
from khoj.utils.helpers import (
is_none_or_empty,
is_promptrace_enabled,
truncate_code_context,
)
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
async def converse_offline(
# Query
user_query: str,
# Context
references: list[dict] = [],
online_results={},
code_results={},
query_files: str = None,
generated_files: List[FileAttachment] = None,
additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {},
location_data: LocationData = None,
user_name: str = None,
chat_history: list[ChatMessageModel] = [],
# Model
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
max_prompt_size=None,
tokenizer_name=None,
agent: Agent = None,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using Llama (Async Version)
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
tracer["chat_model"] = model_name
current_date = datetime.now()
if agent and agent.personality:
system_prompt = prompts.custom_system_prompt_offline_chat.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.system_prompt_offline_chat.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
if location_data:
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
user_query,
system_prompt,
chat_history,
context_message=context_message,
model_name=model_name,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=additional_context,
)
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
# Use asyncio.Queue and a thread to bridge sync iterator
queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue()
stop_phrases = ["<s>", "INST]", "Notes:"]
def _sync_llm_thread():
"""Synchronous function to run in a separate thread."""
aggregated_response = ""
start_time = perf_counter()
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
stop=stop_phrases,
max_prompt_size=max_prompt_size,
streaming=True,
tracer=tracer,
)
for response in response_iterator:
response_delta: str = response["choices"][0]["delta"].get("content", "")
# Log the time taken to start response
if aggregated_response == "" and response_delta != "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle response chunk
aggregated_response += response_delta
# Put chunk into the asyncio queue (non-blocking)
try:
queue.put_nowait(ResponseWithThought(text=response_delta))
except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace
tracer["chat_model"] = model_name
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
finally:
state.chat_lock.release()
# Signal end of stream
queue.put_nowait(None)
# Start the synchronous thread
thread = Thread(target=_sync_llm_thread)
thread.start()
# Asynchronously consume from the queue
while True:
chunk = await queue.get()
if chunk is None: # End of stream signal
queue.task_done()
break
yield chunk
queue.task_done()
# Wait for the thread to finish (optional, ensures cleanup)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, thread.join)
def send_message_to_model_offline(
messages: List[ChatMessage],
loaded_model=None,
model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
temperature: float = 0.2,
streaming=False,
stop=[],
max_prompt_size: int = None,
response_type: str = "text",
tracer: dict = {},
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion(
messages_dict,
stop=stop,
stream=streaming,
temperature=temperature,
response_format={"type": response_type},
seed=seed,
)
if streaming:
return response
response_text: str = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return ResponseWithThought(text=response_text)

View File

@@ -1,80 +0,0 @@
import glob
import logging
import math
import os
from typing import Any, Dict
from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state
from khoj.utils.helpers import get_device_memory
logger = logging.getLogger(__name__)
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
# Initialize Model Parameters
# Use n_ctx=0 to get context size from the model
kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False}
# Decide whether to load model to GPU or CPU
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
# Add chat format if known
if "llama-3" in repo_id.lower():
kwargs["chat_format"] = "llama-3"
elif "gemma-2" in repo_id.lower():
kwargs["chat_format"] = "gemma"
# Check if the model is already downloaded
model_path = load_model_from_cache(repo_id, filename)
chat_model = None
try:
chat_model = load_model(model_path, repo_id, filename, kwargs)
except:
# Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu"
chat_model = load_model(model_path, repo_id, filename, kwargs)
# Now load the model with context size set based on:
# 1. context size supported by model and
# 2. configured size or machine (V)RAM
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
)
return chat_model
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
from llama_cpp.llama import Llama
if model_path:
return Llama(model_path, **kwargs)
else:
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/")
object_id = "--".join([repo_type, repo_org, repo_name])
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
# Check if the model file exists
paths = glob.glob(model_path)
if paths:
return paths[0]
else:
return None
def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)

View File

@@ -1,15 +0,0 @@
import whisper
from asgiref.sync import sync_to_async
from khoj.utils import state
async def transcribe_audio_offline(audio_filename: str, model: str) -> str:
"""
Transcribe audio file offline using Whisper
"""
# Send the audio data to the Whisper API
if not state.whisper_model:
state.whisper_model = whisper.load_model(model)
response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
return response["text"]

View File

@@ -78,38 +78,6 @@ no_entries_found = PromptTemplate.from_template(
""".strip()
)
## Conversation Prompts for Offline Chat Models
## --
system_prompt_offline_chat = PromptTemplate.from_template(
"""
You are Khoj, a smart, inquisitive and helpful personal assistant.
- Use your general knowledge and past conversation with the user as context to inform your responses.
- If you do not know the answer, say 'I don't know.'
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
- Do not print verbatim Notes unless necessary.
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
Today is {day_of_week}, {current_date} in UTC.
""".strip()
)
custom_system_prompt_offline_chat = PromptTemplate.from_template(
"""
You are {name}, a personal agent on Khoj.
- Use your general knowledge and past conversation with the user as context to inform your responses.
- If you do not know the answer, say 'I don't know.'
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
- Do not print verbatim Notes unless necessary.
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
Today is {day_of_week}, {current_date} in UTC.
Instructions:\n{bio}
""".strip()
)
## Notes Conversation
## --
notes_conversation = PromptTemplate.from_template(

View File

@@ -18,8 +18,6 @@ import requests
import tiktoken
import yaml
from langchain_core.messages.chat import ChatMessage
from llama_cpp import LlamaTokenizer
from llama_cpp.llama import Llama
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -32,7 +30,6 @@ from khoj.database.models import (
KhojUser,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -85,12 +82,6 @@ model_to_prompt_size = {
"claude-sonnet-4-20250514": 60000,
"claude-opus-4-0": 60000,
"claude-opus-4-20250514": 60000,
# Offline Models
"bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
"bartowski/Llama-3.2-3B-Instruct-GGUF": 20000,
"bartowski/gemma-2-9b-it-GGUF": 6000,
"bartowski/gemma-2-2b-it-GGUF": 6000,
}
model_to_tokenizer: Dict[str, str] = {}
@@ -573,7 +564,6 @@ def generate_chatml_messages_with_context(
system_message: str = None,
chat_history: list[ChatMessageModel] = [],
model_name="gpt-4o-mini",
loaded_model: Optional[Llama] = None,
max_prompt_size=None,
tokenizer_name=None,
query_images=None,
@@ -588,10 +578,7 @@ def generate_chatml_messages_with_context(
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
if loaded_model:
max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
else:
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
@@ -735,7 +722,7 @@ def generate_chatml_messages_with_context(
message.content = [{"type": "text", "text": message.content}]
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order
return messages[::-1]
@@ -743,25 +730,20 @@ def generate_chatml_messages_with_context(
def get_encoder(
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
default_tokenizer = "gpt-4o"
try:
if loaded_model:
encoder = loaded_model.tokenizer()
elif model_name.startswith("gpt-") or model_name.startswith("o1"):
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
elif tokenizer_name:
if tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name]
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
state.pretrained_tokenizers[tokenizer_name] = encoder
else:
encoder = download_model(model_name).tokenizer()
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except:
encoder = tiktoken.encoding_for_model(default_tokenizer)
if state.verbose > 2:
@@ -773,7 +755,7 @@ def get_encoder(
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
@@ -825,11 +807,10 @@ def truncate_messages(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
encoder = get_encoder(model_name, tokenizer_name)
# Extract system message from messages
system_message = None

View File

@@ -235,7 +235,6 @@ def is_operator_model(model: str) -> ChatModel.ModelType | None:
"claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC,
"claude-sonnet-4": ChatModel.ModelType.ANTHROPIC,
"claude-opus-4": ChatModel.ModelType.ANTHROPIC,
"ui-tars-1.5": ChatModel.ModelType.OFFLINE,
}
for operator_model in operator_models:
if model.startswith(operator_model):

View File

@@ -15,7 +15,6 @@ from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo
from khoj.database.models import KhojUser, SpeechToTextModelOptions
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import (
ApiUserRateLimiter,
@@ -150,9 +149,6 @@ async def transcribe(
if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unsupported on server error
status_code = 501
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
speech2text_model = speech_to_text_config.model_name
if speech_to_text_config.ai_model_api:

View File

@@ -89,10 +89,6 @@ from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
send_message_to_model,
@@ -117,7 +113,6 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
LRU,
ConversationCommand,
@@ -168,14 +163,6 @@ async def is_ready_to_chat(user: KhojUser):
if user_chat_model == None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
chat_model_name = user_chat_model.name
max_tokens = user_chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return True
if (
user_chat_model
and (
@@ -1470,12 +1457,6 @@ async def send_message_to_model_wrapper(
vision_available = chat_model.vision_enabled
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
loaded_model = None
if model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
@@ -1483,7 +1464,6 @@ async def send_message_to_model_wrapper(
system_message=system_message,
chat_history=chat_history,
model_name=chat_model_name,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
@@ -1492,18 +1472,7 @@ async def send_message_to_model_wrapper(
query_files=query_files,
)
if model_type == ChatModel.ModelType.OFFLINE:
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.OPENAI:
if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
@@ -1565,19 +1534,12 @@ def send_message_to_model_wrapper_sync(
vision_available = chat_model.vision_enabled
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
loaded_model = None
if model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
chat_history=chat_history,
model_name=chat_model_name,
loaded_model=loaded_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=model_type,
@@ -1585,18 +1547,7 @@ def send_message_to_model_wrapper_sync(
query_files=query_files,
)
if model_type == ChatModel.ModelType.OFFLINE:
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModel.ModelType.OPENAI:
if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
@@ -1678,30 +1629,7 @@ async def agenerate_chat_response(
chat_model = vision_enabled_config
vision_available = True
if chat_model.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response_generator = converse_offline(
# Query
user_query=query_to_run,
# Context
references=compiled_references,
online_results=online_results,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
location_data=location_data,
user_name=user_name,
query_files=query_files,
chat_history=chat_history,
# Model
loaded_model=loaded_model,
model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
if chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name

View File

@@ -33,9 +33,6 @@ def cli(args=None):
parser.add_argument("--sslcert", type=str, help="Path to SSL certificate file")
parser.add_argument("--sslkey", type=str, help="Path to SSL key file")
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
parser.add_argument(
"--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
)
parser.add_argument(
"--anonymous-mode",
action="store_true",
@@ -54,9 +51,6 @@ def cli(args=None):
if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
# Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu
args.version_no = version("khoj")
if args.version:
# Show version of khoj installed and exit

View File

@@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import torch
from khoj.processor.conversation.offline.utils import download_model
logger = logging.getLogger(__name__)
@@ -62,20 +60,3 @@ class ImageSearchModel:
@dataclass
class SearchModels:
text_search: Optional[TextSearchModel] = None
@dataclass
class OfflineChatProcessorConfig:
loaded_model: Union[Any, None] = None
class OfflineChatProcessorModel:
def __init__(self, chat_model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", max_tokens: int = None):
self.chat_model = chat_model
self.loaded_model = None
try:
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
raise e

View File

@@ -10,13 +10,6 @@ empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
default_offline_chat_models = [
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
"bartowski/Llama-3.2-3B-Instruct-GGUF",
"bartowski/gemma-2-9b-it-GGUF",
"bartowski/gemma-2-2b-it-GGUF",
"bartowski/Qwen2.5-14B-Instruct-GGUF",
]
default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1", "o3", "o4-mini"]
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"]
default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"]

View File

@@ -16,7 +16,6 @@ from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tok
from khoj.utils.constants import (
default_anthropic_chat_models,
default_gemini_chat_models,
default_offline_chat_models,
default_openai_chat_models,
)
@@ -72,7 +71,6 @@ def initialization(interactive: bool = True):
default_api_key=openai_api_key,
api_base_url=openai_base_url,
vision_enabled=True,
is_offline=False,
interactive=interactive,
provider_name=provider,
)
@@ -118,7 +116,6 @@ def initialization(interactive: bool = True):
default_gemini_chat_models,
default_api_key=os.getenv("GEMINI_API_KEY"),
vision_enabled=True,
is_offline=False,
interactive=interactive,
provider_name="Google Gemini",
)
@@ -145,17 +142,6 @@ def initialization(interactive: bool = True):
default_anthropic_chat_models,
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
vision_enabled=True,
is_offline=False,
interactive=interactive,
)
# Set up offline chat models
_setup_chat_model_provider(
ChatModel.ModelType.OFFLINE,
default_offline_chat_models,
default_api_key=None,
vision_enabled=False,
is_offline=True,
interactive=interactive,
)
@@ -186,7 +172,6 @@ def initialization(interactive: bool = True):
interactive: bool,
api_base_url: str = None,
vision_enabled: bool = False,
is_offline: bool = False,
provider_name: str = None,
) -> Tuple[bool, AiModelApi]:
supported_vision_models = (
@@ -195,11 +180,6 @@ def initialization(interactive: bool = True):
provider_name = provider_name or model_type.name.capitalize()
default_use_model = default_api_key is not None
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment.
# This usually means there's not enough RAM to load offline models directly within the application.
# In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
if is_offline:
default_use_model = False
use_model_provider = (
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y"
@@ -211,13 +191,12 @@ def initialization(interactive: bool = True):
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
ai_model_api = None
if not is_offline:
if interactive:
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
api_key = user_api_key if user_api_key != "" else default_api_key
else:
api_key = default_api_key
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
if interactive:
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
api_key = user_api_key if user_api_key != "" else default_api_key
else:
api_key = default_api_key
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
if interactive:
user_chat_models = input(

View File

@@ -103,13 +103,8 @@ class OpenAIProcessorConfig(ConfigBase):
chat_model: Optional[str] = "gpt-4o-mini"
class OfflineChatProcessorConfig(ConfigBase):
chat_model: Optional[str] = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
class ConversationProcessorConfig(ConfigBase):
openai: Optional[OpenAIProcessorConfig] = None
offline_chat: Optional[OfflineChatProcessorConfig] = None
max_prompt_size: Optional[int] = None
tokenizer: Optional[str] = None

View File

@@ -12,7 +12,7 @@ from whisper import Whisper
from khoj.database.models import ProcessLock
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
from khoj.utils.config import OfflineChatProcessorModel, SearchModels
from khoj.utils.config import SearchModels
from khoj.utils.helpers import LRU, get_device, is_env_var_true
from khoj.utils.rawconfig import FullConfig
@@ -22,7 +22,6 @@ search_models = SearchModels()
embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: Dict[str, CrossEncoderModel] = None
openai_client: OpenAI = None
offline_chat_processor_config: OfflineChatProcessorModel = None
whisper_model: Whisper = None
config_file: Path = None
verbose: int = 0
@@ -39,7 +38,6 @@ telemetry: List[Dict[str, str]] = []
telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
khoj_version: str = None
device = get_device()
chat_on_gpu: bool = True
anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
billing_enabled: bool = (

View File

@@ -196,17 +196,6 @@ def default_openai_chat_model_option():
return chat_model
@pytest.mark.django_db
@pytest.fixture
def offline_agent():
chat_model = ChatModelFactory()
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
)
@pytest.mark.django_db
@pytest.fixture
def openai_agent():
@@ -516,40 +505,6 @@ def client(
return TestClient(app)
@pytest.fixture(scope="function")
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=default_user2,
)
all_files = fs_syncer.collect_files(user=default_user2)
configure_content(default_user2, all_files)
# Initialize Processor from Config
ChatModelFactory(
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
tokenizer=None,
max_prompt_size=None,
model_type="offline",
)
UserConversationProcessorConfigFactory(user=default_user2)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function")
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
# Setup

View File

@@ -19,7 +19,7 @@ from khoj.database.models import (
from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
@@ -93,7 +93,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory):
max_prompt_size = 20000
tokenizer = None
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
name = "gemini-2.0-flash"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)

View File

@@ -1,610 +0,0 @@
from datetime import datetime
import pytest
from khoj.database.models import ChatModel
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
from khoj.utils.helpers import ConversationCommand
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="Disable in CI to avoid long test runs.",
)
import freezegun
from freezegun import freeze_time
from khoj.processor.conversation.offline.chat_model import converse_offline
from khoj.processor.conversation.offline.utils import download_model
from khoj.utils.constants import default_offline_chat_models
@pytest.fixture(scope="session")
def loaded_model():
return download_model(default_offline_chat_models[0], max_tokens=5000)
freezegun.configure(extend_ignore_list=["transformers"])
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2):
# Act
response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model)
assert len(response) >= 1
assert any(
[
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2):
# Act
response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model)
# Assert
assert len(response) >= 1
# The user query should be the last question in the response
assert response[-1] == ["Which countries did I visit last month?"]
assert any(
[
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2):
# Act
response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model)
# Assert
expected_responses = [
("dt>='1984-01-01'", ""),
("dt>='1984-01-01'", "dt<'1985-01-01'"),
("dt>='1984-01-01'", "dt<='1984-12-31'"),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2):
# Act
responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model)
# Assert
assert len(responses) >= 2
assert ["the Sun" in response for response in responses]
assert ["the Moon" in response for response in responses]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2):
# Act
response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model)
# Assert
expected_responses = ["height", "taller", "shorter", "heights", "who"]
assert len(response) <= 3
for question in response:
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
"Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2):
# Arrange
message_list = [
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
query = "Does he have any sons?"
# Act
response = extract_questions(
query,
default_user2,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
any_expected_with_barbara = [
"sibling",
"brother",
]
any_expected_with_anderson = [
"son",
"sons",
"children",
"family",
]
# Assert
assert len(response) >= 1
# Ensure the remaining generated search queries use proper nouns and chat history context
for question in response:
if "Barbara" in question:
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
elif "Anderson" in question:
assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
else:
assert False, (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2):
# Arrange
message_list = [
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
# Act
response = extract_questions(
"Is she a Doctor?",
default_user2,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
expected_responses = [
"Barbara",
"Anderson",
]
# Assert
assert len(response) >= 1
assert any([expected_response in response[0] for expected_response in expected_responses]), (
"Expected chat actor to mention person's by name, but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2):
# Arrange
message_list = [
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
]
# Act
response = extract_questions(
"What was the Pizza place we ate at over there?",
default_user2,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
# Assert
expected_responses = [
("dt>='2000-04-01'", "dt<'2000-05-01'"),
("dt>='2000-04-01'", "dt<='2000-04-30'"),
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to April 2000 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
"user_query, expected_conversation_commands",
[
(
"Where did I learn to swim?",
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text},
),
(
"Where is the nearest hospital?",
{"sources": [ConversationCommand.Online], "output": ConversationCommand.Text},
),
(
"Summarize the wikipedia page on the history of the internet",
{"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text},
),
(
"How many noble gases are there?",
{"sources": [ConversationCommand.General], "output": ConversationCommand.Text},
),
(
"Make a painting incorporating my past diving experiences",
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image},
),
(
"Create a chart of the weather over the next 7 days in Timbuktu",
{"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text},
),
(
"What's the highest point in this country and have I been there?",
{"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text},
),
],
)
async def test_select_data_sources_actor_chooses_to_search_notes(
client_offline_chat, user_query, expected_conversation_commands, default_user2
):
# Act
selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)
# Assert
assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"])
assert expected_conversation_commands["output"] == selected_conversation_commands["output"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
chat_log = [
(
"Let's talk about the current events around the world.",
"Sure, let's discuss the current events. What would you like to know?",
[],
),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
]
chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log))
# Act
tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False)
# Assert
tools = [tool.value for tool in tools]
assert tools == ["online"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
# Act
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj", "KHOJ"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
"Chat actor needs to use context in previous notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
# Act
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse_offline(
references=[
{"compiled": "Testatron was born on 1st April 1984 in Testville."}
], # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question(loaded_model):
"Chat actor should not try make up answers to unanswerable questions."
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = [
"don't know",
"do not know",
"no information",
"do not have",
"don't have",
"cannot answer",
"I'm sorry",
]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to say they don't know in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_current_date_awareness(loaded_model):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Arrange
context = [
{
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD"""
},
{
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD"""
},
{
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD"""
},
{
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD"""
},
]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]acos in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
# Arrange
context = [
{
"compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD"""
},
{
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD"""
},
{
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD"""
},
{
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD"""
},
]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
assert "20" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
# Act
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["test", "testing"]
assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
# Arrange
context = [
{
"compiled": f"""# Ramya
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
},
{
"compiled": f"""# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
},
{
"compiled": f"""# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
},
]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister", "Which one"]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_agent_prompt_should_be_used(loaded_model, offline_agent):
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
# Arrange
context = [
{"compiled": f"""I went to the store and bought some bananas for 2.20"""},
{"compiled": f"""I went to the store and bought some apples for 1.30"""},
{"compiled": f"""I went to the store and bought some oranges for 6.00"""},
]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert that the model without the agent prompt does not include the summary of purchases
expected_responses = ["9.50", "9.5"]
assert all([expected_response not in response for expected_response in expected_responses]), (
"Expected chat actor to summarize values of purchases" + response
)
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
loaded_model=loaded_model,
agent=offline_agent,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert that the model with the agent prompt does include the summary of purchases
expected_responses = ["9.50", "9.5"]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to summarize values of purchases" + response
)
# ----------------------------------------------------------------------------------------------------
def test_chat_does_not_exceed_prompt_size(loaded_model):
"Ensure chat context and response together do not exceed max prompt size for the model"
# Arrange
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
context = [{"compiled": " ".join([f"{number}" for number in range(2043)])}]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What numbers come after these?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert prompt_size_exceeded_error not in response, (
"Expected chat response to be within prompt limits, but got exceeded error: " + response
)

View File

@@ -1,726 +0,0 @@
import urllib.parse
import pytest
from faker import Faker
from freezegun import freeze_time
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="Disable in CI to avoid long test runs.",
)
fake = Faker()
# Helpers
# ----------------------------------------------------------------------------------------------------
def generate_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
message_to_log(
user_message,
gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
chat_history=conversation_log.get("chat", []),
)
return conversation_log
def create_conversation(message_list, user, agent=None):
# Generate conversation logs
conversation_log = generate_history(message_list)
# Update Conversation Metadata Logs in Database
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
# Act
query = "Hello, my name is Testatron. Who are you?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Khoj", "khoj"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(client_offline_chat):
# Act
q = "/online give me the link to paul graham's essay how to do great work"
response = client_offline_chat.post(f"/api/chat", json={"q": q})
response_message = response.json()["response"]
# Assert
expected_responses = [
"https://paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html",
"http://www.paulgraham.com/greatwork.html",
]
assert response.status_code == 200
assert any(
[expected_response in response_message for expected_response in expected_responses]
), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_webpage_content(client_offline_chat):
# Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
response = client_offline_chat.post(f"/api/chat", json={"q": q})
response_message = response.json()["response"]
# Assert
expected_responses = ["185", "1871", "horse"]
assert response.status_code == 200
assert any(
[expected_response in response_message for expected_response in expected_responses]
), f"Expected response with {expected_responses}. But actual response had: {response_message}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
create_conversation(message_list, default_user2)
# Act
q = "What is my name?"
response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Testatron", "testatron"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected [T|t]estatron in response but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
create_conversation(message_list, default_user2)
# Act
q = "Where was Xi Li born?"
response = client_offline_chat.post(f"/api/chat", json={"q": q})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
create_conversation(message_list, default_user2)
# Act
q = "Where was I born?"
response = client_offline_chat.post(f"/api/chat", json={"q": q})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
# 1. Infer who I am from chat history
# 2. Infer I was born in Testville from previously retrieved notes
assert "Testville" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
create_conversation(message_list, default_user2)
# Act
q = "Where was I born?"
response = client_offline_chat.post(f"/api/chat", json={"q": q})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
# Inference in a multi-turn conversation
# 1. Infer who I am from chat history
# 2. Search for notes about when <my_name_from_chat_history> was born
# 3. Extract where I was born from currently retrieved notes
assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
create_conversation(message_list, default_user2)
# Act
q = "Where was I born?"
response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to say they don't know in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_using_general_command(client_offline_chat, default_user2):
# Arrange
message_list = []
create_conversation(message_list, default_user2)
# Act
query = "/general Where was Xi Li born?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert "Fujiang" not in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
# Arrange
message_list = []
create_conversation(message_list, default_user2)
# Act
query = "/notes Where was Xi Li born?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_using_file_filter(client_offline_chat, default_user2):
# Arrange
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
create_conversation(message_list, default_user2)
# Act
no_answer_response = client_offline_chat.post(
f"/api/chat", json={"q": no_answer_query, "stream": True}
).content.decode("utf-8")
answer_response = client_offline_chat.post(f"/api/chat", json={"q": answer_query, "stream": True}).content.decode(
"utf-8"
)
# Assert
assert "Fujiang" not in no_answer_response
assert "Fujiang" in answer_response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
# Arrange
message_list = []
create_conversation(message_list, default_user2)
# Act
query = urllib.parse.quote("/notes Where was Testatron born?")
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert response_message == prompts.no_notes_found.format()
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
# pick the file that has "Xi Li.markdown" in the name
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
# Act
query = "/summarize"
response = client_offline_chat.post(
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
)
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
# pick the file that has "Xi Li.markdown" in the name
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
# Act
query = "/summarize tell me about Xiu"
response = client_offline_chat.post(
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
)
response_message = response.content.decode("utf-8")
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "Xi Li.markdown" file to the file filters
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
response = client_offline_chat.post(
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
)
response = client_offline_chat.post(
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
)
# Act
query = "/summarize"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
response_message = response.json()["response"]
# Assert
assert response_message is not None
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = []
conversation = create_conversation(message_list, default_user2)
# Act
query = "/summarize"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
message_list = []
conversation1 = create_conversation(message_list, default_user2)
conversation2 = create_conversation(message_list, default_user2)
file_list = (
Entry.objects.filter(user=default_user2, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
summarization_file = ""
for file in file_list:
if "Birthday Gift for Xiu turning 4.markdown" in file:
summarization_file = file
break
assert summarization_file != ""
# add file filter to conversation 1.
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
)
query = "/summarize"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id})
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# now make sure that the file filter is still in conversation 1
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id})
response_message = response.json()["response"]
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = []
conversation = create_conversation(message_list, default_user2)
# post "imaginary.markdown" file to the file filters
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
)
# Act
query = "/summarize"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_summarize_diff_user_file(
client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
):
# Arrange
message_list = []
conversation = create_conversation(message_list, default_user2)
# Get the pdf file called singlepage.pdf
file_list = (
Entry.objects.filter(user=default_user, file_source="computer")
.distinct("file_path")
.values_list("file_path", flat=True)
)
summarization_file = ""
for file in file_list:
if "singlepage.pdf" in file:
summarization_file = file
break
assert summarization_file != ""
# add singlepage.pdf to the file filters
response = client_offline_chat.post(
"api/chat/conversation/file-filters",
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
# Act
query = "/summarize"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01", ignore=["transformers"])
def test_answer_requires_current_date_awareness(client_offline_chat):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Act
query = "Where did I have lunch today?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Arak", "Medellin"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to say Arak, Medellin, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01", ignore=["transformers"])
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
query = "How much did I spend on dining this year?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert "26" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
create_conversation(message_list, default_user2)
# Act
query = "Write a haiku about unit testing. Do not say anything else."
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["test", "Test"]
assert response.status_code == 200
assert len(response_message.splitlines()) == 3 # haikus are 3 lines long
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
# Act
query = "What is the name of Namitas older son"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = [
"which of them is the older",
"which one is older",
"which of them is older",
"which one is the older",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected chat director to ask for clarification in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
create_conversation(message_list, default_user2)
# Act
query = "What is my name?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Testatron", "testatron"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected [T|t]estatron in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
]
message_list2 = [
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 14th August 1947.", []),
("What's my favorite color", "Your favorite color is maroon.", []),
("Where was I born?", "You were born in a potato farm.", []),
]
conversation = create_conversation(message_list, default_user2)
create_conversation(message_list2, default_user2)
# Act
query = "What is my favorite color?"
response = client_offline_chat.post(
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
)
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["green"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_by_conversation_id_with_agent(
client_offline_chat, default_user2: KhojUser, offline_agent: Agent
):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
]
conversation = create_conversation(message_list, default_user2, offline_agent)
# Act
query = "/general What did I eat for breakfast?"
response = client_offline_chat.post(
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
)
response_message = response.content.decode("utf-8")
# Assert that agent only responds with the summary of spending
expected_responses = ["13.00", "13", "13.0", "thirteen"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
)
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
# Arrange
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
create_conversation(message_list, default_user2)
# Act
query = "What is my name?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert len(response_message) > 0
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_requires_multiple_independent_searches(client_offline_chat):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
query = "Is Xi older than Namita?"
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message
)