From 06a1a22e3b9d3380f8db03aae4715d135e695011 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 19 May 2025 10:02:14 -0700 Subject: [PATCH] Align generic grounding agent's interface with uitars grounding agent The generic grounding agent has not been tested properly but at least it should be aligned with the interface being used by the ui-tars grounding agent which has been tested. --- .../processor/operator/grounding_agent.py | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/khoj/processor/operator/grounding_agent.py b/src/khoj/processor/operator/grounding_agent.py index d6126cbe..61c36d9d 100644 --- a/src/khoj/processor/operator/grounding_agent.py +++ b/src/khoj/processor/operator/grounding_agent.py @@ -185,7 +185,7 @@ class GroundingAgent: }, ] - async def act(self, instruction: str, current_state: EnvState) -> AgentActResult: + async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]: """Call the grounding LLM to get the next action based on the current state and instruction.""" # Format the message for the API call messages_for_api = self._format_message_for_api(instruction, current_state) @@ -204,7 +204,7 @@ class GroundingAgent: # Parse tool calls grounding_message = grounding_response.choices[0].message - action_results = self._parse_action(grounding_message, instruction, current_state) + rendered_response, actions = self._parse_action(grounding_message, instruction, current_state) # Update usage by grounding model self.tracer["usage"] = get_chat_usage_metrics( @@ -215,10 +215,10 @@ class GroundingAgent: ) except Exception as e: logger.error(f"Error calling Grounding LLM: {e}") - rendered_response = f"**Thought (Vision)**: {instruction}\n- **Error**: Error contacting Grounding LLM: {e}" - action_results = AgentActResult(actions=[], action_results=[], rendered_response=rendered_response) + rendered_response = f"**Error**: Error contacting Grounding LLM: {e}" + actions = [] - return action_results + return rendered_response, actions def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List: """Format the message for the API call.""" @@ -264,14 +264,13 @@ back() # Use this to go back to the previous page. def _parse_action( self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState - ) -> AgentActResult: + ) -> tuple[str, list[OperatorAction]]: """Parse the tool calls from the grounding LLM response and convert them to action objects.""" actions: List[OperatorAction] = [] action_results: List[dict] = [] if grounding_message.tool_calls: - # Start rendering with vision output - rendered_parts = [f"**Thought (Vision)**: {instruction}"] + rendered_parts = [] for tool_call in grounding_message.tool_calls: function_name = tool_call.function.name try: @@ -336,17 +335,10 @@ back() # Use this to go back to the previous page. else: # Grounding LLM responded but didn't call a tool logger.warning("Grounding LLM did not produce a tool call.") - rendered_response = f"**Thought (Vision)**: {instruction}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}" + rendered_response = f"{grounding_message.content or 'No action required.'}" # Render the response - return AgentActResult( - actions=actions, - action_results=action_results, - rendered_response={ - "text": rendered_response, - "image": f"data:image/webp;base64,{current_state.screenshot}", - }, - ) + return rendered_response, actions def reset(self): """Reset the agent state."""