diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index d154a965..011e8045 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -322,7 +322,7 @@ def construct_tool_chat_history( chat_history: list = [] base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: [] extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = { - ConversationCommand.Notes: ( + ConversationCommand.SemanticSearchFiles: ( lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] ), ConversationCommand.SearchWeb: ( diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 12c40d36..5e476d4b 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -58,7 +58,7 @@ async def apick_next_tool( query_files: str = None, max_document_searches: int = 7, max_online_searches: int = 3, - max_webpages_to_read: int = 1, + max_webpages_to_read: int = 3, send_status_func: Optional[Callable] = None, tracer: dict = {}, ): @@ -86,25 +86,40 @@ async def apick_next_tool( # Construct tool options for the agent to choose from tools = [] tool_options_str = "" - agent_tools = agent.input_tools if agent else [] + agent_input_tools = agent.input_tools if agent else [] + agent_tools = [] + + # Map agent user facing tools to research tools to include in agents toolbox + document_research_tools = [ + ConversationCommand.SemanticSearchFiles, + ConversationCommand.RegexSearchFiles, + ConversationCommand.ViewFile, + ConversationCommand.ListFiles, + ] + input_tools_to_research_tools = { + ConversationCommand.Notes.value: [tool.value for tool in document_research_tools], + ConversationCommand.Webpage.value: [ConversationCommand.ReadWebpage.value], + ConversationCommand.Online.value: [ConversationCommand.SearchWeb.value], + ConversationCommand.Code.value: [ConversationCommand.RunCode.value], + ConversationCommand.Operator.value: [ConversationCommand.OperateComputer.value], + } + for input_tool, research_tools in input_tools_to_research_tools.items(): + if input_tool in agent_input_tools: + agent_tools += research_tools + user_has_entries = await EntryAdapters.auser_has_entries(user) for tool, tool_data in tools_for_research_llm.items(): # Skip showing operator tool as an option if not enabled - if tool == ConversationCommand.Operator and not is_operator_enabled(): + if tool == ConversationCommand.OperateComputer and not is_operator_enabled(): continue # Skip showing document related tools if user has no documents - if ( - tool == ConversationCommand.SemanticSearchFiles - or tool == ConversationCommand.RegexSearchFiles - or tool == ConversationCommand.ViewFile - or tool == ConversationCommand.ListFiles - ) and not user_has_entries: + if tool in document_research_tools and not user_has_entries: continue if tool == ConversationCommand.SemanticSearchFiles: description = tool_data.description.format(max_search_queries=max_document_searches) - elif tool == ConversationCommand.Webpage: + elif tool == ConversationCommand.ReadWebpage: description = tool_data.description.format(max_webpages_to_read=max_webpages_to_read) - elif tool == ConversationCommand.Online: + elif tool == ConversationCommand.SearchWeb: description = tool_data.description.format(max_search_queries=max_online_searches) else: description = tool_data.description @@ -321,7 +336,9 @@ async def research( try: async for result in search_online( **this_iteration.query.args, - conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Online), + conversation_history=construct_tool_chat_history( + previous_iterations, ConversationCommand.SearchWeb + ), location=location, user=user, send_status_func=send_status_func, @@ -377,7 +394,7 @@ async def research( try: async for result in run_code( **this_iteration.query.args, - conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Code), + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.RunCode), context="", location_data=location, user=user,