From da663e184c1fcb5398276d70bd3379e59b0dfe37 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 28 May 2025 21:53:33 -0700 Subject: [PATCH] Type operator results. Enable storing, loading operator trajectories. We were passing operator results as a simple dictionary. Strongly typing it makes sense as operator results becomes more complex. Storing operator results with trajectory on interrupts will allow restarting interrupted operator run with agent messages of interrupted trajectory loaded into operator agents --- src/khoj/database/models/__init__.py | 2 +- .../conversation/anthropic/anthropic_chat.py | 8 ++- .../conversation/google/gemini_chat.py | 8 ++- src/khoj/processor/conversation/openai/gpt.py | 8 ++- src/khoj/processor/conversation/utils.py | 68 ++++++++++++++++--- src/khoj/processor/operator/__init__.py | 12 ++-- .../operator/operator_agent_anthropic.py | 7 +- .../processor/operator/operator_agent_base.py | 7 +- .../operator/operator_agent_binary.py | 8 +-- .../operator/operator_agent_openai.py | 7 +- src/khoj/routers/api_chat.py | 21 +++--- src/khoj/routers/helpers.py | 5 +- src/khoj/routers/research.py | 19 +++--- 13 files changed, 118 insertions(+), 62 deletions(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 34538df0..2c5f0ada 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -107,7 +107,7 @@ class ChatMessage(PydanticBaseModel): onlineContext: Dict[str, OnlineContext] = {} codeContext: Dict[str, CodeContextData] = {} researchContext: Optional[List] = None - operatorContext: Optional[Dict[str, str]] = None + operatorContext: Optional[List] = None created: str images: Optional[List[str]] = None queryFiles: Optional[List[Dict]] = None diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index e2de2c59..2e52a9f2 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import ( format_messages_for_anthropic, ) from khoj.processor.conversation.utils import ( + OperatorRun, ResponseWithThought, clean_json, construct_structured_message, @@ -144,7 +145,7 @@ async def converse_anthropic( references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, str]] = None, + operator_results: Optional[List[OperatorRun]] = None, conversation_log={}, model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, @@ -216,8 +217,11 @@ async def converse_anthropic( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n" + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" ) context_message = context_message.strip() diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 6cbe87c8..17bb9b76 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.google.utils import ( gemini_completion_with_backoff, ) from khoj.processor.conversation.utils import ( + OperatorRun, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -166,7 +167,7 @@ async def converse_gemini( references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, str]] = None, + operator_results: Optional[List[OperatorRun]] = None, conversation_log={}, model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, @@ -240,8 +241,11 @@ async def converse_gemini( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n" + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" ) context_message = context_message.strip() diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 28739843..d5fcd0a1 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import ( ) from khoj.processor.conversation.utils import ( JsonSupport, + OperatorRun, ResponseWithThought, clean_json, construct_structured_message, @@ -169,7 +170,7 @@ async def converse_openai( references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, str]] = None, + operator_results: Optional[List[OperatorRun]] = None, conversation_log={}, model: str = "gpt-4o-mini", api_key: Optional[str] = None, @@ -242,8 +243,11 @@ async def converse_openai( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) if not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n" + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" ) context_message = context_message.strip() diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 36aa001d..899a6786 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Union import PIL.Image import pyjson5 @@ -20,6 +20,7 @@ import yaml from langchain_core.messages.chat import ChatMessage from llama_cpp import LlamaTokenizer from llama_cpp.llama import Llama +from pydantic import BaseModel from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters @@ -87,6 +88,48 @@ model_to_prompt_size = { model_to_tokenizer: Dict[str, str] = {} +class AgentMessage(BaseModel): + role: Literal["user", "assistant", "system", "environment"] + content: Union[str, List] + + +class OperatorRun: + def __init__( + self, + query: str, + trajectory: list[AgentMessage | dict] = None, + response: str = None, + webpages: list[dict] = None, + ): + self.query = query + self.response = response + self.webpages = webpages or [] + self.trajectory: list[AgentMessage] = [] + if trajectory: + for item in trajectory: + if isinstance(item, dict): + self.trajectory.append(AgentMessage(**item)) + elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object + self.trajectory.append(item) + else: + logger.warning(f"Unexpected item type in trajectory: {type(item)}") + + def to_dict(self) -> dict: + # Ensure AgentMessage instances in trajectory are also dicts + serialized_trajectory = [] + for msg in self.trajectory: + if hasattr(msg, "model_dump"): # Check if it's a Pydantic model + serialized_trajectory.append(msg.model_dump()) + elif isinstance(msg, dict): + serialized_trajectory.append(msg) # Already a dict + return { + "query": self.query, + "response": self.response, + "trajectory": serialized_trajectory, + "webpages": self.webpages, + } + + class InformationCollectionIteration: def __init__( self, @@ -95,7 +138,7 @@ class InformationCollectionIteration: context: list = None, onlineContext: dict = None, codeContext: dict = None, - operatorContext: dict[str, str] = None, + operatorContext: dict = None, summarizedResult: str = None, warning: str = None, ): @@ -104,10 +147,17 @@ class InformationCollectionIteration: self.context = context self.onlineContext = onlineContext self.codeContext = codeContext - self.operatorContext = operatorContext + self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None self.summarizedResult = summarizedResult self.warning = warning + def to_dict(self) -> dict: + data = vars(self).copy() + data["operatorContext"] = ( + self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None + ) + return data + def construct_iteration_history( previous_iterations: List[InformationCollectionIteration], @@ -193,7 +243,7 @@ def construct_tool_chat_history( lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] ), ConversationCommand.Operator: ( - lambda iteration: list(iteration.operatorContext.keys()) if iteration.operatorContext else [] + lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else [] ), } for iteration in previous_iterations: @@ -273,7 +323,7 @@ async def save_to_conversation_log( compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, code_results: Dict[str, Any] = {}, - operator_results: Dict[str, str] = {}, + operator_results: List[OperatorRun] = None, inferred_queries: List[str] = [], intent_type: str = "remember", client_application: ClientApplication = None, @@ -301,8 +351,8 @@ async def save_to_conversation_log( "intent": {"inferred-queries": inferred_queries, "type": intent_type}, "onlineContext": online_results, "codeContext": code_results, - "operatorContext": operator_results, - "researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None, + "operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None, + "researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None, "automationId": automation_id, "trainOfThought": train_of_thought, "turnId": turn_id, @@ -459,10 +509,12 @@ def generate_chatml_messages_with_context( ] if not is_none_or_empty(chat.get("operatorContext")): + operator_context = chat.get("operatorContext") + operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context]) message_context += [ { "type": "text", - "text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}", + "text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}", } ] diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index b2ea846f..e38ae777 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.processor.conversation.utils import OperatorRun from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent @@ -160,11 +161,12 @@ async def operate_environment( if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"): webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls] - yield { - "query": query, - "result": user_input_message or response, - "webpages": webpages, - } + yield OperatorRun( + query=query, + trajectory=operator_agent.messages, + response=response, + webpages=webpages, + ) def is_operator_model(model: str) -> ChatModel.ModelType | None: diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index 567411f6..2b98d50d 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -10,12 +10,9 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo from khoj.database.models import ChatModel from khoj.processor.conversation.anthropic.utils import is_reasoning_model +from khoj.processor.conversation.utils import AgentMessage from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import ( - AgentActResult, - AgentMessage, - OperatorAgent, -) +from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, EnvState, diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 8c273b5e..1aa0f238 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union from pydantic import BaseModel from khoj.database.models import ChatModel -from khoj.processor.conversation.utils import commit_conversation_trace +from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace from khoj.processor.operator.operator_actions import OperatorAction from khoj.processor.operator.operator_environment_base import ( EnvironmentType, @@ -23,11 +23,6 @@ class AgentActResult(BaseModel): rendered_response: Optional[dict] = None -class AgentMessage(BaseModel): - role: Literal["user", "assistant", "system", "environment"] - content: Union[str, List] - - class OperatorAgent(ABC): def __init__( self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index b869e8bb..77e28442 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -5,15 +5,11 @@ from textwrap import dedent from typing import List from khoj.database.models import ChatModel -from khoj.processor.conversation.utils import construct_structured_message +from khoj.processor.conversation.utils import AgentMessage, construct_structured_message from khoj.processor.operator.grounding_agent import GroundingAgent from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import ( - AgentActResult, - AgentMessage, - OperatorAgent, -) +from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, EnvState, diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py index 2b665f4d..7ae71bff 100644 --- a/src/khoj/processor/operator/operator_agent_openai.py +++ b/src/khoj/processor/operator/operator_agent_openai.py @@ -8,12 +8,9 @@ from typing import List, Optional, cast from openai.types.responses import Response, ResponseOutputItem +from khoj.processor.conversation.utils import AgentMessage from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import ( - AgentActResult, - AgentMessage, - OperatorAgent, -) +from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, EnvState, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 39e8c1a1..2cade5bc 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -26,6 +26,7 @@ from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import ( + OperatorRun, ResponseWithThought, defilter_query, save_to_conversation_log, @@ -725,7 +726,7 @@ async def chat( research_results: List[InformationCollectionIteration] = [] online_results: Dict = dict() code_results: Dict = dict() - operator_results: Dict[str, str] = {} + operator_results: List[OperatorRun] = [] compiled_references: List[Any] = [] inferred_queries: List[Any] = [] attached_file_context = gather_raw_query_files(query_files) @@ -960,11 +961,12 @@ async def chat( last_message = conversation.messages[-1] online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} - operator_results = last_message.operatorContext or {} compiled_references = [ref.model_dump() for ref in last_message.context or []] research_results = [ InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or [] ] + operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] + train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] # Drop the interrupted message from conversation history meta_log["chat"].pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") @@ -1034,7 +1036,7 @@ async def chat( if research_result.context: compiled_references.extend(research_result.context) if research_result.operatorContext: - operator_results.update(research_result.operatorContext) + operator_results.append(research_result.operatorContext) research_results.append(research_result) else: @@ -1306,16 +1308,16 @@ async def chat( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - operator_results = {result["query"]: result["result"]} + elif isinstance(result, OperatorRun): + operator_results.append(result) # Add webpages visited while operating browser to references - if result.get("webpages"): + if result.webpages: if not online_results.get(defiltered_query): - online_results[defiltered_query] = {"webpages": result["webpages"]} + online_results[defiltered_query] = {"webpages": result.webpages} elif not online_results[defiltered_query].get("webpages"): - online_results[defiltered_query]["webpages"] = result["webpages"] + online_results[defiltered_query]["webpages"] = result.webpages else: - online_results[defiltered_query]["webpages"] += result["webpages"] + online_results[defiltered_query]["webpages"] += result.webpages except ValueError as e: program_execution_context.append(f"Browser operation error: {e}") logger.warning(f"Failed to operate browser with {e}", exc_info=True) @@ -1333,7 +1335,6 @@ async def chat( "context": compiled_references, "onlineContext": unique_online_results, "codeContext": code_results, - "operatorContext": operator_results, }, ): yield result diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c1ddb82d..a3322b67 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -95,6 +95,7 @@ from khoj.processor.conversation.openai.gpt import ( from khoj.processor.conversation.utils import ( ChatEvent, InformationCollectionIteration, + OperatorRun, ResponseWithThought, clean_json, clean_mermaidjs, @@ -1355,7 +1356,7 @@ async def agenerate_chat_response( compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {}, - operator_results: Dict[str, str] = {}, + operator_results: List[OperatorRun] = [], research_results: List[InformationCollectionIteration] = [], inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], @@ -1414,7 +1415,7 @@ async def agenerate_chat_response( compiled_references = [] online_results = {} code_results = {} - operator_results = {} + operator_results = [] deepthought = True chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index de55e47b..ab0c7062 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -14,6 +14,7 @@ from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( InformationCollectionIteration, + OperatorRun, construct_iteration_history, construct_tool_chat_history, load_complex_json, @@ -248,7 +249,7 @@ async def execute_information_collection( online_results: Dict = dict() code_results: Dict = dict() document_results: List[Dict[str, str]] = [] - operator_results: Dict[str, str] = {} + operator_results: OperatorRun = None summarize_files: str = "" this_iteration = InformationCollectionIteration(tool=None, query=query) @@ -431,17 +432,17 @@ async def execute_information_collection( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - operator_results = {result["query"]: result["result"]} + elif isinstance(result, OperatorRun): + operator_results = result this_iteration.operatorContext = operator_results # Add webpages visited while operating browser to references - if result.get("webpages"): + if result.webpages: if not online_results.get(this_iteration.query): - online_results[this_iteration.query] = {"webpages": result["webpages"]} + online_results[this_iteration.query] = {"webpages": result.webpages} elif not online_results[this_iteration.query].get("webpages"): - online_results[this_iteration.query]["webpages"] = result["webpages"] + online_results[this_iteration.query]["webpages"] = result.webpages else: - online_results[this_iteration.query]["webpages"] += result["webpages"] + online_results[this_iteration.query]["webpages"] += result.webpages this_iteration.onlineContext = online_results except Exception as e: this_iteration.warning = f"Error operating browser: {e}" @@ -489,7 +490,9 @@ async def execute_information_collection( if code_results: results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if operator_results: - results_data += f"\n\n{next(iter(operator_results.values()))}\n" + results_data += ( + f"\n\n{operator_results.response}\n" + ) if summarize_files: results_data += f"\n\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if this_iteration.warning: