mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 21:29:08 +00:00
Improve Quality and Reliability of Offline Chat (#393)
# Incoming ## Major ### Fix Prompt Size Exceeded Issue - Fix issues related to prompt size, Closes #386. Use the correct tokenizer to calculate whether the input needs to be truncated or not. ### Improve Llama 2 Model Download - Use the correct download link for LlamaV2 -- should have been using the small model, but was using the medium - Add better downloading logic to retry download if it failed, Closes #379 ### Fix Segmentation Fault due to Race - Add a lock around generating chat responses from the offline model to avoid segmentation faults. Closes #367. - Add a loading symbol to the web chat UI when the model is thinking. Closes #392 ### Improve Chat Response Latency - Improve performance of offline chat by increasing batch size (via `n_batch`) to automatically engage more cores/GPU, using smaller model and fixing prompt vs response token generation numbers. Closes #363 ### Fix Fake Dialogue Continuation - Fix formatting of user query with offline chat, this was contributing to #398 - Stop Llama 2 from Creating Fake Dialogue Continuations. Closes #398 ## Minor - Improve default message for Chat window on web when it's not configured. Include hint to use offline chat. - Add null check in `perform_chat_checks` method - Add offline chat director unit tests ## Performance Analysis (Time to First Token) | | v0.10.0 | this branch | |-|-|-| | Query 1 | 52s | 28s | | Query 2 | 33s| 42s | | Query 3 | 67s| 38s|
This commit is contained in:
@@ -61,6 +61,7 @@
|
||||
// Add message by user to chat body
|
||||
renderMessage(query, "you");
|
||||
document.getElementById("chat-input").value = "";
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
|
||||
@@ -76,7 +77,9 @@
|
||||
new_response.appendChild(new_response_text);
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
new_response_text.innerHTML = "🤔";
|
||||
let loadingSpinner = document.createElement("div");
|
||||
loadingSpinner.classList.add("spinner");
|
||||
new_response_text.appendChild(loadingSpinner);
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
|
||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||
@@ -107,10 +110,10 @@
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
} else {
|
||||
// Display response from Khoj
|
||||
if (new_response_text.innerHTML === "🤔") {
|
||||
// Clear temporary status message
|
||||
new_response_text.innerHTML = "";
|
||||
if (new_response_text.getElementsByClassName("spinner").length > 0) {
|
||||
new_response_text.removeChild(loadingSpinner);
|
||||
}
|
||||
|
||||
new_response_text.innerHTML += chunk;
|
||||
readStream();
|
||||
}
|
||||
@@ -120,6 +123,7 @@
|
||||
});
|
||||
}
|
||||
readStream();
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -136,7 +140,7 @@
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
// If the server returns a 500 error with detail, render a setup hint.
|
||||
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
|
||||
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
|
||||
|
||||
// Disable chat input field and update placeholder text
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
@@ -269,6 +273,21 @@
|
||||
margin-left: auto;
|
||||
white-space: pre-line;
|
||||
}
|
||||
/* Spinner symbol when the chat message is loading */
|
||||
.spinner {
|
||||
border: 4px solid #f3f3f3;
|
||||
border-top: 4px solid var(--primary-inverse);
|
||||
border-radius: 50%;
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
animation: spin 2s linear infinite;
|
||||
margin: 0px 0px 0px 10px;
|
||||
display: inline-block;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
/* add left protrusion to khoj chat bubble */
|
||||
.chat-message-text.khoj:after {
|
||||
content: '';
|
||||
|
||||
28
src/khoj/migrations/migrate_offline_model.py
Normal file
28
src/khoj/migrations/migrate_offline_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import logging
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_model(args):
|
||||
schema_version = "0.10.1"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse("0.10.1"):
|
||||
logger.info(
|
||||
f"Migrating offline model used for version {previous_version} to latest version for {args.version_no}"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# If the user has downloaded the offline model, remove it from the cache.
|
||||
offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin")
|
||||
if os.path.exists(offline_model_path):
|
||||
os.remove(offline_model_path)
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
return args
|
||||
@@ -34,10 +34,9 @@ from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_processor_conversation_schema(args):
|
||||
schema_version = "0.10.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
raw_config["version"] = args.version_no
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
@@ -45,22 +44,24 @@ def migrate_processor_conversation_schema(args):
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
|
||||
# Add enable_offline_chat to khoj config schema
|
||||
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
|
||||
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
|
||||
if current_openai_api_key is None and current_chat_model is None:
|
||||
return args
|
||||
|
||||
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Add enable_offline_chat to khoj config schema
|
||||
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||
|
||||
# Update conversation processor schema
|
||||
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||
raw_config["processor"]["conversation"] = {
|
||||
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
|
||||
"conversation-logfile": conversation_logfile,
|
||||
"enable-offline-chat": False,
|
||||
}
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
||||
|
||||
@@ -2,11 +2,12 @@ from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_config_to_version(args):
|
||||
schema_version = "0.9.0"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
# Add version to khoj config schema
|
||||
if "version" not in raw_config:
|
||||
raw_config["version"] = args.version_no
|
||||
raw_config["version"] = schema_version
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
# regenerate khoj index on first start of this version
|
||||
|
||||
@@ -10,6 +10,7 @@ from gpt4all import GPT4All
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils import state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,7 +59,11 @@ def extract_questions_offline(
|
||||
next_christmas_date=next_christmas_date,
|
||||
)
|
||||
message = system_prompt + example_questions
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0)
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=256)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
@@ -119,13 +124,11 @@ def converse_offline(
|
||||
"""
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
# Initialize Variables
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||
if compiled_references_message == "":
|
||||
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
|
||||
conversation_primer = user_query
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
||||
query=user_query, references=compiled_references_message
|
||||
@@ -157,11 +160,20 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
stop_words = ["<s>"]
|
||||
chat_history = "".join(formatted_messages)
|
||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
|
||||
for response in response_iterator:
|
||||
g.send(response)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256)
|
||||
try:
|
||||
for response in response_iterator:
|
||||
if any(stop_word in response.strip() for stop_word in stop_words):
|
||||
logger.debug(f"Stop response as hit stop word in {response}")
|
||||
break
|
||||
g.send(response)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
model_name_to_url = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from khoj.processor.conversation.gpt4all import model_metadata
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(model_name):
|
||||
def download_model(model_name: str):
|
||||
url = model_metadata.model_name_to_url.get(model_name)
|
||||
if not url:
|
||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
||||
@@ -19,13 +19,16 @@ def download_model(model_name):
|
||||
if os.path.exists(filename):
|
||||
return GPT4All(model_name)
|
||||
|
||||
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
|
||||
tmp_filename = filename + ".tmp"
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
|
||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with open(filename, "wb") as f, tqdm(
|
||||
with open(tmp_filename, "wb") as f, tqdm(
|
||||
unit="B", # unit string to be displayed.
|
||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||
unit_divisor=1024, # is used when unit_scale is true
|
||||
@@ -35,7 +38,14 @@ def download_model(model_name):
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
# Move the tmp file to the actual file
|
||||
os.rename(tmp_filename, filename)
|
||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
||||
return GPT4All(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
||||
# Remove the tmp file if it exists
|
||||
if os.path.exists(tmp_filename):
|
||||
os.remove(tmp_filename)
|
||||
return None
|
||||
|
||||
@@ -19,14 +19,13 @@ Question: {query}
|
||||
)
|
||||
|
||||
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
|
||||
Using your general knowledge and our past conversations as context, answer the following question."""
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
If you do not know the answer, say 'I don't know.'"""
|
||||
|
||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant.
|
||||
- When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Try to be as specific as possible. For example, rather than use "they" or "it", use the name of the person or thing you are referring to.
|
||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Write the question as if you can search for the answer on the user's personal notes.
|
||||
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
||||
- Add as much context from the previous questions and notes as required into your search queries.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- Provide search queries as a list of questions
|
||||
What follow-up questions, if any, will you need to ask to answer the user's question?
|
||||
"""
|
||||
@@ -129,11 +128,6 @@ Answer (in second person):"""
|
||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
|
||||
<s>[INST]<<SYS>>
|
||||
Use these notes from the user's previous conversations to provide a response:
|
||||
{chat_history}
|
||||
<</SYS>>[/INST]</s>
|
||||
|
||||
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
|
||||
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
|
||||
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
|
||||
@@ -141,6 +135,10 @@ Use these notes from the user's previous conversations to provide a response:
|
||||
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
||||
<s>[INST]How are you feeling today?[/INST]</s>
|
||||
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
|
||||
<s>[INST]<<SYS>>
|
||||
Use these notes from the user's previous conversations to provide a response:
|
||||
{chat_history}
|
||||
<</SYS>>[/INST]</s>
|
||||
<s>[INST]{query}[/INST]
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -7,13 +7,15 @@ import tiktoken
|
||||
|
||||
# External packages
|
||||
from langchain.schema import ChatMessage
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
# Internal Packages
|
||||
import queue
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548}
|
||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
@@ -40,6 +42,10 @@ class ThreadedGenerator:
|
||||
return item
|
||||
|
||||
def send(self, data):
|
||||
if self.response == "":
|
||||
time_to_first_response = perf_counter() - self.start_time
|
||||
logger.debug(f"First response took: {time_to_first_response:.3f} seconds")
|
||||
|
||||
self.response += data
|
||||
self.queue.put(data)
|
||||
|
||||
@@ -100,30 +106,35 @@ def generate_chatml_messages_with_context(
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(messages, max_prompt_size, model_name):
|
||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
try:
|
||||
|
||||
if "llama" in model_name:
|
||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
||||
else:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
encoder = tiktoken.encoding_for_model("text-davinci-001")
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
|
||||
# Truncate last message if still over max supported prompt size by model
|
||||
if tokens > max_prompt_size:
|
||||
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
logger.debug(
|
||||
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
|
||||
return messages
|
||||
return messages + [system_message]
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
||||
@@ -15,9 +15,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def perform_chat_checks():
|
||||
if state.processor_config.conversation and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
if (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
@@ -89,37 +93,36 @@ def generate_chat_response(
|
||||
chat_response = None
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
@@ -8,6 +8,7 @@ from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.yaml import parse_config_from_file
|
||||
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
||||
from khoj.migrations.migrate_offline_model import migrate_offline_model
|
||||
|
||||
|
||||
def cli(args=None):
|
||||
@@ -55,7 +56,7 @@ def cli(args=None):
|
||||
|
||||
|
||||
def run_migrations(args):
|
||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
|
||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
|
||||
for migration in migrations:
|
||||
args = migration(args)
|
||||
return args
|
||||
|
||||
@@ -25,6 +25,7 @@ port: int = None
|
||||
cli_args: List[str] = None
|
||||
query_cache = LRU()
|
||||
config_lock = threading.Lock()
|
||||
chat_lock = threading.Lock()
|
||||
SearchType = utils_config.SearchType
|
||||
telemetry: List[Dict[str, str]] = []
|
||||
previous_query: str = None
|
||||
|
||||
@@ -170,6 +170,20 @@ def processor_config(tmp_path_factory):
|
||||
return processor_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def processor_config_offline_chat(tmp_path_factory):
|
||||
processor_dir = tmp_path_factory.mktemp("processor")
|
||||
|
||||
# Setup conversation processor
|
||||
processor_config = ProcessorConfig()
|
||||
processor_config.conversation = ConversationProcessorConfig(
|
||||
enable_offline_chat=True,
|
||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||
)
|
||||
|
||||
return processor_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
# Initialize app state
|
||||
@@ -211,6 +225,32 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
|
||||
configure_routes(app)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_offline_chat(
|
||||
content_config: ContentConfig, search_config: SearchConfig, processor_config_offline_chat: ProcessorConfig
|
||||
):
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||
|
||||
configure_routes(app)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestTruncateMessage:
|
||||
|
||||
def test_truncate_message_all_small(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(500)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
@@ -27,7 +26,6 @@ class TestTruncateMessage:
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_messages) < 500
|
||||
assert len(chat_messages) > 1
|
||||
assert prompt == chat_messages
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
@@ -52,14 +50,17 @@ class TestTruncateMessage:
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
chat_messages.append(big_chat_message)
|
||||
chat_messages.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_messages) < 26
|
||||
assert len(chat_messages) > 1
|
||||
# The original object has been modified. Verify certain properties.
|
||||
assert len(prompt) == (
|
||||
len(chat_messages) + 1
|
||||
) # Because the system_prompt is popped off from the chat_messages lsit
|
||||
assert len(prompt) < 26
|
||||
assert len(prompt) > 1
|
||||
assert prompt[0] != copy_big_chat_message
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
@@ -35,6 +35,7 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
@@ -54,7 +55,7 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
@@ -76,9 +77,29 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
||||
def test_extract_question_with_date_filter_from_relative_year():
|
||||
# Act
|
||||
response = extract_questions_offline("Which countries have I visited this year?")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='1984-01-01'", ""),
|
||||
("dt>='1984-01-01'", "dt<'1985-01-01'"),
|
||||
("dt>='1984-01-01'", "dt<='1984-12-31'"),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to 1984 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_includes_root_question(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model)
|
||||
|
||||
@@ -107,13 +128,13 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("morpheus", "neo", "height", "taller", "shorter"),
|
||||
]
|
||||
assert len(response) == 3
|
||||
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
|
||||
"Expected two search queries in response but got: " + response[0]
|
||||
)
|
||||
expected_responses = ["height", "taller", "shorter", "heights"]
|
||||
assert len(response) <= 3
|
||||
|
||||
for question in response:
|
||||
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -148,7 +169,6 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# @pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
@@ -178,7 +198,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware")
|
||||
@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
@@ -219,7 +239,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"]
|
||||
expected_responses = ["Khoj", "khoj", "KHOJ"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response
|
||||
@@ -406,6 +426,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||
@@ -434,6 +455,28 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
"Ensure chat context and response together do not exceed max prompt size for the model"
|
||||
# Arrange
|
||||
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
|
||||
context = [" ".join([f"{number}" for number in range(2043)])]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What numbers come after these?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert prompt_size_exceeded_error not in response, (
|
||||
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_filter_questions():
|
||||
test_questions = [
|
||||
"I don't know how to answer that",
|
||||
|
||||
308
tests/test_gpt4all_chat_director.py
Normal file
308
tests/test_gpt4all_chat_director.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# External Packages
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from faker import Faker
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
SKIP_TESTS = True
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
||||
)
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, llm_message, context in message_list:
|
||||
conversation_log["chat"] += message_to_log(
|
||||
user_message,
|
||||
llm_message,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
|
||||
# Update Conversation Metadata Logs in Application State
|
||||
state.processor_config.conversation.meta_log = conversation_log
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected [T|t]estatron in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# 1. Infer who I am from chat history
|
||||
# 2. Infer I was born in Testville from previously retrieved notes
|
||||
assert "Testville" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# Inference in a multi-turn conversation
|
||||
# 1. Infer who I am from chat history
|
||||
# 2. Search for notes about when <my_name_from_chat_history> was born
|
||||
# 3. Extract where I was born from currently retrieved notes
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say they don't know in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where did I have lunch today?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Arak", "Medellin"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say Arak, Medellin, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "23" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(
|
||||
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "Test"]
|
||||
assert response.status_code == 200
|
||||
assert len(response_message.splitlines()) == 3 # haikus are 3 lines long
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected [T|t]est in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"which of them is the older",
|
||||
"which one is older",
|
||||
"which of them is older",
|
||||
"which one is the older",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat director to ask for clarification in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected [T|t]estatron in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_chat_history_very_long(client_offline_chat):
|
||||
# Arrange
|
||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||
|
||||
populate_chat_history(message_list)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert len(response_message) > 0
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
Reference in New Issue
Block a user