diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e86834f9..7901f29c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -152,19 +152,35 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A def construct_tool_chat_history( previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None ) -> Dict[str, list]: + """ + Construct chat history from previous iterations for a specific tool + + If a tool is provided, only the inferred queries for that tool is added. + If no tool is provided inferred query for all tools used are added. + """ chat_history: list = [] - inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] - if tool == ConversationCommand.Notes: - inferred_query_extractor = ( + base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] + extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = { + ConversationCommand.Notes: ( lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] - ) - elif tool == ConversationCommand.Online: - inferred_query_extractor = ( + ), + ConversationCommand.Online: ( lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] - ) - elif tool == ConversationCommand.Code: - inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + ), + ConversationCommand.Webpage: ( + lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] + ), + ConversationCommand.Code: ( + lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + ), + } for iteration in previous_iterations: + # If a tool is provided use the inferred query extractor for that tool if available + # If no tool is provided, use inferred query extractor for the tool used in the iteration + # Fallback to base extractor if the tool does not have an inferred query extractor + inferred_query_extractor = extract_inferred_query_map.get( + tool or ConversationCommand(iteration.tool), base_extractor + ) chat_history += [ { "by": "you", @@ -409,6 +425,9 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + if not is_none_or_empty(chat.get("codeContext")): + message_context += f"{prompts.code_executed_context.format(online_results=chat.get('codeContext'))}" + if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 9fb7c229..86f63b2d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -361,7 +361,7 @@ async def execute_information_collection( try: async for result in run_code( this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), + construct_tool_chat_history(previous_iterations, ConversationCommand.Code), "", location, user, @@ -388,7 +388,7 @@ async def execute_information_collection( this_iteration.query, user, file_filters, - construct_tool_chat_history(previous_iterations), + construct_tool_chat_history(previous_iterations, ConversationCommand.Summarize), query_images=query_images, agent=agent, send_status_func=send_status_func,