mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Fix, improve openai operator agent for interrupts, computer environment
- Create reusable method to call model - Fix to summarize messages on operator run. - Mark assistant tool calls with role = assistant, not environment - Try fix message format when load after interrupts. Does not work well yet
This commit is contained in:
@@ -97,7 +97,7 @@ class OperatorRun:
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
trajectory: list[AgentMessage | dict] = None,
|
||||
trajectory: list[AgentMessage] | list[dict] = None,
|
||||
response: str = None,
|
||||
webpages: list[dict] = None,
|
||||
):
|
||||
@@ -138,7 +138,7 @@ class ResearchIteration:
|
||||
context: list = None,
|
||||
onlineContext: dict = None,
|
||||
codeContext: dict = None,
|
||||
operatorContext: dict = None,
|
||||
operatorContext: dict | OperatorRun = None,
|
||||
summarizedResult: str = None,
|
||||
warning: str = None,
|
||||
):
|
||||
@@ -147,15 +147,13 @@ class ResearchIteration:
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
self.codeContext = codeContext
|
||||
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None
|
||||
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
|
||||
self.summarizedResult = summarizedResult
|
||||
self.warning = warning
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
data = vars(self).copy()
|
||||
data["operatorContext"] = (
|
||||
self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None
|
||||
)
|
||||
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional, cast
|
||||
|
||||
from openai.types.responses import Response, ResponseOutputItem
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import AgentMessage
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
@@ -24,9 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
# --- Anthropic Operator Agent ---
|
||||
class OpenAIOperatorAgent(OperatorAgent):
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
client = get_openai_async_client(
|
||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||
)
|
||||
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
||||
safety_check_message = None
|
||||
actions: List[OperatorAction] = []
|
||||
@@ -34,23 +32,11 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
self._commit_trace() # Commit trace before next action
|
||||
system_prompt = self.get_instructions(self.environment_type, current_state)
|
||||
tools = self.get_tools(self.environment_type, current_state)
|
||||
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
response: Response = await client.responses.create(
|
||||
model="computer-use-preview",
|
||||
input=messages_for_api,
|
||||
instructions=system_prompt,
|
||||
tools=tools,
|
||||
parallel_tool_calls=False, # Keep sequential for now
|
||||
max_output_tokens=4096, # TODO: Make configurable?
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||
self.messages += [AgentMessage(role="environment", content=response.output)]
|
||||
response = await self._call_model(self.vision_model, system_prompt, tools)
|
||||
self.messages += [AgentMessage(role="assistant", content=response.output)]
|
||||
rendered_response = self._render_response(response.output, current_state.screenshot)
|
||||
|
||||
last_call_id = None
|
||||
@@ -130,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
"summary": [],
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported response block type: {block.type}")
|
||||
content = f"Unsupported response block type: {block.type}"
|
||||
if action_to_run or content:
|
||||
actions.append(action_to_run)
|
||||
if action_to_run or content:
|
||||
@@ -176,6 +165,9 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
elif action_result["type"] == "reasoning":
|
||||
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
|
||||
continue
|
||||
elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress":
|
||||
result_content["status"] = "completed" # Mark in-progress actions as completed
|
||||
action_result["output"] = result_content
|
||||
else:
|
||||
# Add text data
|
||||
action_result["output"] = result_content
|
||||
@@ -185,11 +177,45 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
|
||||
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
|
||||
summarize_prompt = summarize_prompt or self.summarize_prompt
|
||||
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||
response = await self._call_model(self.vision_model, summarize_prompt, [])
|
||||
self.messages += [AgentMessage(role="assistant", content=response.output)]
|
||||
if not self.messages:
|
||||
return "No actions to summarize."
|
||||
return self._compile_response(self.messages[-1].content)
|
||||
|
||||
async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response:
|
||||
client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
|
||||
if tools:
|
||||
model_name = "computer-use-preview"
|
||||
else:
|
||||
model_name = model.name
|
||||
|
||||
# Format messages for OpenAI API
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
# format messages for summary if model is not computer-use-preview
|
||||
if model_name != "computer-use-preview":
|
||||
messages_for_api = self._format_messages_for_summary(messages_for_api)
|
||||
|
||||
response: Response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=messages_for_api,
|
||||
instructions=system_prompt,
|
||||
tools=tools,
|
||||
parallel_tool_calls=False,
|
||||
truncation="auto",
|
||||
)
|
||||
|
||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||
return response
|
||||
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
||||
"""Format the message for OpenAI API."""
|
||||
formatted_messages: list = []
|
||||
for message in messages:
|
||||
if message.role == "environment":
|
||||
if message.role == "assistant":
|
||||
if isinstance(message.content, list):
|
||||
# Remove reasoning message if not followed by computer call
|
||||
if (
|
||||
@@ -208,14 +234,19 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
message.content.pop(0)
|
||||
formatted_messages.extend(message.content)
|
||||
else:
|
||||
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
|
||||
logger.warning(f"Expected message content list from assistant, got {type(message.content)}")
|
||||
elif message.role == "environment":
|
||||
formatted_messages.extend(message.content)
|
||||
else:
|
||||
if isinstance(message.content, list):
|
||||
message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
|
||||
@@ -352,10 +383,10 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
|
||||
"""Return the tools available for the OpenAI operator."""
|
||||
if environment_type == EnvironmentType.COMPUTER:
|
||||
# get os info of this computer. it can be mac, windows, linux
|
||||
environment_os = (
|
||||
"mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
|
||||
)
|
||||
# TODO: Get OS info from the environment
|
||||
# For now, assume Linux as the environment OS
|
||||
environment_os = "linux"
|
||||
# environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
|
||||
else:
|
||||
environment_os = "browser"
|
||||
|
||||
@@ -393,3 +424,33 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
},
|
||||
]
|
||||
return tools
|
||||
|
||||
def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]:
|
||||
"""Format messages for summary."""
|
||||
# Format messages to interact with non computer use AI models
|
||||
items_to_drop = [] # Track indices to drop reasoning messages
|
||||
for idx, msg in enumerate(formatted_messages):
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
continue
|
||||
elif isinstance(msg, dict) and "output" in msg:
|
||||
# Drop current_url from output as not supported for non computer operations
|
||||
if "current_url" in msg["output"]:
|
||||
del msg["output"]["current_url"]
|
||||
formatted_messages[idx] = {"role": "user", "content": [msg["output"]]}
|
||||
elif isinstance(msg, str):
|
||||
formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]}
|
||||
else:
|
||||
text = self._compile_response([msg])
|
||||
if not text:
|
||||
items_to_drop.append(idx) # Track index to drop reasoning message
|
||||
else:
|
||||
formatted_messages[idx] = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": text}],
|
||||
}
|
||||
|
||||
# Remove reasoning messages for non-computer use models
|
||||
for idx in reversed(items_to_drop):
|
||||
formatted_messages.pop(idx)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
Reference in New Issue
Block a user