diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index efaf0f0a..ec531675 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -39,7 +39,7 @@ async def operate_environment( cancellation_event: Optional[asyncio.Event] = None, tracer: dict = {}, ): - response, summary_message, user_input_message = None, None, None + response, user_input_message = None, None # Get the agent chat model agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None @@ -81,7 +81,6 @@ async def operate_environment( # Start Operator Loop try: - summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." task_completed = False iterations = 0 @@ -137,7 +136,7 @@ async def operate_environment( if task_completed or trigger_iteration_limit: # Summarize results of operator run on last iteration operator_agent.add_action_results(env_steps, agent_result) - summary_message = await operator_agent.summarize(summarize_prompt, env_state) + summary_message = await operator_agent.summarize(env_state) logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}") break diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 8aec6225..eb9bd544 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -38,6 +38,7 @@ class OperatorAgent(ABC): self.max_iterations = max_iterations self.tracer = tracer self.messages: List[AgentMessage] = [] + self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." @abstractmethod async def act(self, current_state: EnvState) -> AgentActResult: @@ -48,8 +49,9 @@ class OperatorAgent(ABC): """Track results of agent actions on the environment.""" pass - async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str: + async def summarize(self, current_state: EnvState, summarize_prompt: str = None) -> str: """Summarize the agent's actions and results.""" + summarize_prompt = summarize_prompt or self.summarize_prompt self.messages.append(AgentMessage(role="user", content=summarize_prompt)) await self.act(current_state) if not self.messages: diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index f6dd63c9..b869e8bb 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -226,7 +226,8 @@ class BinaryOperatorAgent(OperatorAgent): action_results_content.extend(action_result["content"]) self.messages.append(AgentMessage(role="environment", content=action_results_content)) - async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str: + async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str: + summarize_prompt = summarize_prompt or self.summarize_prompt conversation_history = {"chat": self._format_message_for_api(self.messages)} try: summary = await send_message_to_model_wrapper(