Fix, improve openai operator agent for interrupts, computer environment

- Create reusable method to call model
- Fix to summarize messages on operator run.
- Mark assistant tool calls with role = assistant, not environment

- Try fix message format when load after interrupts.
  Does not work well yet
This commit is contained in:
Debanjum
2025-05-30 21:14:43 -07:00
parent f517566560
commit c5c06a086e
2 changed files with 88 additions and 29 deletions

View File

@@ -97,7 +97,7 @@ class OperatorRun:
def __init__( def __init__(
self, self,
query: str, query: str,
trajectory: list[AgentMessage | dict] = None, trajectory: list[AgentMessage] | list[dict] = None,
response: str = None, response: str = None,
webpages: list[dict] = None, webpages: list[dict] = None,
): ):
@@ -138,7 +138,7 @@ class ResearchIteration:
context: list = None, context: list = None,
onlineContext: dict = None, onlineContext: dict = None,
codeContext: dict = None, codeContext: dict = None,
operatorContext: dict = None, operatorContext: dict | OperatorRun = None,
summarizedResult: str = None, summarizedResult: str = None,
warning: str = None, warning: str = None,
): ):
@@ -147,15 +147,13 @@ class ResearchIteration:
self.context = context self.context = context
self.onlineContext = onlineContext self.onlineContext = onlineContext
self.codeContext = codeContext self.codeContext = codeContext
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
self.summarizedResult = summarizedResult self.summarizedResult = summarizedResult
self.warning = warning self.warning = warning
def to_dict(self) -> dict: def to_dict(self) -> dict:
data = vars(self).copy() data = vars(self).copy()
data["operatorContext"] = ( data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None
)
return data return data

View File

@@ -8,6 +8,7 @@ from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem from openai.types.responses import Response, ResponseOutputItem
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
@@ -24,9 +25,6 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent --- # --- Anthropic Operator Agent ---
class OpenAIOperatorAgent(OperatorAgent): class OpenAIOperatorAgent(OperatorAgent):
async def act(self, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
client = get_openai_async_client(
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:" safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
safety_check_message = None safety_check_message = None
actions: List[OperatorAction] = [] actions: List[OperatorAction] = []
@@ -34,23 +32,11 @@ class OpenAIOperatorAgent(OperatorAgent):
self._commit_trace() # Commit trace before next action self._commit_trace() # Commit trace before next action
system_prompt = self.get_instructions(self.environment_type, current_state) system_prompt = self.get_instructions(self.environment_type, current_state)
tools = self.get_tools(self.environment_type, current_state) tools = self.get_tools(self.environment_type, current_state)
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)] self.messages = [AgentMessage(role="user", content=self.query)]
messages_for_api = self._format_message_for_api(self.messages) response = await self._call_model(self.vision_model, system_prompt, tools)
response: Response = await client.responses.create( self.messages += [AgentMessage(role="assistant", content=response.output)]
model="computer-use-preview",
input=messages_for_api,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False, # Keep sequential for now
max_output_tokens=4096, # TODO: Make configurable?
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
self.messages += [AgentMessage(role="environment", content=response.output)]
rendered_response = self._render_response(response.output, current_state.screenshot) rendered_response = self._render_response(response.output, current_state.screenshot)
last_call_id = None last_call_id = None
@@ -130,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent):
"summary": [], "summary": [],
} }
) )
else:
logger.warning(f"Unsupported response block type: {block.type}")
content = f"Unsupported response block type: {block.type}"
if action_to_run or content: if action_to_run or content:
actions.append(action_to_run) actions.append(action_to_run)
if action_to_run or content: if action_to_run or content:
@@ -176,6 +165,9 @@ class OpenAIOperatorAgent(OperatorAgent):
elif action_result["type"] == "reasoning": elif action_result["type"] == "reasoning":
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
continue continue
elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress":
result_content["status"] = "completed" # Mark in-progress actions as completed
action_result["output"] = result_content
else: else:
# Add text data # Add text data
action_result["output"] = result_content action_result["output"] = result_content
@@ -185,11 +177,45 @@ class OpenAIOperatorAgent(OperatorAgent):
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
summarize_prompt = summarize_prompt or self.summarize_prompt
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
response = await self._call_model(self.vision_model, summarize_prompt, [])
self.messages += [AgentMessage(role="assistant", content=response.output)]
if not self.messages:
return "No actions to summarize."
return self._compile_response(self.messages[-1].content)
async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response:
client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
if tools:
model_name = "computer-use-preview"
else:
model_name = model.name
# Format messages for OpenAI API
messages_for_api = self._format_message_for_api(self.messages)
# format messages for summary if model is not computer-use-preview
if model_name != "computer-use-preview":
messages_for_api = self._format_messages_for_summary(messages_for_api)
response: Response = await client.responses.create(
model=model_name,
input=messages_for_api,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
return response
def _format_message_for_api(self, messages: list[AgentMessage]) -> list: def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API.""" """Format the message for OpenAI API."""
formatted_messages: list = [] formatted_messages: list = []
for message in messages: for message in messages:
if message.role == "environment": if message.role == "assistant":
if isinstance(message.content, list): if isinstance(message.content, list):
# Remove reasoning message if not followed by computer call # Remove reasoning message if not followed by computer call
if ( if (
@@ -208,14 +234,19 @@ class OpenAIOperatorAgent(OperatorAgent):
message.content.pop(0) message.content.pop(0)
formatted_messages.extend(message.content) formatted_messages.extend(message.content)
else: else:
logger.warning(f"Expected message content list from environment, got {type(message.content)}") logger.warning(f"Expected message content list from assistant, got {type(message.content)}")
elif message.role == "environment":
formatted_messages.extend(message.content)
else: else:
if isinstance(message.content, list):
message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"])
formatted_messages.append( formatted_messages.append(
{ {
"role": message.role, "role": message.role,
"content": message.content, "content": message.content,
} }
) )
return formatted_messages return formatted_messages
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str: def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
@@ -352,10 +383,10 @@ class OpenAIOperatorAgent(OperatorAgent):
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]: def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
"""Return the tools available for the OpenAI operator.""" """Return the tools available for the OpenAI operator."""
if environment_type == EnvironmentType.COMPUTER: if environment_type == EnvironmentType.COMPUTER:
# get os info of this computer. it can be mac, windows, linux # TODO: Get OS info from the environment
environment_os = ( # For now, assume Linux as the environment OS
"mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux" environment_os = "linux"
) # environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
else: else:
environment_os = "browser" environment_os = "browser"
@@ -393,3 +424,33 @@ class OpenAIOperatorAgent(OperatorAgent):
}, },
] ]
return tools return tools
def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]:
"""Format messages for summary."""
# Format messages to interact with non computer use AI models
items_to_drop = [] # Track indices to drop reasoning messages
for idx, msg in enumerate(formatted_messages):
if isinstance(msg, dict) and "content" in msg:
continue
elif isinstance(msg, dict) and "output" in msg:
# Drop current_url from output as not supported for non computer operations
if "current_url" in msg["output"]:
del msg["output"]["current_url"]
formatted_messages[idx] = {"role": "user", "content": [msg["output"]]}
elif isinstance(msg, str):
formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]}
else:
text = self._compile_response([msg])
if not text:
items_to_drop.append(idx) # Track index to drop reasoning message
else:
formatted_messages[idx] = {
"role": "assistant",
"content": [{"type": "output_text", "text": text}],
}
# Remove reasoning messages for non-computer use models
for idx in reversed(items_to_drop):
formatted_messages.pop(idx)
return formatted_messages