diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index a1b3033a..03eb61da 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -630,37 +630,36 @@ pick_relevant_tools = PromptTemplate.from_template( """ You are Khoj, an extremely smart and helpful search assistant. {personality_context} -- You have access to a variety of data sources to help you answer the user's question -- You can use the data sources listed below to collect more relevant information -- You can select certain types of output to respond to the user's question. Select just one output type to answer the user's question -- You can use any combination of these data sources and output types to answer the user's question -- You can only select one output type to answer the user's question +- You have access to a variety of data sources to help you answer the user's question. +- You can use any subset of data sources listed below to collect more relevant information. +- You can select the most appropriate output format from the options listed below to respond to the user's question. +- Both the data sources and output format should be selected based on the user's query and relevant context provided in the chat history. -Which of the tools listed below you would use to answer the user's question? You **only** have access to the following: +Which of the data sources, output format listed below would you use to answer the user's question? You **only** have access to the following: -Inputs: -{tools} +Data Sources: +{sources} -Outputs: +Output Formats: {outputs} Here are some examples: Example: Chat History: -User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco. +User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene. -Q: What is the population of each of those cities? -Khoj: {{"source": ["online"], "output": ["text"]}} +Q: Chart the population growth of each of those cities in the last decade +Khoj: {{"source": ["online", "code"], "output": "text"}} Example: Chat History: -User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting. +User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting AI: Excellent! Taking a vacation is a great way to relax and recharge. Q: Where did Grandma grow up? -Khoj: {{"source": ["notes"], "output": ["text"]}} +Khoj: {{"source": ["notes"], "output": "text"}} Example: Chat History: @@ -668,7 +667,7 @@ User: Good morning AI: Good morning! How can I help you today? Q: How can I share my files with Khoj? -Khoj: {{"source": ["default", "online"], "output": ["text"]}} +Khoj: {{"source": ["default", "online"], "output": "text"}} Example: Chat History: @@ -676,17 +675,18 @@ User: What is the first element in the periodic table? AI: The first element in the periodic table is Hydrogen. Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen -Khoj: {{"source": ["webpage"], "output": ["text"]}} +Khoj: {{"source": ["webpage"], "output": "text"}} Example: Chat History: -User: I want to start a new hobby. I'm thinking of learning to play the guitar. -AI: Learning to play the guitar is a great hobby. It can be a lot of fun and a great way to express yourself. +User: I'm learning to play the guitar, so I can make a band with my friends +AI: Learning to play the guitar is a great hobby. It can be a fun way to socialize and express yourself. -Q: Draw a painting of a guitar. -Khoj: {{"source": ["general"], "output": ["image"]}} +Q: Create a painting of my recent jamming sessions +Khoj: {{"source": ["notes"], "output": "image"}} -Now it's your turn to pick the sources and output to answer the user's query. Respond with a JSON object, including both `source` and `output`. The values should be a list of strings. Do not say anything else. +Now it's your turn to pick the appropriate data sources and output format to answer the user's query. Respond with a JSON object, including both `source` and `output` in the following format. Do not say anything else. +{{"source": list[str], "output': str}} Chat History: {chat_history} diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index be469439..a9086dd0 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -46,7 +46,7 @@ from khoj.routers.helpers import ( FeedbackData, acreate_title_from_history, agenerate_chat_response, - aget_relevant_tools_to_execute, + aget_data_sources_and_output_format, construct_automation_created_message, create_automation, gather_raw_query_files, @@ -752,7 +752,7 @@ async def chat( attached_file_context = gather_raw_query_files(query_files) if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_tools_to_execute( + chosen_io = await aget_data_sources_and_output_format( q, meta_log, is_automated_task, @@ -762,6 +762,7 @@ async def chat( query_files=attached_file_context, tracer=tracer, ) + conversation_commands = chosen_io.get("sources") + [chosen_io.get("output")] # If we're doing research, we don't want to do anything else if ConversationCommand.Research in conversation_commands: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 109deca2..52d67c6e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -336,7 +336,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: return is_safe, reason -async def aget_relevant_tools_to_execute( +async def aget_data_sources_and_output_format( query: str, conversation_history: dict, is_task: bool, @@ -345,33 +345,33 @@ async def aget_relevant_tools_to_execute( agent: Agent = None, query_files: str = None, tracer: dict = {}, -): +) -> Dict[str, Any]: """ - Given a query, determine which of the available tools the agent should use in order to answer appropriately. + Given a query, determine which of the available data sources and output modes the agent should use to answer appropriately. """ - tool_options = dict() - tool_options_str = "" + source_options = dict() + source_options_str = "" - agent_tools = agent.input_tools if agent else [] + agent_sources = agent.input_tools if agent else [] - for tool, description in tool_descriptions_for_llm.items(): - tool_options[tool.value] = description - if len(agent_tools) == 0 or tool.value in agent_tools: - tool_options_str += f'- "{tool.value}": "{description}"\n' + for source, description in tool_descriptions_for_llm.items(): + source_options[source.value] = description + if len(agent_sources) == 0 or source.value in agent_sources: + source_options_str += f'- "{source.value}": "{description}"\n' - mode_options = dict() - mode_options_str = "" + output_options = dict() + output_options_str = "" - output_modes = agent.output_modes if agent else [] + agent_outputs = agent.output_modes if agent else [] - for mode, description in mode_descriptions_for_llm.items(): + for output, description in mode_descriptions_for_llm.items(): # Do not allow tasks to schedule another task - if is_task and mode == ConversationCommand.Automation: + if is_task and output == ConversationCommand.Automation: continue - mode_options[mode.value] = description - if len(output_modes) == 0 or mode.value in output_modes: - mode_options_str += f'- "{mode.value}": "{description}"\n' + output_options[output.value] = description + if len(agent_outputs) == 0 or output.value in agent_outputs: + output_options_str += f'- "{output.value}": "{description}"\n' chat_history = construct_chat_history(conversation_history) @@ -384,8 +384,8 @@ async def aget_relevant_tools_to_execute( relevant_tools_prompt = prompts.pick_relevant_tools.format( query=query, - tools=tool_options_str, - outputs=mode_options_str, + sources=source_options_str, + outputs=output_options_str, chat_history=chat_history, personality_context=personality_context, ) @@ -403,43 +403,42 @@ async def aget_relevant_tools_to_execute( response = clean_json(response) response = json.loads(response) - input_tools = [q.strip() for q in response.get("source", []) if q.strip()] - output_modes = [q.strip() for q in response.get("output", ["text"]) if q.strip()] # Default to text output + selected_sources = [q.strip() for q in response.get("source", []) if q.strip()] + selected_output = response.get("output", "text").strip() # Default to text output - if not isinstance(input_tools, list) or not input_tools or len(input_tools) == 0: + if not isinstance(selected_sources, list) or not selected_sources or len(selected_sources) == 0: raise ValueError( - f"Invalid response for determining relevant tools: {input_tools}. Raw Response: {response}" + f"Invalid response for determining relevant tools: {selected_sources}. Raw Response: {response}" ) - final_response = [] if not is_task else [ConversationCommand.AutomatedTask] - for llm_suggested_tool in input_tools: + result: Dict = {"sources": [], "output": None} if not is_task else {"output": ConversationCommand.AutomatedTask} + for selected_source in selected_sources: # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options. - if llm_suggested_tool in tool_options.keys() and ( - len(agent_tools) == 0 or llm_suggested_tool in agent_tools + if ( + selected_source in source_options.keys() + and isinstance(result["sources"], list) + and (len(agent_sources) == 0 or selected_source in agent_sources) ): # Check whether the tool exists as a valid ConversationCommand - final_response.append(ConversationCommand(llm_suggested_tool)) + result["sources"].append(ConversationCommand(selected_source)) - for llm_suggested_output in output_modes: - # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options. - if llm_suggested_output in mode_options.keys() and ( - len(output_modes) == 0 or llm_suggested_output in output_modes - ): - # Check whether the tool exists as a valid ConversationCommand - final_response.append(ConversationCommand(llm_suggested_output)) + # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options. + if selected_output in output_options.keys() and (len(agent_outputs) == 0 or selected_output in agent_outputs): + # Check whether the tool exists as a valid ConversationCommand + result["output"] = ConversationCommand(selected_output) - if is_none_or_empty(final_response): - if len(agent_tools) == 0: - final_response = [ConversationCommand.Default, ConversationCommand.Text] + if is_none_or_empty(result): + if len(agent_sources) == 0: + result = {"sources": [ConversationCommand.Default], "output": ConversationCommand.Text} else: - final_response = [ConversationCommand.General, ConversationCommand.Text] + result = {"sources": [ConversationCommand.General], "output": ConversationCommand.Text} except Exception as e: logger.error(f"Invalid response for determining relevant tools: {response}. Error: {e}", exc_info=True) - if len(agent_tools) == 0: - final_response = [ConversationCommand.Default, ConversationCommand.Text] - else: - final_response = agent_tools - return final_response + sources = agent_sources if len(agent_sources) > 0 else [ConversationCommand.Default] + output = agent_outputs[0] if len(agent_outputs) > 0 else ConversationCommand.Text + result = {"sources": sources, "output": output} + + return result async def infer_webpage_urls( diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index 9e5df09a..aa9bd5d1 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -7,7 +7,7 @@ from freezegun import freeze_time from khoj.database.models import Agent, Entry, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log -from khoj.routers.helpers import aget_relevant_tools_to_execute +from khoj.routers.helpers import aget_data_sources_and_output_format from tests.helpers import ConversationFactory SKIP_TESTS = True @@ -735,7 +735,7 @@ async def test_get_correct_tools_online(client_offline_chat): user_query = "What's the weather in Patagonia this week?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) + tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) # Assert tools = [tool.value for tool in tools] @@ -750,7 +750,7 @@ async def test_get_correct_tools_notes(client_offline_chat): user_query = "Where did I go for my first battleship training?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) + tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) # Assert tools = [tool.value for tool in tools] @@ -765,7 +765,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat user_query = "What's the highest point in Patagonia and have I been there?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) + tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) # Assert tools = [tool.value for tool in tools] @@ -782,7 +782,7 @@ async def test_get_correct_tools_general(client_offline_chat): user_query = "How many noble gases are there?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False) + tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) # Assert tools = [tool.value for tool in tools] @@ -806,7 +806,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_ chat_history = create_conversation(chat_log, default_user2) # Act - tools = await aget_relevant_tools_to_execute(user_query, chat_history, is_task=False) + tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False) # Assert tools = [tool.value for tool in tools] diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index dd3c5575..87533ab4 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -1,4 +1,3 @@ -import os from datetime import datetime import freezegun @@ -8,7 +7,7 @@ from freezegun import freeze_time from khoj.processor.conversation.openai.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import ( - aget_relevant_tools_to_execute, + aget_data_sources_and_output_format, generate_online_subqueries, infer_webpage_urls, schedule_query, @@ -529,19 +528,36 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u @pytest.mark.parametrize( "user_query, expected_conversation_commands", [ - ("Where did I learn to swim?", [ConversationCommand.Notes]), - ("Where is the nearest hospital?", [ConversationCommand.Online]), - ("Summarize the wikipedia page on the history of the internet", [ConversationCommand.Webpage]), + ( + "Where did I learn to swim?", + {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text}, + ), + ( + "Where is the nearest hospital?", + {"sources": [ConversationCommand.Online], "output": ConversationCommand.Text}, + ), + ( + "Summarize the wikipedia page on the history of the internet", + {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text}, + ), + ( + "Make a painting incorporating my past diving experiences", + {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image}, + ), + ( + "Create a chart of the weather over the next 7 days in Timbuktu", + {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text}, + ), ], ) async def test_select_data_sources_actor_chooses_to_search_notes( - chat_client, user_query, expected_conversation_commands + chat_client, user_query, expected_conversation_commands, default_user2 ): # Act - conversation_commands = await aget_relevant_tools_to_execute(user_query, {}, False, False) + selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2) # Assert - assert set(expected_conversation_commands) == set(conversation_commands) + assert expected_conversation_commands == selected_conversation_commands # ---------------------------------------------------------------------------------------------------- diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index bb0ca85a..49cda98b 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -8,7 +8,7 @@ from freezegun import freeze_time from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log -from khoj.routers.helpers import aget_relevant_tools_to_execute +from khoj.routers.helpers import aget_data_sources_and_output_format from tests.helpers import ConversationFactory # Initialize variables for tests @@ -719,7 +719,7 @@ async def test_get_correct_tools_online(chat_client): user_query = "What's the weather in Patagonia this week?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) + tools = await aget_data_sources_and_output_format(user_query, {}, False, False) # Assert tools = [tool.value for tool in tools] @@ -734,7 +734,7 @@ async def test_get_correct_tools_notes(chat_client): user_query = "Where did I go for my first battleship training?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) + tools = await aget_data_sources_and_output_format(user_query, {}, False, False) # Assert tools = [tool.value for tool in tools] @@ -749,7 +749,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client): user_query = "What's the highest point in Patagonia and have I been there?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) + tools = await aget_data_sources_and_output_format(user_query, {}, False, False) # Assert tools = [tool.value for tool in tools] @@ -766,7 +766,7 @@ async def test_get_correct_tools_general(chat_client): user_query = "How many noble gases are there?" # Act - tools = await aget_relevant_tools_to_execute(user_query, {}, False, False) + tools = await aget_data_sources_and_output_format(user_query, {}, False, False) # Assert tools = [tool.value for tool in tools] @@ -790,7 +790,7 @@ async def test_get_correct_tools_with_chat_history(chat_client): chat_history = generate_history(chat_log) # Act - tools = await aget_relevant_tools_to_execute(user_query, chat_history, False, False) + tools = await aget_data_sources_and_output_format(user_query, chat_history, False, False) # Assert tools = [tool.value for tool in tools]