diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 52acfbfd..bf3a20c7 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -97,7 +97,7 @@ class OperatorRun: def __init__( self, query: str, - trajectory: list[AgentMessage | dict] = None, + trajectory: list[AgentMessage] | list[dict] = None, response: str = None, webpages: list[dict] = None, ): @@ -138,7 +138,7 @@ class ResearchIteration: context: list = None, onlineContext: dict = None, codeContext: dict = None, - operatorContext: dict = None, + operatorContext: dict | OperatorRun = None, summarizedResult: str = None, warning: str = None, ): @@ -147,15 +147,13 @@ class ResearchIteration: self.context = context self.onlineContext = onlineContext self.codeContext = codeContext - self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None + self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext self.summarizedResult = summarizedResult self.warning = warning def to_dict(self) -> dict: data = vars(self).copy() - data["operatorContext"] = ( - self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None - ) + data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None return data diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py index af115e1c..606233a8 100644 --- a/src/khoj/processor/operator/operator_agent_openai.py +++ b/src/khoj/processor/operator/operator_agent_openai.py @@ -8,6 +8,7 @@ from typing import List, Optional, cast from openai.types.responses import Response, ResponseOutputItem +from khoj.database.models import ChatModel from khoj.processor.conversation.utils import AgentMessage from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent @@ -24,9 +25,6 @@ logger = logging.getLogger(__name__) # --- Anthropic Operator Agent --- class OpenAIOperatorAgent(OperatorAgent): async def act(self, current_state: EnvState) -> AgentActResult: - client = get_openai_async_client( - self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url - ) safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:" safety_check_message = None actions: List[OperatorAction] = [] @@ -34,23 +32,11 @@ class OpenAIOperatorAgent(OperatorAgent): self._commit_trace() # Commit trace before next action system_prompt = self.get_instructions(self.environment_type, current_state) tools = self.get_tools(self.environment_type, current_state) - if is_none_or_empty(self.messages): self.messages = [AgentMessage(role="user", content=self.query)] - messages_for_api = self._format_message_for_api(self.messages) - response: Response = await client.responses.create( - model="computer-use-preview", - input=messages_for_api, - instructions=system_prompt, - tools=tools, - parallel_tool_calls=False, # Keep sequential for now - max_output_tokens=4096, # TODO: Make configurable? - truncation="auto", - ) - - logger.debug(f"Openai response: {response.model_dump_json()}") - self.messages += [AgentMessage(role="environment", content=response.output)] + response = await self._call_model(self.vision_model, system_prompt, tools) + self.messages += [AgentMessage(role="assistant", content=response.output)] rendered_response = self._render_response(response.output, current_state.screenshot) last_call_id = None @@ -130,6 +116,9 @@ class OpenAIOperatorAgent(OperatorAgent): "summary": [], } ) + else: + logger.warning(f"Unsupported response block type: {block.type}") + content = f"Unsupported response block type: {block.type}" if action_to_run or content: actions.append(action_to_run) if action_to_run or content: @@ -176,6 +165,9 @@ class OpenAIOperatorAgent(OperatorAgent): elif action_result["type"] == "reasoning": items_to_pop.append(idx) # Mark placeholder reasoning action result for removal continue + elif action_result["type"] == "computer_call" and action_result["status"] == "in_progress": + result_content["status"] = "completed" # Mark in-progress actions as completed + action_result["output"] = result_content else: # Add text data action_result["output"] = result_content @@ -185,11 +177,45 @@ class OpenAIOperatorAgent(OperatorAgent): self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] + async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str: + summarize_prompt = summarize_prompt or self.summarize_prompt + self.messages.append(AgentMessage(role="user", content=summarize_prompt)) + response = await self._call_model(self.vision_model, summarize_prompt, []) + self.messages += [AgentMessage(role="assistant", content=response.output)] + if not self.messages: + return "No actions to summarize." + return self._compile_response(self.messages[-1].content) + + async def _call_model(self, model: ChatModel, system_prompt, tools) -> Response: + client = get_openai_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url) + if tools: + model_name = "computer-use-preview" + else: + model_name = model.name + + # Format messages for OpenAI API + messages_for_api = self._format_message_for_api(self.messages) + # format messages for summary if model is not computer-use-preview + if model_name != "computer-use-preview": + messages_for_api = self._format_messages_for_summary(messages_for_api) + + response: Response = await client.responses.create( + model=model_name, + input=messages_for_api, + instructions=system_prompt, + tools=tools, + parallel_tool_calls=False, + truncation="auto", + ) + + logger.debug(f"Openai response: {response.model_dump_json()}") + return response + def _format_message_for_api(self, messages: list[AgentMessage]) -> list: """Format the message for OpenAI API.""" formatted_messages: list = [] for message in messages: - if message.role == "environment": + if message.role == "assistant": if isinstance(message.content, list): # Remove reasoning message if not followed by computer call if ( @@ -208,14 +234,19 @@ class OpenAIOperatorAgent(OperatorAgent): message.content.pop(0) formatted_messages.extend(message.content) else: - logger.warning(f"Expected message content list from environment, got {type(message.content)}") + logger.warning(f"Expected message content list from assistant, got {type(message.content)}") + elif message.role == "environment": + formatted_messages.extend(message.content) else: + if isinstance(message.content, list): + message.content = "\n".join([part["text"] for part in message.content if part["type"] == "text"]) formatted_messages.append( { "role": message.role, "content": message.content, } ) + return formatted_messages def _compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str: @@ -352,10 +383,10 @@ class OpenAIOperatorAgent(OperatorAgent): def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]: """Return the tools available for the OpenAI operator.""" if environment_type == EnvironmentType.COMPUTER: - # get os info of this computer. it can be mac, windows, linux - environment_os = ( - "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux" - ) + # TODO: Get OS info from the environment + # For now, assume Linux as the environment OS + environment_os = "linux" + # environment = "mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux" else: environment_os = "browser" @@ -393,3 +424,33 @@ class OpenAIOperatorAgent(OperatorAgent): }, ] return tools + + def _format_messages_for_summary(self, formatted_messages: List[dict]) -> List[dict]: + """Format messages for summary.""" + # Format messages to interact with non computer use AI models + items_to_drop = [] # Track indices to drop reasoning messages + for idx, msg in enumerate(formatted_messages): + if isinstance(msg, dict) and "content" in msg: + continue + elif isinstance(msg, dict) and "output" in msg: + # Drop current_url from output as not supported for non computer operations + if "current_url" in msg["output"]: + del msg["output"]["current_url"] + formatted_messages[idx] = {"role": "user", "content": [msg["output"]]} + elif isinstance(msg, str): + formatted_messages[idx] = {"role": "user", "content": [{"type": "input_text", "text": msg}]} + else: + text = self._compile_response([msg]) + if not text: + items_to_drop.append(idx) # Track index to drop reasoning message + else: + formatted_messages[idx] = { + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + + # Remove reasoning messages for non-computer use models + for idx in reversed(items_to_drop): + formatted_messages.pop(idx) + + return formatted_messages