diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py
index d3dd9bf5..6f91fdf4 100644
--- a/src/khoj/processor/conversation/gpt4all/chat_model.py
+++ b/src/khoj/processor/conversation/gpt4all/chat_model.py
@@ -119,7 +119,7 @@ def converse_offline(
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
loaded_model: Union[GPT4All, None] = None,
completion_func=None,
- conversation_command=ConversationCommand.Notes,
+ conversation_command=ConversationCommand.Default,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@@ -129,10 +129,10 @@ def converse_offline(
compiled_references_message = "\n\n".join({f"{item}" for item in references})
# Get Conversation Primer appropriate to Conversation Type
- if conversation_command == ConversationCommand.General:
- conversation_primer = user_query
- elif conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
+ if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()])
+ elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
+ conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_llamav2.format(
query=user_query, references=compiled_references_message
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 9185e3c7..8105c2d7 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -109,7 +109,7 @@ def converse(
api_key: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
- conversation_command=ConversationCommand.Notes,
+ conversation_command=ConversationCommand.Default,
):
"""
Converse with user using OpenAI's ChatGPT
@@ -119,11 +119,11 @@ def converse(
compiled_references = "\n\n".join({f"# {item}" for item in references})
# Get Conversation Primer appropriate to Conversation Type
- if conversation_command == ConversationCommand.General:
- conversation_primer = prompts.general_conversation.format(current_date=current_date, query=user_query)
- elif conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
+ if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
+ elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
+ conversation_primer = prompts.general_conversation.format(current_date=current_date, query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(
current_date=current_date, query=user_query, references=compiled_references
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index e5c08ff3..dcfc1bf4 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -4,7 +4,7 @@ from langchain.prompts import PromptTemplate
## Personality
## --
-personality = PromptTemplate.from_template("You are Khoj, a friendly, smart and helpful personal assistant.")
+personality = PromptTemplate.from_template("You are Khoj, a smart, inquisitive and helpful personal assistant.")
## General Conversation
@@ -77,7 +77,9 @@ conversation_llamav2 = PromptTemplate.from_template(
## --
notes_conversation = PromptTemplate.from_template(
"""
-Using the notes and our past conversations as context, answer the following question.
+Using my personal notes and our past conversations as context, answer the following question.
+Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
+These questions should end with a question mark.
Current Date: {current_date}
Notes:
@@ -236,9 +238,10 @@ Q:"""
# --
help_message = PromptTemplate.from_template(
"""
+**/notes**: Chat using the information in your knowledge base.
+**/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
+**/default**: Chat using your knowledge base and Khoj's general knowledge for context.
**/help**: Show this help message.
-**/notes**: Chat using the information in your knowledge base. This is the default method.
-**/general**: Chat using general knowledge with the LLM. This will not search against your notes.
You are using the **{model}** model.
**version**: {version}
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 4c5541b1..4f7c6f42 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -705,7 +705,7 @@ async def chat(
compiled_references, inferred_queries = await extract_references_and_questions(
request, q, (n or 5), conversation_command
)
- conversation_command = get_conversation_command(query=q, any_references=is_none_or_empty(compiled_references))
+ conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references))
if conversation_command == ConversationCommand.Help:
model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
@@ -755,7 +755,7 @@ async def extract_references_and_questions(
request: Request,
q: str,
n: int,
- conversation_type: ConversationCommand = ConversationCommand.Notes,
+ conversation_type: ConversationCommand = ConversationCommand.Default,
):
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 63f82a1d..267af330 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -60,15 +60,15 @@ def update_telemetry_state(
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
- elif query.startswith("/general"):
- return ConversationCommand.General
elif query.startswith("/help"):
return ConversationCommand.Help
+ elif query.startswith("/general"):
+ return ConversationCommand.General
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
else:
- return ConversationCommand.Notes
+ return ConversationCommand.Default
def generate_chat_response(
@@ -76,7 +76,7 @@ def generate_chat_response(
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
- conversation_command: ConversationCommand = ConversationCommand.Notes,
+ conversation_command: ConversationCommand = ConversationCommand.Default,
) -> Union[ThreadedGenerator, Iterator[str]]:
def _save_to_conversation_log(
q: str,
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 7d02497f..9bd139d4 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -214,13 +214,15 @@ def log_telemetry(
class ConversationCommand(str, Enum):
+ Default = "default"
General = "general"
Notes = "notes"
Help = "help"
command_descriptions = {
- ConversationCommand.General: "This command allows you to search talk with the LLM without including context from your knowledge base.",
- ConversationCommand.Notes: "This command allows you to search talk with the LLM while including context from your knowledge base.",
- ConversationCommand.Help: "This command displays a help message with all available commands and other metadata.",
+ ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.",
+ ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
+ ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
+ ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}
diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py
index 111a6a12..6da7f759 100644
--- a/tests/test_gpt4all_chat_director.py
+++ b/tests/test_gpt4all_chat_director.py
@@ -1,10 +1,13 @@
+# Standard Packages
+import urllib.parse
+
# External Packages
import pytest
from freezegun import freeze_time
from faker import Faker
-
# Internal Packages
+from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.utils import state
@@ -172,6 +175,57 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
)
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_using_general_command(client_offline_chat):
+ # Arrange
+ query = urllib.parse.quote("/general Where was Xi Li born?")
+ message_list = []
+ populate_chat_history(message_list)
+
+ # Act
+ response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response.status_code == 200
+ assert "Fujiang" not in response_message
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
+ # Arrange
+ query = urllib.parse.quote("/notes Where was Xi Li born?")
+ message_list = []
+ populate_chat_history(message_list)
+
+ # Act
+ response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response.status_code == 200
+ assert "Fujiang" in response_message
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_not_known_using_notes_command(client_offline_chat):
+ # Arrange
+ query = urllib.parse.quote("/notes Where was Testatron born?")
+ message_list = []
+ populate_chat_history(message_list)
+
+ # Act
+ response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
+ response_message = response.content.decode("utf-8")
+
+ # Assert
+ assert response.status_code == 200
+ assert response_message == prompts.no_notes_found.format()
+
+
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.chatquality
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index a28c3d04..4f05fc52 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -280,7 +280,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# ----------------------------------------------------------------------------------------------------
-@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# Act
@@ -289,10 +288,10 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# Assert
expected_responses = [
- "which of them is the older",
- "which one is older",
- "which of them is older",
- "which one is the older",
+ "which of them",
+ "which one is",
+ "which of namita's sons",
+ "the birth order",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (