diff --git a/src/khoj/processor/operator/operate_browser.py b/src/khoj/processor/operator/operate_browser.py index 23bc2123..e5a56bfc 100644 --- a/src/khoj/processor/operator/operate_browser.py +++ b/src/khoj/processor/operator/operate_browser.py @@ -47,10 +47,10 @@ async def operate_browser( # Initialize Agent max_iterations = 40 # TODO: Configurable? operator_agent: OperatorAgent - if chat_model.name.startswith("gpt-"): - operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer) - elif chat_model.name.startswith("claude-"): - operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer) + if reasoning_model.name.startswith("gpt-"): + operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer) + elif reasoning_model.name.startswith("claude-"): + operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer) else: grounding_model_name = "ui-tars-1.5-7b" grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name) @@ -60,7 +60,7 @@ async def operate_browser( or grounding_model.model_type != ChatModel.ModelType.OPENAI ): raise ValueError("No supported visual grounding model for binary operator agent found.") - operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer) + operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer) # Initialize Environment if send_status_func: @@ -87,7 +87,7 @@ async def operate_browser( browser_state = await environment.get_state() # 2. Agent decides action(s) - agent_result = await operator_agent.act(query, browser_state) + agent_result = await operator_agent.act(browser_state) # Render status update rendered_response = agent_result.rendered_response @@ -118,8 +118,8 @@ async def operate_browser( break if task_completed or trigger_iteration_limit: # Summarize results of operator run on last iteration - operator_agent.add_action_results(env_steps, agent_result, summarize_prompt) - summary_message = await operator_agent.summarize(query, browser_state) + operator_agent.add_action_results(env_steps, agent_result) + summary_message = await operator_agent.summarize(summarize_prompt, browser_state) logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}") break diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index 5e28d012..4e332a6f 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -20,9 +20,9 @@ logger = logging.getLogger(__name__) # --- Anthropic Operator Agent --- class AnthropicOperatorAgent(OperatorAgent): - async def act(self, query: str, current_state: EnvState) -> AgentActResult: + async def act(self, current_state: EnvState) -> AgentActResult: client = get_anthropic_async_client( - self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url + self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url ) tool_version = "2025-01-24" betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"] @@ -51,7 +51,7 @@ class AnthropicOperatorAgent(OperatorAgent): """ if is_none_or_empty(self.messages): - self.messages = [AgentMessage(role="user", content=query)] + self.messages = [AgentMessage(role="user", content=self.query)] tools = [ { @@ -77,13 +77,13 @@ class AnthropicOperatorAgent(OperatorAgent): ] thinking = {"type": "disabled"} - if self.chat_model.name.startswith("claude-3-7"): + if self.vision_model.name.startswith("claude-3-7"): thinking = {"type": "enabled", "budget_tokens": 1024} messages_for_api = self._format_message_for_api(self.messages) response = await client.beta.messages.create( messages=messages_for_api, - model=self.chat_model.name, + model=self.vision_model.name, system=system_prompt, tools=tools, betas=betas, @@ -187,8 +187,8 @@ class AnthropicOperatorAgent(OperatorAgent): { "type": "tool_result", "tool_use_id": tool_use_id, - "content": None, # Updated by environment step - "is_error": False, # Updated by environment step + "content": None, # Updated after environment step + "is_error": False, # Updated after environment step } ) @@ -206,13 +206,9 @@ class AnthropicOperatorAgent(OperatorAgent): rendered_response=rendered_response, ) - def add_action_results( - self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None - ): - if not agent_action.action_results and not summarize_prompt: + def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult): + if not agent_action.action_results: return - elif not agent_action.action_results: - agent_action.action_results = [] # Update action results with results of applying suggested actions on the environment for idx, env_step in enumerate(env_steps): @@ -236,10 +232,6 @@ class AnthropicOperatorAgent(OperatorAgent): if env_step.error: action_result["is_error"] = True - # If summarize prompt provided, append as text within the tool results user message - if summarize_prompt: - agent_action.action_results.append({"type": "text", "text": summarize_prompt}) - # Append tool results to the message history self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index f5cfaaf2..76ad4f54 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -25,26 +25,26 @@ class AgentMessage(BaseModel): class OperatorAgent(ABC): - def __init__(self, chat_model: ChatModel, max_iterations: int, tracer: dict): - self.chat_model = chat_model + def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict): + self.query = query + self.vision_model = vision_model self.max_iterations = max_iterations self.tracer = tracer self.messages: List[AgentMessage] = [] @abstractmethod - async def act(self, query: str, current_state: EnvState) -> AgentActResult: + async def act(self, current_state: EnvState) -> AgentActResult: pass @abstractmethod - def add_action_results( - self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None - ) -> None: + def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None: """Track results of agent actions on the environment.""" pass - async def summarize(self, query: str, current_state: EnvState) -> str: + async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str: """Summarize the agent's actions and results.""" - await self.act(query, current_state) + self.messages.append(AgentMessage(role="user", content=summarize_prompt)) + await self.act(current_state) if not self.messages: return "No actions to summarize." return self.compile_response(self.messages[-1].content) @@ -63,12 +63,12 @@ class OperatorAgent(ABC): def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0): self.tracer["usage"] = get_chat_usage_metrics( - self.chat_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage") + self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage") ) - logger.debug(f"Operator usage by {self.chat_model.model_type}: {self.tracer['usage']}") + logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}") def _commit_trace(self): - self.tracer["chat_model"] = self.chat_model.name + self.tracer["chat_model"] = self.vision_model.name if is_promptrace_enabled() and len(self.messages) > 1: compiled_messages = [ AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 5010821e..0f62580b 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -3,7 +3,6 @@ import logging from datetime import datetime from typing import List, Optional -from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from khoj.database.models import ChatModel @@ -18,7 +17,6 @@ from khoj.processor.operator.operator_environment_base import EnvState, EnvStepR from khoj.routers.helpers import send_message_to_model_wrapper from khoj.utils.helpers import ( convert_image_to_png, - get_chat_usage_metrics, get_openai_async_client, is_none_or_empty, ) @@ -29,34 +27,35 @@ logger = logging.getLogger(__name__) # --- Binary Operator Agent --- class BinaryOperatorAgent(OperatorAgent): """ - An OperatorAgent that uses two LLMs (OpenAI compatible): - 1. Vision LLM: Determines the next high-level action based on the visual state. + An OperatorAgent that uses two LLMs: + 1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory. 2. Grounding LLM: Converts the high-level action into specific, executable browser actions. """ def __init__( self, - vision_chat_model: ChatModel, - grounding_chat_model: ChatModel, # Assuming a second model is provided/configured + query: str, + reasoning_model: ChatModel, + grounding_model: ChatModel, max_iterations: int, tracer: dict, ): - super().__init__(vision_chat_model, max_iterations, tracer) # Use vision model for primary tracking - self.vision_chat_model = vision_chat_model - self.grounding_chat_model = grounding_chat_model - # Initialize OpenAI clients - self.grounding_client: AsyncOpenAI = get_openai_async_client( - grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url + super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking + self.reasoning_model = reasoning_model + self.grounding_model = grounding_model + # Initialize openai api compatible client for grounding model + self.grounding_client = get_openai_async_client( + grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url ) - async def act(self, query: str, current_state: EnvState) -> AgentActResult: + async def act(self, current_state: EnvState) -> AgentActResult: """ Uses a two-step LLM process to determine and structure the next action. """ self._commit_trace() # Commit trace before next action # --- Step 1: Reasoning LLM determines high-level action --- - reasoner_response = await self.act_reason(query, current_state) + reasoner_response = await self.act_reason(current_state) natural_language_action = reasoner_response["message"] if reasoner_response["type"] == "error": logger.error(natural_language_action) @@ -75,7 +74,7 @@ class BinaryOperatorAgent(OperatorAgent): # --- Step 2: Grounding LLM converts NL action to structured action --- return await self.act_ground(natural_language_action, current_state) - async def act_reason(self, query: str, current_state: EnvState) -> dict[str, str]: + async def act_reason(self, current_state: EnvState) -> dict[str, str]: """ Uses the reasoning LLM to determine the next high-level action based on the operation trajectory. """ @@ -118,12 +117,12 @@ Focus on the visual action and provide all necessary context. """.strip() if is_none_or_empty(self.messages): - query_text = f"**Main Objective**: {query}" + query_text = f"**Main Objective**: {self.query}" query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"] first_message_content = construct_structured_message( message=query_text, images=query_screenshot, - model_type=self.vision_chat_model.model_type, + model_type=self.reasoning_model.model_type, vision_enabled=True, ) current_message = AgentMessage(role="user", content=first_message_content) @@ -140,7 +139,7 @@ Focus on the visual action and provide all necessary context. query_images=query_screenshot, system_message=reasoning_system_prompt, conversation_log=visual_reasoner_history, - agent_chat_model=self.vision_chat_model, + agent_chat_model=self.reasoning_model, tracer=self.tracer, ) self.messages.append(current_message) @@ -371,7 +370,7 @@ back() # Use this to go back to the previous page. try: grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create( - model=self.grounding_chat_model.name, + model=self.grounding_model.name, messages=grounding_messages_for_api, tools=grounding_tools, tool_choice="required", @@ -465,14 +464,12 @@ back() # Use this to go back to the previous page. rendered_response=rendered_response, ) - def add_action_results( - self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None - ) -> None: + def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None: """ Adds the results of executed actions back into the message history, formatted for the next OpenAI vision LLM call. """ - if not agent_action.action_results and not summarize_prompt: + if not agent_action.action_results: return tool_outputs = [] @@ -493,44 +490,38 @@ back() # Use this to go back to the previous page. tool_output_content = construct_structured_message( message=tool_outputs_str, images=[formatted_screenshot], - model_type=self.vision_chat_model.model_type, + model_type=self.reasoning_model.model_type, vision_enabled=True, ) self.messages.append(AgentMessage(role="environment", content=tool_output_content)) - # Append summarize prompt if provided - if summarize_prompt: - self.messages.append(AgentMessage(role="user", content=summarize_prompt)) - - async def summarize(self, query: str, env_state: EnvState) -> str: - # Construct vision LLM input following OpenAI format + async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str: conversation_history = self._format_message_for_api(self.messages) try: summary = await send_message_to_model_wrapper( - query=query, + query=summarize_prompt, conversation_log=conversation_history, - agent_chat_model=self.vision_chat_model, + agent_chat_model=self.reasoning_model, tracer=self.tracer, ) - - # Return last action message if no summary + # Set summary to last action message if not summary: - return self.compile_response(self.messages[-1].content) # Compile the last action message - - # Append summary messages to history - trigger_summary = AgentMessage(role="user", content=query) - summary_message = AgentMessage(role="assistant", content=summary) - self.messages.extend([trigger_summary, summary_message]) - - return summary + raise ValueError("Summary is empty.") except Exception as e: - logger.error(f"Error calling Vision LLM for summary: {e}") - return f"Error generating summary: {e}" + logger.error(f"Error calling Reasoning LLM for summary: {e}") + summary = "\n".join([self._get_message_text(msg) for msg in self.messages]) + + # Append summary messages to history + trigger_summary = AgentMessage(role="user", content=summarize_prompt) + summary_message = AgentMessage(role="assistant", content=summary) + self.messages.extend([trigger_summary, summary_message]) + + return summary def compile_response(self, response_content: Union[str, List, dict]) -> str: """Compile response content into a string, handling OpenAI message structures.""" if isinstance(response_content, str): - return response_content # Simple text (e.g., initial user query, vision response) + return response_content if isinstance(response_content, dict) and response_content.get("role") == "assistant": # Grounding LLM response message (might contain tool calls) @@ -544,7 +535,7 @@ back() # Use this to go back to the previous page. compiled.append( f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}" ) - return "\n- ".join(filter(None, compiled)) or "[Assistant Message]" + return "\n- ".join(filter(None, compiled)) if isinstance(response_content, list): # Tool results list compiled = ["**Tool Results**:"] diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py index 768cf1e5..7eec9e3b 100644 --- a/src/khoj/processor/operator/operator_agent_openai.py +++ b/src/khoj/processor/operator/operator_agent_openai.py @@ -20,9 +20,9 @@ logger = logging.getLogger(__name__) # --- Anthropic Operator Agent --- class OpenAIOperatorAgent(OperatorAgent): - async def act(self, query: str, current_state: EnvState) -> AgentActResult: + async def act(self, current_state: EnvState) -> AgentActResult: client = get_openai_async_client( - self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url + self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url ) safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:" safety_check_message = None @@ -80,7 +80,7 @@ class OpenAIOperatorAgent(OperatorAgent): ] if is_none_or_empty(self.messages): - self.messages = [AgentMessage(role="user", content=query)] + self.messages = [AgentMessage(role="user", content=self.query)] messages_for_api = self._format_message_for_api(self.messages) response: Response = await client.responses.create( @@ -168,7 +168,7 @@ class OpenAIOperatorAgent(OperatorAgent): action_results.append( { "type": f"{block.type}_output", - "output": content, # Updated by environment step + "output": content, # Updated after environment step "call_id": last_call_id, } ) @@ -181,10 +181,8 @@ class OpenAIOperatorAgent(OperatorAgent): rendered_response=rendered_response, ) - def add_action_results( - self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None - ) -> None: - if not agent_action.action_results and not summarize_prompt: + def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None: + if not agent_action.action_results: return # Update action results with results of applying suggested actions on the environment @@ -209,11 +207,7 @@ class OpenAIOperatorAgent(OperatorAgent): # Add text data action_result["output"] = result_content - if agent_action.action_results: - self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] - # Append summarize prompt as a user message after tool results - if summarize_prompt: - self.messages += [AgentMessage(role="user", content=summarize_prompt)] + self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] def _format_message_for_api(self, messages: list[AgentMessage]) -> list: """Format the message for OpenAI API."""