Fix, improve openai operator agent for interrupts, computer environment

- Create reusable method to call model
- Fix to summarize messages on operator run.
- Mark assistant tool calls with role = assistant, not environment

- Try fix message format when load after interrupts.
  Does not work well yet
This commit is contained in:
Debanjum
2025-05-30 21:14:43 -07:00
parent f517566560
commit c5c06a086e
2 changed files with 88 additions and 29 deletions

View File

@@ -97,7 +97,7 @@ class OperatorRun:
def __init__(
self,
query: str,
trajectory: list[AgentMessage | dict] = None,
trajectory: list[AgentMessage] | list[dict] = None,
response: str = None,
webpages: list[dict] = None,
):
@@ -138,7 +138,7 @@ class ResearchIteration:
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
operatorContext: dict = None,
operatorContext: dict | OperatorRun = None,
summarizedResult: str = None,
warning: str = None,
):
@@ -147,15 +147,13 @@ class ResearchIteration:
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
self.summarizedResult = summarizedResult
self.warning = warning
def to_dict(self) -> dict:
data = vars(self).copy()
data["operatorContext"] = (
self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None
)
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
return data

View File

@@ -8,6 +8,7 @@ from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
@@ -24,9 +25,6 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent ---
class OpenAIOperatorAgent(OperatorAgent):
async def act(self, current_state: EnvState) -> AgentActResult:
client = get_openai_async_client(
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
safety_check_message = None
actions: List[OperatorAction] = []
@@ -34,23 +32,11 @@ class OpenAIOperatorAgent(OperatorAgent):
self._commit_trace() # Commit trace before next action
system_prompt = self.get_instructions(self.environment_type, current_state)
tools = self.get_tools(self.environment_type, current_state)
if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)]
messages_for_api = self._format_message_for_api(self.messages)
response: Response = await client.responses.create(
model="computer-use-preview",
input=messages_for_api,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False, # Keep sequential for now
max_output_tokens=4096, # TODO: Make configurable?
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
self.messages += [AgentMessage(role="environment", content=response.output)]
response = await self._call_model(self.vision_model, system_prompt, tools)
self.messages += [AgentMessage(role="assistant", content=response.output)]
rendered_response = self._render_response(response.output, current_state.screenshot)
last_call_id = None
@@ -130,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent):
"summary": [],
}
)
else:
logger.warning(f"Unsupported response block type: {block.type}")
content = f"Unsupported response block type: {block.type}"
if action_to_run or content:
actions.append(action_to_run)
if action_to_run or content:
@@ -176,6 +165,9 @@ class OpenAIOperatorAgent(OperatorAgent):
elif action_result["type"] == "reasoning":
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
continue
elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress":
result_content["status"] = "completed" # Mark in-progress actions as completed
action_result["output"] = result_content
else:
# Add text data
action_result["output"] = result_content
@@ -185,11 +177,45 @@ class OpenAIOperatorAgent(OperatorAgent):
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
summarize_prompt = summarize_prompt or self.summarize_prompt
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
response = await self._call_model(self.vision_model, summarize_prompt, [])
self.messages += [AgentMessage(role="assistant", content=response.output)]
if not self.messages:
return "No actions to summarize."
return self._compile_response(self.messages[-1].content)
async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response:
client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
if tools:
model_name = "computer-use-preview"
else:
model_name = model.name
# Format messages for OpenAI API
messages_for_api = self._format_message_for_api(self.messages)
# format messages for summary if model is not computer-use-preview
if model_name != "computer-use-preview":
messages_for_api = self._format_messages_for_summary(messages_for_api)
response: Response = await client.responses.create(
model=model_name,
input=messages_for_api,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
return response
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API."""
formatted_messages: list = []
for message in messages:
if message.role == "environment":
if message.role == "assistant":
if isinstance(message.content, list):
# Remove reasoning message if not followed by computer call
if (
@@ -208,14 +234,19 @@ class OpenAIOperatorAgent(OperatorAgent):
message.content.pop(0)
formatted_messages.extend(message.content)
else:
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
logger.warning(f"Expected message content list from assistant, got {type(message.content)}")
elif message.role == "environment":
formatted_messages.extend(message.content)
else:
if isinstance(message.content, list):
message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"])
formatted_messages.append(
{
"role": message.role,
"content": message.content,
}
)
return formatted_messages
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
@@ -352,10 +383,10 @@ class OpenAIOperatorAgent(OperatorAgent):
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
"""Return the tools available for the OpenAI operator."""
if environment_type == EnvironmentType.COMPUTER:
# get os info of this computer. it can be mac, windows, linux
environment_os = (
"mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
)
# TODO: Get OS info from the environment
# For now, assume Linux as the environment OS
environment_os = "linux"
# environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
else:
environment_os = "browser"
@@ -393,3 +424,33 @@ class OpenAIOperatorAgent(OperatorAgent):
},
]
return tools
def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]:
"""Format messages for summary."""
# Format messages to interact with non computer use AI models
items_to_drop = [] # Track indices to drop reasoning messages
for idx, msg in enumerate(formatted_messages):
if isinstance(msg, dict) and "content" in msg:
continue
elif isinstance(msg, dict) and "output" in msg:
# Drop current_url from output as not supported for non computer operations
if "current_url" in msg["output"]:
del msg["output"]["current_url"]
formatted_messages[idx] = {"role": "user", "content": [msg["output"]]}
elif isinstance(msg, str):
formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]}
else:
text = self._compile_response([msg])
if not text:
items_to_drop.append(idx) # Track index to drop reasoning message
else:
formatted_messages[idx] = {
"role": "assistant",
"content": [{"type": "output_text", "text": text}],
}
# Remove reasoning messages for non-computer use models
for idx in reversed(items_to_drop):
formatted_messages.pop(idx)
return formatted_messages