Improve Quality and Reliability of Offline Chat (#393)

# Incoming
## Major
### Fix Prompt Size Exceeded Issue
- Fix issues related to prompt size, Closes #386. Use the correct tokenizer to calculate whether the input needs to be truncated or not.

### Improve Llama 2 Model Download
- Use the correct download link for LlamaV2 -- should have been using the small model, but was using the medium
- Add better downloading logic to retry download if it failed, Closes #379 

### Fix Segmentation Fault due to Race
- Add a lock around generating chat responses from the offline model to avoid segmentation faults. Closes #367.
- Add a loading symbol to the web chat UI when the model is thinking. Closes #392

### Improve Chat Response Latency
- Improve performance of offline chat by increasing batch size (via `n_batch`) to automatically engage more cores/GPU, using smaller model and fixing prompt vs response token generation numbers. Closes #363

### Fix Fake Dialogue Continuation
- Fix formatting of user query with offline chat, this was contributing to #398
- Stop Llama 2 from Creating Fake Dialogue Continuations. Closes #398

## Minor
- Improve default message for Chat window on web when it's not configured. Include hint to use offline chat.
- Add null check in `perform_chat_checks` method
- Add offline chat director unit tests

## Performance Analysis (Time to First Token)
|  | v0.10.0 | this branch |
|-|-|-|
| Query 1 | 52s | 28s |
| Query 2 | 33s| 42s |
| Query 3 | 67s| 38s|
This commit is contained in:
Debanjum
2023-08-01 22:07:27 -07:00
committed by GitHub
16 changed files with 578 additions and 101 deletions

View File

@@ -61,6 +61,7 @@
// Add message by user to chat body
renderMessage(query, "you");
document.getElementById("chat-input").value = "";
document.getElementById("chat-input").setAttribute("disabled", "disabled");
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
@@ -76,7 +77,9 @@
new_response.appendChild(new_response_text);
// Temporary status message to indicate that Khoj is thinking
new_response_text.innerHTML = "🤔";
let loadingSpinner = document.createElement("div");
loadingSpinner.classList.add("spinner");
new_response_text.appendChild(loadingSpinner);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
// Call specified Khoj API which returns a streamed response of type text/plain
@@ -107,10 +110,10 @@
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
} else {
// Display response from Khoj
if (new_response_text.innerHTML === "🤔") {
// Clear temporary status message
new_response_text.innerHTML = "";
if (new_response_text.getElementsByClassName("spinner").length > 0) {
new_response_text.removeChild(loadingSpinner);
}
new_response_text.innerHTML += chunk;
readStream();
}
@@ -120,6 +123,7 @@
});
}
readStream();
document.getElementById("chat-input").removeAttribute("disabled");
});
}
@@ -136,7 +140,7 @@
.then(data => {
if (data.detail) {
// If the server returns a 500 error with detail, render a setup hint.
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
// Disable chat input field and update placeholder text
document.getElementById("chat-input").setAttribute("disabled", "disabled");
@@ -269,6 +273,21 @@
margin-left: auto;
white-space: pre-line;
}
/* Spinner symbol when the chat message is loading */
.spinner {
border: 4px solid #f3f3f3;
border-top: 4px solid var(--primary-inverse);
border-radius: 50%;
width: 12px;
height: 12px;
animation: spin 2s linear infinite;
margin: 0px 0px 0px 10px;
display: inline-block;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* add left protrusion to khoj chat bubble */
.chat-message-text.khoj:after {
content: '';

View File

@@ -0,0 +1,28 @@
import os
import logging
from packaging import version
from khoj.utils.yaml import load_config_from_file, save_config_to_file
logger = logging.getLogger(__name__)
def migrate_offline_model(args):
schema_version = "0.10.1"
raw_config = load_config_from_file(args.config_file)
previous_version = raw_config.get("version")
if previous_version is None or version.parse(previous_version) < version.parse("0.10.1"):
logger.info(
f"Migrating offline model used for version {previous_version} to latest version for {args.version_no}"
)
raw_config["version"] = schema_version
# If the user has downloaded the offline model, remove it from the cache.
offline_model_path = os.path.expanduser("~/.cache/gpt4all/llama-2-7b-chat.ggmlv3.q4_K_S.bin")
if os.path.exists(offline_model_path):
os.remove(offline_model_path)
save_config_to_file(raw_config, args.config_file)
return args

View File

@@ -34,10 +34,9 @@ from khoj.utils.yaml import load_config_from_file, save_config_to_file
def migrate_processor_conversation_schema(args):
schema_version = "0.10.0"
raw_config = load_config_from_file(args.config_file)
raw_config["version"] = args.version_no
if "processor" not in raw_config:
return args
if raw_config["processor"] is None:
@@ -45,22 +44,24 @@ def migrate_processor_conversation_schema(args):
if "conversation" not in raw_config["processor"]:
return args
# Add enable_offline_chat to khoj config schema
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
save_config_to_file(raw_config, args.config_file)
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
if current_openai_api_key is None and current_chat_model is None:
return args
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
raw_config["version"] = schema_version
# Add enable_offline_chat to khoj config schema
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
# Update conversation processor schema
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
raw_config["processor"]["conversation"] = {
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
"conversation-logfile": conversation_logfile,
"enable-offline-chat": False,
}
save_config_to_file(raw_config, args.config_file)
return args

View File

@@ -2,11 +2,12 @@ from khoj.utils.yaml import load_config_from_file, save_config_to_file
def migrate_config_to_version(args):
schema_version = "0.9.0"
raw_config = load_config_from_file(args.config_file)
# Add version to khoj config schema
if "version" not in raw_config:
raw_config["version"] = args.version_no
raw_config["version"] = schema_version
save_config_to_file(raw_config, args.config_file)
# regenerate khoj index on first start of this version

View File

@@ -10,6 +10,7 @@ from gpt4all import GPT4All
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
from khoj.processor.conversation import prompts
from khoj.utils.constants import empty_escape_sequences
from khoj.utils import state
logger = logging.getLogger(__name__)
@@ -58,7 +59,11 @@ def extract_questions_offline(
next_christmas_date=next_christmas_date,
)
message = system_prompt + example_questions
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0)
state.chat_lock.acquire()
try:
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=256)
finally:
state.chat_lock.release()
# Extract, Clean Message from GPT's Response
try:
@@ -119,13 +124,11 @@ def converse_offline(
"""
gpt4all_model = loaded_model or GPT4All(model)
# Initialize Variables
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references_message = "\n\n".join({f"{item}" for item in references})
# Get Conversation Primer appropriate to Conversation Type
# TODO If compiled_references_message is too long, we need to truncate it.
if compiled_references_message == "":
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_llamav2.format(
query=user_query, references=compiled_references_message
@@ -157,11 +160,20 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
for message in conversation_history
]
stop_words = ["<s>"]
chat_history = "".join(formatted_messages)
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
for response in response_iterator:
g.send(response)
state.chat_lock.acquire()
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256)
try:
for response in response_iterator:
if any(stop_word in response.strip() for stop_word in stop_words):
logger.debug(f"Stop response as hit stop word in {response}")
break
g.send(response)
finally:
state.chat_lock.release()
g.close()

View File

@@ -1,3 +1,3 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
}

View File

@@ -9,7 +9,7 @@ from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
def download_model(model_name):
def download_model(model_name: str):
url = model_metadata.model_name_to_url.get(model_name)
if not url:
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
@@ -19,13 +19,16 @@ def download_model(model_name):
if os.path.exists(filename):
return GPT4All(model_name)
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
tmp_filename = filename + ".tmp"
try:
os.makedirs(os.path.dirname(filename), exist_ok=True)
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with open(filename, "wb") as f, tqdm(
with open(tmp_filename, "wb") as f, tqdm(
unit="B", # unit string to be displayed.
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
@@ -35,7 +38,14 @@ def download_model(model_name):
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))
# Move the tmp file to the actual file
os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
# Remove the tmp file if it exists
if os.path.exists(tmp_filename):
os.remove(tmp_filename)
return None

View File

@@ -19,14 +19,13 @@ Question: {query}
)
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question."""
Using your general knowledge and our past conversations as context, answer the following question.
If you do not know the answer, say 'I don't know.'"""
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant.
- When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Try to be as specific as possible. For example, rather than use "they" or "it", use the name of the person or thing you are referring to.
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Write the question as if you can search for the answer on the user's personal notes.
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
- Add as much context from the previous questions and notes as required into your search queries.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- Provide search queries as a list of questions
What follow-up questions, if any, will you need to ask to answer the user's question?
"""
@@ -129,11 +128,6 @@ Answer (in second person):"""
extract_questions_llamav2_sample = PromptTemplate.from_template(
"""
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
<s>[INST]<<SYS>>
Use these notes from the user's previous conversations to provide a response:
{chat_history}
<</SYS>>[/INST]</s>
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
@@ -141,6 +135,10 @@ Use these notes from the user's previous conversations to provide a response:
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
<s>[INST]How are you feeling today?[/INST]</s>
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
<s>[INST]<<SYS>>
Use these notes from the user's previous conversations to provide a response:
{chat_history}
<</SYS>>[/INST]</s>
<s>[INST]{query}[/INST]
"""
)

View File

@@ -7,13 +7,15 @@ import tiktoken
# External packages
from langchain.schema import ChatMessage
from transformers import LlamaTokenizerFast
# Internal Packages
import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548}
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
class ThreadedGenerator:
@@ -40,6 +42,10 @@ class ThreadedGenerator:
return item
def send(self, data):
if self.response == "":
time_to_first_response = perf_counter() - self.start_time
logger.debug(f"First response took: {time_to_first_response:.3f} seconds")
self.response += data
self.queue.put(data)
@@ -100,30 +106,35 @@ def generate_chatml_messages_with_context(
return messages[::-1]
def truncate_messages(messages, max_prompt_size, model_name):
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
try:
if "llama" in model_name:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
else:
encoder = tiktoken.encoding_for_model(model_name)
except KeyError:
encoder = tiktoken.encoding_for_model("text-davinci-001")
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1:
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages])
# Truncate last message if still over max supported prompt size by model
if tokens > max_prompt_size:
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
current_message = "\n".join(messages[0].content.split("\n")[:-1])
original_question = "\n".join(messages[0].content.split("\n")[-1:])
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
logger.debug(
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages
return messages + [system_message]
def reciprocal_conversation_to_chatml(message_pair):

View File

@@ -15,9 +15,13 @@ logger = logging.getLogger(__name__)
def perform_chat_checks():
if state.processor_config.conversation and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
if (
state.processor_config
and state.processor_config.conversation
and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
)
):
return
@@ -89,37 +93,36 @@ def generate_chat_response(
chat_response = None
try:
with timer("Generating chat response took", logger):
partial_completion = partial(
_save_to_conversation_log,
q,
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
meta_log=meta_log,
partial_completion = partial(
_save_to_conversation_log,
q,
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
meta_log=meta_log,
)
if state.processor_config.conversation.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_offline(
references=compiled_references,
user_query=q,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
)
if state.processor_config.conversation.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_offline(
references=compiled_references,
user_query=q,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
chat_response = converse(
compiled_references,
q,
meta_log,
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
chat_response = converse(
compiled_references,
q,
meta_log,
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
)
except Exception as e:
logger.error(e, exc_info=True)

View File

@@ -8,6 +8,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.yaml import parse_config_from_file
from khoj.migrations.migrate_version import migrate_config_to_version
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
from khoj.migrations.migrate_offline_model import migrate_offline_model
def cli(args=None):
@@ -55,7 +56,7 @@ def cli(args=None):
def run_migrations(args):
migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
for migration in migrations:
args = migration(args)
return args

View File

@@ -25,6 +25,7 @@ port: int = None
cli_args: List[str] = None
query_cache = LRU()
config_lock = threading.Lock()
chat_lock = threading.Lock()
SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = []
previous_query: str = None

View File

@@ -170,6 +170,20 @@ def processor_config(tmp_path_factory):
return processor_config
@pytest.fixture(scope="session")
def processor_config_offline_chat(tmp_path_factory):
processor_dir = tmp_path_factory.mktemp("processor")
# Setup conversation processor
processor_config = ProcessorConfig()
processor_config.conversation = ConversationProcessorConfig(
enable_offline_chat=True,
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)
return processor_config
@pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
# Initialize app state
@@ -211,6 +225,32 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
content_config.image, state.search_models.image_search, regenerate=False
)
state.processor_config = configure_processor(processor_config)
configure_routes(app)
return TestClient(app)
@pytest.fixture(scope="function")
def client_offline_chat(
content_config: ContentConfig, search_config: SearchConfig, processor_config_offline_chat: ProcessorConfig
):
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
state.processor_config = configure_processor(processor_config_offline_chat)
configure_routes(app)
return TestClient(app)

View File

@@ -19,7 +19,6 @@ class TestTruncateMessage:
def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
@@ -27,7 +26,6 @@ class TestTruncateMessage:
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 500
assert len(chat_messages) > 1
assert prompt == chat_messages
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
@@ -52,14 +50,17 @@ class TestTruncateMessage:
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.append(big_chat_message)
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 26
assert len(chat_messages) > 1
# The original object has been modified. Verify certain properties.
assert len(prompt) == (
len(chat_messages) + 1
) # Because the system_prompt is popped off from the chat_messages lsit
assert len(prompt) < 26
assert len(prompt) > 1
assert prompt[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size

View File

@@ -35,6 +35,7 @@ freezegun.configure(extend_ignore_list=["transformers"])
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
@@ -54,7 +55,7 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
@@ -76,9 +77,29 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
def test_extract_question_with_date_filter_from_relative_year():
# Act
response = extract_questions_offline("Which countries have I visited this year?")
# Assert
expected_responses = [
("dt>='1984-01-01'", ""),
("dt>='1984-01-01'", "dt<'1985-01-01'"),
("dt>='1984-01-01'", "dt<='1984-12-31'"),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_includes_root_question(loaded_model):
# Act
response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model)
@@ -107,13 +128,13 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
# Assert
expected_responses = [
("morpheus", "neo", "height", "taller", "shorter"),
]
assert len(response) == 3
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
expected_responses = ["height", "taller", "shorter", "heights"]
assert len(response) <= 3
for question in response:
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
"Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question
)
# ----------------------------------------------------------------------------------------------------
@@ -148,7 +169,6 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# ----------------------------------------------------------------------------------------------------
# @pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange
@@ -178,7 +198,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware")
@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
# Arrange
@@ -219,7 +239,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"]
expected_responses = ["Khoj", "khoj", "KHOJ"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response
@@ -406,6 +426,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
@@ -434,6 +455,28 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
# ----------------------------------------------------------------------------------------------------
def test_chat_does_not_exceed_prompt_size(loaded_model):
"Ensure chat context and response together do not exceed max prompt size for the model"
# Arrange
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
context = [" ".join([f"{number}" for number in range(2043)])]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What numbers come after these?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert prompt_size_exceeded_error not in response, (
"Expected chat response to be within prompt limits, but got exceeded error: " + response
)
# ----------------------------------------------------------------------------------------------------
def test_filter_questions():
test_questions = [
"I don't know how to answer that",

View File

@@ -0,0 +1,308 @@
# External Packages
import pytest
from freezegun import freeze_time
from faker import Faker
# Internal Packages
from khoj.processor.conversation.utils import message_to_log
from khoj.utils import state
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
)
fake = Faker()
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, llm_message, context in message_list:
conversation_log["chat"] += message_to_log(
user_message,
llm_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
# Update Conversation Metadata Logs in Application State
state.processor_config.conversation.meta_log = conversation_log
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
# Act
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Khoj", "khoj"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history(client_offline_chat):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Testatron", "testatron"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected [T|t]estatron in response but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_from_currently_retrieved_content(client_offline_chat):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
# 1. Infer who I am from chat history
# 2. Infer I was born in Testville from previously retrieved notes
assert "Testville" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
# Inference in a multi-turn conversation
# 1. Infer who I am from chat history
# 2. Search for notes about when <my_name_from_chat_history> was born
# 3. Extract where I was born from currently retrieved notes
assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to say they don't know in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(client_offline_chat):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Act
response = client_offline_chat.get(f'/api/chat?q="Where did I have lunch today?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Arak", "Medellin"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to say Arak, Medellin, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = client_offline_chat.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert "23" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
)
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["test", "Test"]
assert response.status_code == 200
assert len(response_message.splitlines()) == 3 # haikus are 3 lines long
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
# Act
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = [
"which of them is the older",
"which one is older",
"which of them is older",
"which one is the older",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected chat director to ask for clarification in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Testatron", "testatron"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected [T|t]estatron in response, but got: " + response_message
)
@pytest.mark.chatquality
def test_answer_chat_history_very_long(client_offline_chat):
# Arrange
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
populate_chat_history(message_list)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert len(response_message) > 0
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_requires_multiple_independent_searches(client_offline_chat):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
response = client_offline_chat.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message
)