Merge branch 'master' into migrate-to-llama-cpp-for-offline-chat

This commit is contained in:
Debanjum Singh Solanky
2024-03-31 00:59:20 +05:30
51 changed files with 1708 additions and 145 deletions

View File

@@ -12,6 +12,7 @@ from khoj.configure import (
configure_search_types,
)
from khoj.database.models import (
Agent,
GithubConfig,
GithubRepoConfig,
KhojApiUser,
@@ -181,6 +182,28 @@ def api_user4(default_user4):
)
@pytest.mark.django_db
@pytest.fixture
def offline_agent():
chat_model = ChatModelOptionsFactory()
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
)
@pytest.mark.django_db
@pytest.fixture
def openai_agent():
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent.",
)
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()

View File

@@ -34,7 +34,9 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Ensure raw entry with no headings do not get heading prefix prepended
assert not jsonl_data[0]["raw"].startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert jsonl_data[0]["compiled"].startswith(expected_heading)
assert expected_heading in jsonl_data[0]["compiled"]
# Ensure compiled entry also includes the file name
assert str(tmp_path) in jsonl_data[0]["compiled"]
def test_single_markdown_entry_to_jsonl(tmp_path):

View File

@@ -467,6 +467,47 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_agent_prompt_should_be_used(loaded_model, offline_agent):
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
# Arrange
context = [
f"""I went to the store and bought some bananas for 2.20""",
f"""I went to the store and bought some apples for 1.30""",
f"""I went to the store and bought some oranges for 6.00""",
]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert that the model without the agent prompt does not include the summary of purchases
expected_responses = ["9.50", "9.5"]
assert all([expected_response not in response for expected_response in expected_responses]), (
"Expected chat actor to summarize values of purchases" + response
)
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
loaded_model=loaded_model,
agent=offline_agent,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert that the model with the agent prompt does include the summary of purchases
expected_responses = ["9.50", "9.5"]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to summarize values of purchases" + response
)
# ----------------------------------------------------------------------------------------------------
def test_chat_does_not_exceed_prompt_size(loaded_model):
"Ensure chat context and response together do not exceed max prompt size for the model"

View File

@@ -6,6 +6,7 @@ import pytest
from faker import Faker
from freezegun import freeze_time
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
@@ -26,20 +27,20 @@ def generate_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
conversation_log["chat"] += message_to_log(
message_to_log(
user_message,
gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
conversation_log=conversation_log.get("chat", []),
)
return conversation_log
def populate_chat_history(message_list, user):
def create_conversation(message_list, user, agent=None):
# Generate conversation logs
conversation_log = generate_history(message_list)
# Update Conversation Metadata Logs in Database
ConversationFactory(user=user, conversation_log=conversation_log)
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
# Tests
@@ -114,7 +115,7 @@ def test_answer_from_chat_history(client_offline_chat, default_user2):
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -141,7 +142,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat, default_us
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
@@ -165,7 +166,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@@ -191,7 +192,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@@ -217,7 +218,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, def
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
@@ -238,7 +239,7 @@ def test_answer_using_general_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -256,7 +257,7 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat,
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -275,7 +276,7 @@ def test_answer_using_file_filter(client_offline_chat, default_user2):
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
@@ -293,7 +294,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat, default_user2
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -351,7 +352,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(
@@ -394,14 +395,14 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -415,13 +416,77 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, defa
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
]
message_list2 = [
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 14th August 1947.", []),
("What's my favorite color", "Your favorite color is maroon.", []),
("Where was I born?", "You were born in a potato farm.", []),
]
conversation = create_conversation(message_list, default_user2)
create_conversation(message_list2, default_user2)
# Act
response = client_offline_chat.get(
f'/api/chat?q="What is my favorite color?"&conversation_id={conversation.id}&stream=true'
)
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["green"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_by_conversation_id_with_agent(
client_offline_chat, default_user2: KhojUser, offline_agent: Agent
):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
]
conversation = create_conversation(message_list, default_user2, offline_agent)
# Act
query = urllib.parse.quote("/general What did I eat for breakfast?")
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert that agent only responds with the summary of spending
expected_responses = ["13.00", "13", "13.0", "thirteen"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
)
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
# Arrange
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -525,7 +590,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
]
chat_history = populate_chat_history(chat_log, default_user2)
chat_history = create_conversation(chat_log, default_user2)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)

View File

@@ -414,6 +414,42 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_agent_prompt_should_be_used(openai_agent):
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
# Arrange
context = [
f"""I went to the store and bought some bananas for 2.20""",
f"""I went to the store and bought some apples for 1.30""",
f"""I went to the store and bought some oranges for 6.00""",
]
expected_responses = ["9.50", "9.5"]
# Act
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,
)
no_agent_response = "".join([response_chunk for response_chunk in response_gen])
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,
agent=openai_agent,
)
agent_response = "".join([response_chunk for response_chunk in response_gen])
# Assert that the model without the agent prompt does not include the summary of purchases
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
"Expected chat actor to summarize values of purchases" + no_agent_response
)
assert any([expected_response in agent_response for expected_response in expected_responses]), (
"Expected chat actor to summarize values of purchases" + agent_response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)

View File

@@ -5,13 +5,10 @@ from urllib.parse import quote
import pytest
from freezegun import freeze_time
from khoj.database.models import KhojUser
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
)
from khoj.routers.helpers import aget_relevant_information_sources
from tests.helpers import ConversationFactory
# Initialize variables for tests
@@ -29,20 +26,21 @@ def generate_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
conversation_log["chat"] += message_to_log(
message_to_log(
user_message,
gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
conversation_log=conversation_log.get("chat", []),
)
return conversation_log
def populate_chat_history(message_list, user):
def create_conversation(message_list, user, agent=None):
# Generate conversation logs
conversation_log = generate_history(message_list)
# Update Conversation Metadata Logs in Database
ConversationFactory(user=user, conversation_log=conversation_log)
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
# Tests
@@ -116,7 +114,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -143,7 +141,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
@@ -167,7 +165,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
@@ -190,7 +188,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
@@ -215,7 +213,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
@@ -244,7 +242,7 @@ def test_answer_using_general_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@@ -262,7 +260,7 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@@ -280,7 +278,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
@@ -335,7 +333,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true')
@@ -387,7 +385,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list, default_user2)
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -401,6 +399,68 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
]
message_list2 = [
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 14th August 1947.", []),
("What's my favorite color", "Your favorite color is maroon.", []),
("Where was I born?", "You were born in a potato farm.", []),
]
conversation = create_conversation(message_list, default_user2)
create_conversation(message_list2, default_user2)
# Act
query = urllib.parse.quote("/general What is my favorite color?")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["green"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_by_conversation_id_with_agent(
chat_client, default_user2: KhojUser, openai_agent: Agent
):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
]
conversation = create_conversation(message_list, default_user2, openai_agent)
# Act
query = urllib.parse.quote("/general What did I eat for breakfast?")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
# Assert that agent only responds with the summary of spending
expected_responses = ["13.00", "13", "13.0", "thirteen"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality