Improve data source, output mode selection

- Set output mode to single string. Specify output schema in prompt
  - Both thesee should encourage model to select only 1 output mode
    instead of encouraging it in prompt too many times
  - Output schema should also improve schema following in general
- Standardize variable, func name of io selector for readability
- Fix chat actors to test the io selector chat actor
- Make chat actor return sources, output separately for better
  disambiguation, at least during tests, for now
This commit is contained in:
Debanjum
2024-11-18 12:49:48 -08:00
parent e3fd51d14b
commit 653127bf1d
6 changed files with 104 additions and 88 deletions

View File

@@ -630,37 +630,36 @@ pick_relevant_tools = PromptTemplate.from_template(
""" """
You are Khoj, an extremely smart and helpful search assistant. You are Khoj, an extremely smart and helpful search assistant.
{personality_context} {personality_context}
- You have access to a variety of data sources to help you answer the user's question - You have access to a variety of data sources to help you answer the user's question.
- You can use the data sources listed below to collect more relevant information - You can use any subset of data sources listed below to collect more relevant information.
- You can select certain types of output to respond to the user's question. Select just one output type to answer the user's question - You can select the most appropriate output format from the options listed below to respond to the user's question.
- You can use any combination of these data sources and output types to answer the user's question - Both the data sources and output format should be selected based on the user's query and relevant context provided in the chat history.
- You can only select one output type to answer the user's question
Which of the tools listed below you would use to answer the user's question? You **only** have access to the following: Which of the data sources, output format listed below would you use to answer the user's question? You **only** have access to the following:
Inputs: Data Sources:
{tools} {sources}
Outputs: Output Formats:
{outputs} {outputs}
Here are some examples: Here are some examples:
Example: Example:
Chat History: Chat History:
User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco. User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene. AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
Q: What is the population of each of those cities? Q: Chart the population growth of each of those cities in the last decade
Khoj: {{"source": ["online"], "output": ["text"]}} Khoj: {{"source": ["online", "code"], "output": "text"}}
Example: Example:
Chat History: Chat History:
User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting. User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting
AI: Excellent! Taking a vacation is a great way to relax and recharge. AI: Excellent! Taking a vacation is a great way to relax and recharge.
Q: Where did Grandma grow up? Q: Where did Grandma grow up?
Khoj: {{"source": ["notes"], "output": ["text"]}} Khoj: {{"source": ["notes"], "output": "text"}}
Example: Example:
Chat History: Chat History:
@@ -668,7 +667,7 @@ User: Good morning
AI: Good morning! How can I help you today? AI: Good morning! How can I help you today?
Q: How can I share my files with Khoj? Q: How can I share my files with Khoj?
Khoj: {{"source": ["default", "online"], "output": ["text"]}} Khoj: {{"source": ["default", "online"], "output": "text"}}
Example: Example:
Chat History: Chat History:
@@ -676,17 +675,18 @@ User: What is the first element in the periodic table?
AI: The first element in the periodic table is Hydrogen. AI: The first element in the periodic table is Hydrogen.
Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen
Khoj: {{"source": ["webpage"], "output": ["text"]}} Khoj: {{"source": ["webpage"], "output": "text"}}
Example: Example:
Chat History: Chat History:
User: I want to start a new hobby. I'm thinking of learning to play the guitar. User: I'm learning to play the guitar, so I can make a band with my friends
AI: Learning to play the guitar is a great hobby. It can be a lot of fun and a great way to express yourself. AI: Learning to play the guitar is a great hobby. It can be a fun way to socialize and express yourself.
Q: Draw a painting of a guitar. Q: Create a painting of my recent jamming sessions
Khoj: {{"source": ["general"], "output": ["image"]}} Khoj: {{"source": ["notes"], "output": "image"}}
Now it's your turn to pick the sources and output to answer the user's query. Respond with a JSON object, including both `source` and `output`. The values should be a list of strings. Do not say anything else. Now it's your turn to pick the appropriate data sources and output format to answer the user's query. Respond with a JSON object, including both `source` and `output` in the following format. Do not say anything else.
{{"source": list[str], "output': str}}
Chat History: Chat History:
{chat_history} {chat_history}

View File

@@ -46,7 +46,7 @@ from khoj.routers.helpers import (
FeedbackData, FeedbackData,
acreate_title_from_history, acreate_title_from_history,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_tools_to_execute, aget_data_sources_and_output_format,
construct_automation_created_message, construct_automation_created_message,
create_automation, create_automation,
gather_raw_query_files, gather_raw_query_files,
@@ -752,7 +752,7 @@ async def chat(
attached_file_context = gather_raw_query_files(query_files) attached_file_context = gather_raw_query_files(query_files)
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_tools_to_execute( chosen_io = await aget_data_sources_and_output_format(
q, q,
meta_log, meta_log,
is_automated_task, is_automated_task,
@@ -762,6 +762,7 @@ async def chat(
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
) )
conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")]
# If we're doing research, we don't want to do anything else # If we're doing research, we don't want to do anything else
if ConversationCommand.Research in conversation_commands: if ConversationCommand.Research in conversation_commands:

View File

@@ -336,7 +336,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
return is_safe, reason return is_safe, reason
async def aget_relevant_tools_to_execute( async def aget_data_sources_and_output_format(
query: str, query: str,
conversation_history: dict, conversation_history: dict,
is_task: bool, is_task: bool,
@@ -345,33 +345,33 @@ async def aget_relevant_tools_to_execute(
agent: Agent = None, agent: Agent = None,
query_files: str = None, query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ) -> Dict[str, Any]:
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available data sources and output modes the agent should use to answer appropriately.
""" """
tool_options = dict() source_options = dict()
tool_options_str = "" source_options_str = ""
agent_tools = agent.input_tools if agent else [] agent_sources = agent.input_tools if agent else []
for tool, description in tool_descriptions_for_llm.items(): for source, description in tool_descriptions_for_llm.items():
tool_options[tool.value] = description source_options[source.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools: if len(agent_sources) == 0 or source.value in agent_sources:
tool_options_str += f'- "{tool.value}": "{description}"\n' source_options_str += f'- "{source.value}": "{description}"\n'
mode_options = dict() output_options = dict()
mode_options_str = "" output_options_str = ""
output_modes = agent.output_modes if agent else [] agent_outputs = agent.output_modes if agent else []
for mode, description in mode_descriptions_for_llm.items(): for output, description in mode_descriptions_for_llm.items():
# Do not allow tasks to schedule another task # Do not allow tasks to schedule another task
if is_task and mode == ConversationCommand.Automation: if is_task and output == ConversationCommand.Automation:
continue continue
mode_options[mode.value] = description output_options[output.value] = description
if len(output_modes) == 0 or mode.value in output_modes: if len(agent_outputs) == 0 or output.value in agent_outputs:
mode_options_str += f'- "{mode.value}": "{description}"\n' output_options_str += f'- "{output.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
@@ -384,8 +384,8 @@ async def aget_relevant_tools_to_execute(
relevant_tools_prompt = prompts.pick_relevant_tools.format( relevant_tools_prompt = prompts.pick_relevant_tools.format(
query=query, query=query,
tools=tool_options_str, sources=source_options_str,
outputs=mode_options_str, outputs=output_options_str,
chat_history=chat_history, chat_history=chat_history,
personality_context=personality_context, personality_context=personality_context,
) )
@@ -403,43 +403,42 @@ async def aget_relevant_tools_to_execute(
response = clean_json(response) response = clean_json(response)
response = json.loads(response) response = json.loads(response)
input_tools = [q.strip() for q in response.get("source", []) if q.strip()] selected_sources = [q.strip() for q in response.get("source", []) if q.strip()]
output_modes = [q.strip() for q in response.get("output", ["text"]) if q.strip()] # Default to text output selected_output = response.get("output", "text").strip() # Default to text output
if not isinstance(input_tools, list) or not input_tools or len(input_tools) == 0: if not isinstance(selected_sources, list) or not selected_sources or len(selected_sources) == 0:
raise ValueError( raise ValueError(
f"Invalid response for determining relevant tools: {input_tools}. Raw Response: {response}" f"Invalid response for determining relevant tools: {selected_sources}. Raw Response: {response}"
) )
final_response = [] if not is_task else [ConversationCommand.AutomatedTask] result: Dict = {"sources": [], "output": None} if not is_task else {"output": ConversationCommand.AutomatedTask}
for llm_suggested_tool in input_tools: for selected_source in selected_sources:
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options. # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if llm_suggested_tool in tool_options.keys() and ( if (
len(agent_tools) == 0 or llm_suggested_tool in agent_tools selected_source in source_options.keys()
and isinstance(result["sources"], list)
and (len(agent_sources) == 0 or selected_source in agent_sources)
): ):
# Check whether the tool exists as a valid ConversationCommand # Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool)) result["sources"].append(ConversationCommand(selected_source))
for llm_suggested_output in output_modes: # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options. if selected_output in output_options.keys() and (len(agent_outputs) == 0 or selected_output in agent_outputs):
if llm_suggested_output in mode_options.keys() and ( # Check whether the tool exists as a valid ConversationCommand
len(output_modes) == 0 or llm_suggested_output in output_modes result["output"] = ConversationCommand(selected_output)
):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_output))
if is_none_or_empty(final_response): if is_none_or_empty(result):
if len(agent_tools) == 0: if len(agent_sources) == 0:
final_response = [ConversationCommand.Default, ConversationCommand.Text] result = {"sources": [ConversationCommand.Default], "output": ConversationCommand.Text}
else: else:
final_response = [ConversationCommand.General, ConversationCommand.Text] result = {"sources": [ConversationCommand.General], "output": ConversationCommand.Text}
except Exception as e: except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. Error: {e}", exc_info=True) logger.error(f"Invalid response for determining relevant tools: {response}. Error: {e}", exc_info=True)
if len(agent_tools) == 0: sources = agent_sources if len(agent_sources) > 0 else [ConversationCommand.Default]
final_response = [ConversationCommand.Default, ConversationCommand.Text] output = agent_outputs[0] if len(agent_outputs) > 0 else ConversationCommand.Text
else: result = {"sources": sources, "output": output}
final_response = agent_tools
return final_response return result
async def infer_webpage_urls( async def infer_webpage_urls(

View File

@@ -7,7 +7,7 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_tools_to_execute from khoj.routers.helpers import aget_data_sources_and_output_format
from tests.helpers import ConversationFactory from tests.helpers import ConversationFactory
SKIP_TESTS = True SKIP_TESTS = True
@@ -735,7 +735,7 @@ async def test_get_correct_tools_online(client_offline_chat):
user_query = "What's the weather in Patagonia this week?" user_query = "What's the weather in Patagonia this week?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -750,7 +750,7 @@ async def test_get_correct_tools_notes(client_offline_chat):
user_query = "Where did I go for my first battleship training?" user_query = "Where did I go for my first battleship training?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -765,7 +765,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat
user_query = "What's the highest point in Patagonia and have I been there?" user_query = "What's the highest point in Patagonia and have I been there?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -782,7 +782,7 @@ async def test_get_correct_tools_general(client_offline_chat):
user_query = "How many noble gases are there?" user_query = "How many noble gases are there?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -806,7 +806,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
chat_history = create_conversation(chat_log, default_user2) chat_history = create_conversation(chat_log, default_user2)
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, chat_history, is_task=False) tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]

View File

@@ -1,4 +1,3 @@
import os
from datetime import datetime from datetime import datetime
import freezegun import freezegun
@@ -8,7 +7,7 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import ( from khoj.routers.helpers import (
aget_relevant_tools_to_execute, aget_data_sources_and_output_format,
generate_online_subqueries, generate_online_subqueries,
infer_webpage_urls, infer_webpage_urls,
schedule_query, schedule_query,
@@ -529,19 +528,36 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u
@pytest.mark.parametrize( @pytest.mark.parametrize(
"user_query, expected_conversation_commands", "user_query, expected_conversation_commands",
[ [
("Where did I learn to swim?", [ConversationCommand.Notes]), (
("Where is the nearest hospital?", [ConversationCommand.Online]), "Where did I learn to swim?",
("Summarize the wikipedia page on the history of the internet", [ConversationCommand.Webpage]), {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text},
),
(
"Where is the nearest hospital?",
{"sources": [ConversationCommand.Online], "output": ConversationCommand.Text},
),
(
"Summarize the wikipedia page on the history of the internet",
{"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text},
),
(
"Make a painting incorporating my past diving experiences",
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image},
),
(
"Create a chart of the weather over the next 7 days in Timbuktu",
{"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text},
),
], ],
) )
async def test_select_data_sources_actor_chooses_to_search_notes( async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands chat_client, user_query, expected_conversation_commands, default_user2
): ):
# Act # Act
conversation_commands = await aget_relevant_tools_to_execute(user_query, {}, False, False) selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)
# Assert # Assert
assert set(expected_conversation_commands) == set(conversation_commands) assert expected_conversation_commands == selected_conversation_commands
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View File

@@ -8,7 +8,7 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_tools_to_execute from khoj.routers.helpers import aget_data_sources_and_output_format
from tests.helpers import ConversationFactory from tests.helpers import ConversationFactory
# Initialize variables for tests # Initialize variables for tests
@@ -719,7 +719,7 @@ async def test_get_correct_tools_online(chat_client):
user_query = "What's the weather in Patagonia this week?" user_query = "What's the weather in Patagonia this week?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) tools = await aget_data_sources_and_output_format(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -734,7 +734,7 @@ async def test_get_correct_tools_notes(chat_client):
user_query = "Where did I go for my first battleship training?" user_query = "Where did I go for my first battleship training?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) tools = await aget_data_sources_and_output_format(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -749,7 +749,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
user_query = "What's the highest point in Patagonia and have I been there?" user_query = "What's the highest point in Patagonia and have I been there?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) tools = await aget_data_sources_and_output_format(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -766,7 +766,7 @@ async def test_get_correct_tools_general(chat_client):
user_query = "How many noble gases are there?" user_query = "How many noble gases are there?"
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) tools = await aget_data_sources_and_output_format(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@@ -790,7 +790,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
chat_history = generate_history(chat_log) chat_history = generate_history(chat_log)
# Act # Act
tools = await aget_relevant_tools_to_execute(user_query, chat_history, False, False) tools = await aget_data_sources_and_output_format(user_query, chat_history, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]