mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Fix, improve openai operator agent for interrupts, computer environment
- Create reusable method to call model - Fix to summarize messages on operator run. - Mark assistant tool calls with role = assistant, not environment - Try fix message format when load after interrupts. Does not work well yet
This commit is contained in:
@@ -97,7 +97,7 @@ class OperatorRun:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
trajectory: list[AgentMessage | dict] = None,
|
trajectory: list[AgentMessage] | list[dict] = None,
|
||||||
response: str = None,
|
response: str = None,
|
||||||
webpages: list[dict] = None,
|
webpages: list[dict] = None,
|
||||||
):
|
):
|
||||||
@@ -138,7 +138,7 @@ class ResearchIteration:
|
|||||||
context: list = None,
|
context: list = None,
|
||||||
onlineContext: dict = None,
|
onlineContext: dict = None,
|
||||||
codeContext: dict = None,
|
codeContext: dict = None,
|
||||||
operatorContext: dict = None,
|
operatorContext: dict | OperatorRun = None,
|
||||||
summarizedResult: str = None,
|
summarizedResult: str = None,
|
||||||
warning: str = None,
|
warning: str = None,
|
||||||
):
|
):
|
||||||
@@ -147,15 +147,13 @@ class ResearchIteration:
|
|||||||
self.context = context
|
self.context = context
|
||||||
self.onlineContext = onlineContext
|
self.onlineContext = onlineContext
|
||||||
self.codeContext = codeContext
|
self.codeContext = codeContext
|
||||||
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None
|
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
|
||||||
self.summarizedResult = summarizedResult
|
self.summarizedResult = summarizedResult
|
||||||
self.warning = warning
|
self.warning = warning
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
data = vars(self).copy()
|
data = vars(self).copy()
|
||||||
data["operatorContext"] = (
|
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
|
||||||
self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None
|
|
||||||
)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import List, Optional, cast
|
|||||||
|
|
||||||
from openai.types.responses import Response, ResponseOutputItem
|
from openai.types.responses import Response, ResponseOutputItem
|
||||||
|
|
||||||
|
from khoj.database.models import ChatModel
|
||||||
from khoj.processor.conversation.utils import AgentMessage
|
from khoj.processor.conversation.utils import AgentMessage
|
||||||
from khoj.processor.operator.operator_actions import *
|
from khoj.processor.operator.operator_actions import *
|
||||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||||
@@ -24,9 +25,6 @@ logger = logging.getLogger(__name__)
|
|||||||
# --- Anthropic Operator Agent ---
|
# --- Anthropic Operator Agent ---
|
||||||
class OpenAIOperatorAgent(OperatorAgent):
|
class OpenAIOperatorAgent(OperatorAgent):
|
||||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
client = get_openai_async_client(
|
|
||||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
|
||||||
)
|
|
||||||
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
||||||
safety_check_message = None
|
safety_check_message = None
|
||||||
actions: List[OperatorAction] = []
|
actions: List[OperatorAction] = []
|
||||||
@@ -34,23 +32,11 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
self._commit_trace() # Commit trace before next action
|
self._commit_trace() # Commit trace before next action
|
||||||
system_prompt = self.get_instructions(self.environment_type, current_state)
|
system_prompt = self.get_instructions(self.environment_type, current_state)
|
||||||
tools = self.get_tools(self.environment_type, current_state)
|
tools = self.get_tools(self.environment_type, current_state)
|
||||||
|
|
||||||
if is_none_or_empty(self.messages):
|
if is_none_or_empty(self.messages):
|
||||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||||
|
|
||||||
messages_for_api = self._format_message_for_api(self.messages)
|
response = await self._call_model(self.vision_model, system_prompt, tools)
|
||||||
response: Response = await client.responses.create(
|
self.messages += [AgentMessage(role="assistant", content=response.output)]
|
||||||
model="computer-use-preview",
|
|
||||||
input=messages_for_api,
|
|
||||||
instructions=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
parallel_tool_calls=False, # Keep sequential for now
|
|
||||||
max_output_tokens=4096, # TODO: Make configurable?
|
|
||||||
truncation="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
|
||||||
self.messages += [AgentMessage(role="environment", content=response.output)]
|
|
||||||
rendered_response = self._render_response(response.output, current_state.screenshot)
|
rendered_response = self._render_response(response.output, current_state.screenshot)
|
||||||
|
|
||||||
last_call_id = None
|
last_call_id = None
|
||||||
@@ -130,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
"summary": [],
|
"summary": [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported response block type: {block.type}")
|
||||||
|
content = f"Unsupported response block type: {block.type}"
|
||||||
if action_to_run or content:
|
if action_to_run or content:
|
||||||
actions.append(action_to_run)
|
actions.append(action_to_run)
|
||||||
if action_to_run or content:
|
if action_to_run or content:
|
||||||
@@ -176,6 +165,9 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
elif action_result["type"] == "reasoning":
|
elif action_result["type"] == "reasoning":
|
||||||
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
|
items_to_pop.append(idx) # Mark placeholder reasoning action result for removal
|
||||||
continue
|
continue
|
||||||
|
elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress":
|
||||||
|
result_content["status"] = "completed" # Mark in-progress actions as completed
|
||||||
|
action_result["output"] = result_content
|
||||||
else:
|
else:
|
||||||
# Add text data
|
# Add text data
|
||||||
action_result["output"] = result_content
|
action_result["output"] = result_content
|
||||||
@@ -185,11 +177,45 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||||
|
|
||||||
|
async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str:
|
||||||
|
summarize_prompt = summarize_prompt or self.summarize_prompt
|
||||||
|
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||||
|
response = await self._call_model(self.vision_model, summarize_prompt, [])
|
||||||
|
self.messages += [AgentMessage(role="assistant", content=response.output)]
|
||||||
|
if not self.messages:
|
||||||
|
return "No actions to summarize."
|
||||||
|
return self._compile_response(self.messages[-1].content)
|
||||||
|
|
||||||
|
async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response:
|
||||||
|
client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
|
||||||
|
if tools:
|
||||||
|
model_name = "computer-use-preview"
|
||||||
|
else:
|
||||||
|
model_name = model.name
|
||||||
|
|
||||||
|
# Format messages for OpenAI API
|
||||||
|
messages_for_api = self._format_message_for_api(self.messages)
|
||||||
|
# format messages for summary if model is not computer-use-preview
|
||||||
|
if model_name != "computer-use-preview":
|
||||||
|
messages_for_api = self._format_messages_for_summary(messages_for_api)
|
||||||
|
|
||||||
|
response: Response = await client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
input=messages_for_api,
|
||||||
|
instructions=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
truncation="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||||
|
return response
|
||||||
|
|
||||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
||||||
"""Format the message for OpenAI API."""
|
"""Format the message for OpenAI API."""
|
||||||
formatted_messages: list = []
|
formatted_messages: list = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == "environment":
|
if message.role == "assistant":
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
# Remove reasoning message if not followed by computer call
|
# Remove reasoning message if not followed by computer call
|
||||||
if (
|
if (
|
||||||
@@ -208,14 +234,19 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
message.content.pop(0)
|
message.content.pop(0)
|
||||||
formatted_messages.extend(message.content)
|
formatted_messages.extend(message.content)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
|
logger.warning(f"Expected message content list from assistant, got {type(message.content)}")
|
||||||
|
elif message.role == "environment":
|
||||||
|
formatted_messages.extend(message.content)
|
||||||
else:
|
else:
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||||
formatted_messages.append(
|
formatted_messages.append(
|
||||||
{
|
{
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
|
def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
|
||||||
@@ -352,10 +383,10 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
|
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
|
||||||
"""Return the tools available for the OpenAI operator."""
|
"""Return the tools available for the OpenAI operator."""
|
||||||
if environment_type == EnvironmentType.COMPUTER:
|
if environment_type == EnvironmentType.COMPUTER:
|
||||||
# get os info of this computer. it can be mac, windows, linux
|
# TODO: Get OS info from the environment
|
||||||
environment_os = (
|
# For now, assume Linux as the environment OS
|
||||||
"mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
|
environment_os = "linux"
|
||||||
)
|
# environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
|
||||||
else:
|
else:
|
||||||
environment_os = "browser"
|
environment_os = "browser"
|
||||||
|
|
||||||
@@ -393,3 +424,33 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]:
|
||||||
|
"""Format messages for summary."""
|
||||||
|
# Format messages to interact with non computer use AI models
|
||||||
|
items_to_drop = [] # Track indices to drop reasoning messages
|
||||||
|
for idx, msg in enumerate(formatted_messages):
|
||||||
|
if isinstance(msg, dict) and "content" in msg:
|
||||||
|
continue
|
||||||
|
elif isinstance(msg, dict) and "output" in msg:
|
||||||
|
# Drop current_url from output as not supported for non computer operations
|
||||||
|
if "current_url" in msg["output"]:
|
||||||
|
del msg["output"]["current_url"]
|
||||||
|
formatted_messages[idx] = {"role": "user", "content": [msg["output"]]}
|
||||||
|
elif isinstance(msg, str):
|
||||||
|
formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]}
|
||||||
|
else:
|
||||||
|
text = self._compile_response([msg])
|
||||||
|
if not text:
|
||||||
|
items_to_drop.append(idx) # Track index to drop reasoning message
|
||||||
|
else:
|
||||||
|
formatted_messages[idx] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": text}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove reasoning messages for non-computer use models
|
||||||
|
for idx in reversed(items_to_drop):
|
||||||
|
formatted_messages.pop(idx)
|
||||||
|
|
||||||
|
return formatted_messages
|
||||||
|
|||||||
Reference in New Issue
Block a user