diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index c6de9557..52acfbfd 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -261,6 +261,27 @@ def construct_question_history( return history_parts +def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]: + """ + Construct chat history for operator agent in conversation log. + Only include last n completed turns (i.e with user and khoj message). + """ + chat_history: list[AgentMessage] = [] + user_message: Optional[AgentMessage] = None + + for chat in conversation_history.get("chat", []): + if len(chat_history) >= n: + break + if chat["by"] == "you" and chat.get("message"): + content = [{"type": "text", "text": chat["message"]}] + for file in chat.get("queryFiles", []): + content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}] + user_message = AgentMessage(role="user", content=content) + elif chat["by"] == "khoj" and chat.get("message"): + chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])] + return chat_history + + def construct_tool_chat_history( previous_iterations: List[ResearchIteration], tool: ConversationCommand = None ) -> Dict[str, list]: @@ -285,9 +306,6 @@ def construct_tool_chat_history( ConversationCommand.Code: ( lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] ), - ConversationCommand.Operator: ( - lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else [] - ), } for iteration in previous_iterations: # If a tool is provided use the inferred query extractor for that tool if available diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index cb606c67..34f4ad7d 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -6,7 +6,11 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatModel, KhojUser -from khoj.processor.conversation.utils import OperatorRun, construct_chat_history +from khoj.processor.conversation.utils import ( + OperatorRun, + construct_chat_history, + construct_chat_history_for_operator, +) from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent @@ -43,6 +47,10 @@ async def operate_environment( ): response, user_input_message = None, None + # Only use partial previous trajectories to continue existing task + if previous_trajectory and previous_trajectory.response: + previous_trajectory = None + # Get the agent chat model agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None reasoning_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) @@ -51,21 +59,19 @@ async def operate_environment( if not reasoning_model: raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") - chat_history = construct_chat_history(conversation_log) - query_with_history = ( - f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query - ) # Append chat history to query if available + # Create conversation history from conversation log + chat_history = construct_chat_history_for_operator(conversation_log) # Initialize Agent max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100)) operator_agent: OperatorAgent if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: operator_agent = OpenAIOperatorAgent( - query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer + query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer ) elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC: operator_agent = AnthropicOperatorAgent( - query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer + query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer ) else: grounding_model_name = "ui-tars-1.5" @@ -77,11 +83,12 @@ async def operate_environment( ): raise ValueError("No supported visual grounding model for binary operator agent found.") operator_agent = BinaryOperatorAgent( - query_with_history, + query, reasoning_model, grounding_model, environment_type, max_iterations, + chat_history, previous_trajectory, tracer, ) @@ -100,6 +107,8 @@ async def operate_environment( try: task_completed = False iterations = 0 + operator_run = OperatorRun(query=query, trajectory=operator_agent.messages, response=response) + yield operator_run with timer( f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger @@ -159,30 +168,27 @@ async def operate_environment( # 4. Update agent on the results of its action on the environment operator_agent.add_action_results(env_steps, agent_result) + operator_run.trajectory = operator_agent.messages # Determine final response message if user_input_message: - response = user_input_message + operator_run.response = user_input_message elif task_completed: - response = summary_message + operator_run.response = summary_message + elif cancellation_event and cancellation_event.is_set(): + operator_run.response = None else: # Hit iteration limit - response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}" + operator_run.response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}" finally: if environment and not user_input_message: # Don't close environment if user input required await environment.close() if operator_agent: operator_agent.reset() - webpages = [] if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"): - webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls] + operator_run.webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls] - yield OperatorRun( - query=query, - trajectory=operator_agent.messages, - response=response, - webpages=webpages, - ) + yield operator_run def is_operator_model(model: str) -> ChatModel.ModelType | None: diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 23b99362..8686b708 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -34,6 +34,7 @@ class OperatorAgent(ABC): vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, + chat_history: List[AgentMessage] = [], previous_trajectory: Optional[OperatorRun] = None, tracer: dict = {}, ): @@ -42,11 +43,15 @@ class OperatorAgent(ABC): self.environment_type = environment_type self.max_iterations = max_iterations self.tracer = tracer - self.messages: List[AgentMessage] = [] self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." + self.messages: List[AgentMessage] = chat_history if previous_trajectory: - self.messages = previous_trajectory.trajectory + # Remove tool call from previous trajectory as tool call w/o result not supported + if previous_trajectory.trajectory and previous_trajectory.trajectory[-1].role == "assistant": + previous_trajectory.trajectory.pop() + self.messages += previous_trajectory.trajectory + self.messages += [AgentMessage(role="user", content=query)] # Context compression parameters self.context_compress_trigger = 2e3 # heuristic to determine compression trigger diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 2a058c42..ade98176 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -40,6 +40,7 @@ class BinaryOperatorAgent(OperatorAgent): grounding_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, + chat_history: List[AgentMessage] = [], previous_trajectory: Optional[OperatorRun] = None, tracer: dict = {}, ): @@ -48,6 +49,7 @@ class BinaryOperatorAgent(OperatorAgent): reasoning_model, environment_type, max_iterations, + chat_history, previous_trajectory, tracer, ) # Use reasoning model for primary tracking diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 23571738..a7bb513f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1011,7 +1011,7 @@ async def chat( query=defiltered_query, conversation_id=conversation_id, conversation_history=meta_log, - previous_iterations=research_results, + previous_iterations=list(research_results), query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1030,13 +1030,26 @@ async def chat( code_results.update(research_result.codeContext) if research_result.context: compiled_references.extend(research_result.context) - if research_result.operatorContext: - operator_results.append(research_result.operatorContext) + if not research_results or research_results[-1] is not research_result: research_results.append(research_result) - else: yield research_result + # Track operator results across research and operator iterations + # This relies on two conditions: + # 1. Check to append new (partial) operator results + # Relies on triggering this check on every status updates. + # Status updates cascade up from operator to research to chat api on every step. + # 2. Keep operator results in sync with each research operator step + # Relies on python object references to ensure operator results + # are implicitly kept in sync after the initial append + if ( + research_results + and research_results[-1].operatorContext + and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext) + ): + operator_results.append(research_results[-1].operatorContext) + # researched_results = await extract_relevant_info(q, researched_results, agent) if state.verbose > 1: logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}') @@ -1294,7 +1307,7 @@ async def chat( user, meta_log, location, - operator_results[-1] if operator_results else None, + list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, query_files=attached_file_context, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1305,7 +1318,8 @@ async def chat( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] elif isinstance(result, OperatorRun): - operator_results.append(result) + if not operator_results or operator_results[-1] is not result: + operator_results.append(result) # Add webpages visited while operating browser to references if result.webpages: if not online_results.get(defiltered_query): diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 902c2e0c..83ec141a 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -96,6 +96,24 @@ async def apick_next_tool( ): """Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" + # Continue with previous iteration if a multi-step tool use is in progress + if ( + previous_iterations + and previous_iterations[-1].tool == ConversationCommand.Operator + and not previous_iterations[-1].summarizedResult + ): + previous_iteration = previous_iterations[-1] + yield ResearchIteration( + tool=previous_iteration.tool, + query=query, + context=previous_iteration.context, + onlineContext=previous_iteration.onlineContext, + codeContext=previous_iteration.codeContext, + operatorContext=previous_iteration.operatorContext, + warning=previous_iteration.warning, + ) + return + # Construct tool options for the agent to choose from tool_options = dict() tool_options_str = "" @@ -274,6 +292,7 @@ async def research( yield result[ChatEvent.STATUS] elif isinstance(result, ResearchIteration): this_iteration = result + yield this_iteration # Skip running iteration if warning present in iteration if this_iteration.warning: