diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index e38ae777..cb606c67 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -6,7 +6,7 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatModel, KhojUser -from khoj.processor.conversation.utils import OperatorRun +from khoj.processor.conversation.utils import OperatorRun, construct_chat_history from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent @@ -32,6 +32,7 @@ async def operate_environment( user: KhojUser, conversation_log: dict, location_data: LocationData, + previous_trajectory: Optional[OperatorRun] = None, environment_type: EnvironmentType = EnvironmentType.COMPUTER, send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, # TODO: Handle query images @@ -50,13 +51,22 @@ async def operate_environment( if not reasoning_model: raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") + chat_history = construct_chat_history(conversation_log) + query_with_history = ( + f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query + ) # Append chat history to query if available + # Initialize Agent max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100)) operator_agent: OperatorAgent if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: - operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer) + operator_agent = OpenAIOperatorAgent( + query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer + ) elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC: - operator_agent = AnthropicOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer) + operator_agent = AnthropicOperatorAgent( + query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer + ) else: grounding_model_name = "ui-tars-1.5" grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name) @@ -67,7 +77,13 @@ async def operate_environment( ): raise ValueError("No supported visual grounding model for binary operator agent found.") operator_agent = BinaryOperatorAgent( - query, reasoning_model, grounding_model, environment_type, max_iterations, tracer + query_with_history, + reasoning_model, + grounding_model, + environment_type, + max_iterations, + previous_trajectory, + tracer, ) # Initialize Environment diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 1aa0f238..23b99362 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -5,7 +5,11 @@ from typing import List, Literal, Optional, Union from pydantic import BaseModel from khoj.database.models import ChatModel -from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace +from khoj.processor.conversation.utils import ( + AgentMessage, + OperatorRun, + commit_conversation_trace, +) from khoj.processor.operator.operator_actions import OperatorAction from khoj.processor.operator.operator_environment_base import ( EnvironmentType, @@ -25,7 +29,13 @@ class AgentActResult(BaseModel): class OperatorAgent(ABC): def __init__( - self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict + self, + query: str, + vision_model: ChatModel, + environment_type: EnvironmentType, + max_iterations: int, + previous_trajectory: Optional[OperatorRun] = None, + tracer: dict = {}, ): self.query = query self.vision_model = vision_model @@ -35,6 +45,9 @@ class OperatorAgent(ABC): self.messages: List[AgentMessage] = [] self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." + if previous_trajectory: + self.messages = previous_trajectory.trajectory + # Context compression parameters self.context_compress_trigger = 2e3 # heuristic to determine compression trigger # turns after which compression triggered. scales with model max context size. Minimum 5 turns. diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 77e28442..2a058c42 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -2,10 +2,14 @@ import json import logging from datetime import datetime from textwrap import dedent -from typing import List +from typing import List, Optional from khoj.database.models import ChatModel -from khoj.processor.conversation.utils import AgentMessage, construct_structured_message +from khoj.processor.conversation.utils import ( + AgentMessage, + OperatorRun, + construct_structured_message, +) from khoj.processor.operator.grounding_agent import GroundingAgent from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars from khoj.processor.operator.operator_actions import * @@ -36,10 +40,16 @@ class BinaryOperatorAgent(OperatorAgent): grounding_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, - tracer: dict, + previous_trajectory: Optional[OperatorRun] = None, + tracer: dict = {}, ): super().__init__( - query, reasoning_model, environment_type, max_iterations, tracer + query, + reasoning_model, + environment_type, + max_iterations, + previous_trajectory, + tracer, ) # Use reasoning model for primary tracking self.reasoning_model = reasoning_model self.grounding_model = grounding_model diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 39599c37..23571738 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1294,6 +1294,7 @@ async def chat( user, meta_log, location, + operator_results[-1] if operator_results else None, query_images=uploaded_images, query_files=attached_file_context, send_status_func=partial(send_event, ChatEvent.STATUS), diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d0f9f4ef..902c2e0c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -423,6 +423,7 @@ async def research( user, construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), location, + previous_iterations[-1].operatorContext if previous_iterations else None, send_status_func=send_status_func, query_images=query_images, agent=agent,