From 99a230524645b1e43a8a93d36118f048657b36fe Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 10 May 2025 16:15:12 -0600 Subject: [PATCH] Improve tool chat history constructor and fix its usage during research. Code tool should see code context and webpage tool should see online context during research runs Fix to include code context from past conversations to answer queries. Add all queries to tool chat history when no specific tool to limit extracting inferred queries for provided. --- src/khoj/processor/conversation/utils.py | 37 ++++++++++++++++++------ src/khoj/routers/research.py | 4 +-- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e86834f9..7901f29c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -152,19 +152,35 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A def construct_tool_chat_history( previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None ) -> Dict[str, list]: + """ + Construct chat history from previous iterations for a specific tool + + If a tool is provided, only the inferred queries for that tool is added. + If no tool is provided inferred query for all tools used are added. + """ chat_history: list = [] - inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] - if tool == ConversationCommand.Notes: - inferred_query_extractor = ( + base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] + extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = { + ConversationCommand.Notes: ( lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] - ) - elif tool == ConversationCommand.Online: - inferred_query_extractor = ( + ), + ConversationCommand.Online: ( lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] - ) - elif tool == ConversationCommand.Code: - inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + ), + ConversationCommand.Webpage: ( + lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] + ), + ConversationCommand.Code: ( + lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + ), + } for iteration in previous_iterations: + # If a tool is provided use the inferred query extractor for that tool if available + # If no tool is provided, use inferred query extractor for the tool used in the iteration + # Fallback to base extractor if the tool does not have an inferred query extractor + inferred_query_extractor = extract_inferred_query_map.get( + tool or ConversationCommand(iteration.tool), base_extractor + ) chat_history += [ { "by": "you", @@ -409,6 +425,9 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + if not is_none_or_empty(chat.get("codeContext")): + message_context += f"{prompts.code_executed_context.format(online_results=chat.get('codeContext'))}" + if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 9fb7c229..86f63b2d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -361,7 +361,7 @@ async def execute_information_collection( try: async for result in run_code( this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), + construct_tool_chat_history(previous_iterations, ConversationCommand.Code), "", location, user, @@ -388,7 +388,7 @@ async def execute_information_collection( this_iteration.query, user, file_filters, - construct_tool_chat_history(previous_iterations), + construct_tool_chat_history(previous_iterations, ConversationCommand.Summarize), query_images=query_images, agent=agent, send_status_func=send_status_func,