Type operator results. Enable storing, loading operator trajectories.

We were passing operator results as a simple dictionary. Strongly
typing it makes sense as operator results becomes more complex.

Storing operator results with trajectory on interrupts will allow
restarting interrupted operator run with agent messages of interrupted
trajectory loaded into operator agents
This commit is contained in:
Debanjum
2025-05-28 21:53:33 -07:00
parent 675fc0ad05
commit da663e184c
13 changed files with 118 additions and 62 deletions

View File

@@ -107,7 +107,7 @@ class ChatMessage(PydanticBaseModel):
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
researchContext: Optional[List] = None
operatorContext: Optional[Dict[str, str]] = None
operatorContext: Optional[List] = None
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None

View File

@@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
clean_json,
construct_structured_message,
@@ -144,7 +145,7 @@ async def converse_anthropic(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, str]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
@@ -216,8 +217,11 @@ async def converse_anthropic(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()

View File

@@ -14,6 +14,7 @@ from khoj.processor.conversation.google.utils import (
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import (
OperatorRun,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -166,7 +167,7 @@ async def converse_gemini(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, str]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
@@ -240,8 +241,11 @@ async def converse_gemini(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()

View File

@@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
OperatorRun,
ResponseWithThought,
clean_json,
construct_structured_message,
@@ -169,7 +170,7 @@ async def converse_openai(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, str]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
@@ -242,8 +243,11 @@ async def converse_openai(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()

View File

@@ -10,7 +10,7 @@ from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Literal, Optional, Union
import PIL.Image
import pyjson5
@@ -20,6 +20,7 @@ import yaml
from langchain_core.messages.chat import ChatMessage
from llama_cpp import LlamaTokenizer
from llama_cpp.llama import Llama
from pydantic import BaseModel
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters
@@ -87,6 +88,48 @@ model_to_prompt_size = {
model_to_tokenizer: Dict[str, str] = {}
class AgentMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]
class OperatorRun:
def __init__(
self,
query: str,
trajectory: list[AgentMessage | dict] = None,
response: str = None,
webpages: list[dict] = None,
):
self.query = query
self.response = response
self.webpages = webpages or []
self.trajectory: list[AgentMessage] = []
if trajectory:
for item in trajectory:
if isinstance(item, dict):
self.trajectory.append(AgentMessage(**item))
elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object
self.trajectory.append(item)
else:
logger.warning(f"Unexpected item type in trajectory: {type(item)}")
def to_dict(self) -> dict:
# Ensure AgentMessage instances in trajectory are also dicts
serialized_trajectory = []
for msg in self.trajectory:
if hasattr(msg, "model_dump"): # Check if it's a Pydantic model
serialized_trajectory.append(msg.model_dump())
elif isinstance(msg, dict):
serialized_trajectory.append(msg) # Already a dict
return {
"query": self.query,
"response": self.response,
"trajectory": serialized_trajectory,
"webpages": self.webpages,
}
class InformationCollectionIteration:
def __init__(
self,
@@ -95,7 +138,7 @@ class InformationCollectionIteration:
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
operatorContext: dict[str, str] = None,
operatorContext: dict = None,
summarizedResult: str = None,
warning: str = None,
):
@@ -104,10 +147,17 @@ class InformationCollectionIteration:
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.operatorContext = operatorContext
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None
self.summarizedResult = summarizedResult
self.warning = warning
def to_dict(self) -> dict:
data = vars(self).copy()
data["operatorContext"] = (
self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None
)
return data
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration],
@@ -193,7 +243,7 @@ def construct_tool_chat_history(
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
ConversationCommand.Operator: (
lambda iteration: list(iteration.operatorContext.keys()) if iteration.operatorContext else []
lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else []
),
}
for iteration in previous_iterations:
@@ -273,7 +323,7 @@ async def save_to_conversation_log(
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
code_results: Dict[str, Any] = {},
operator_results: Dict[str, str] = {},
operator_results: List[OperatorRun] = None,
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
@@ -301,8 +351,8 @@ async def save_to_conversation_log(
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"operatorContext": operator_results,
"researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
"operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None,
"researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
@@ -459,10 +509,12 @@ def generate_chatml_messages_with_context(
]
if not is_none_or_empty(chat.get("operatorContext")):
operator_context = chat.get("operatorContext")
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
message_context += [
{
"type": "text",
"text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
"text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}",
}
]

View File

@@ -6,6 +6,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation.utils import OperatorRun
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent
@@ -160,11 +161,12 @@ async def operate_environment(
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
yield {
"query": query,
"result": user_input_message or response,
"webpages": webpages,
}
yield OperatorRun(
query=query,
trajectory=operator_agent.messages,
response=response,
webpages=webpages,
)
def is_operator_model(model: str) -> ChatModel.ModelType | None:

View File

@@ -10,12 +10,9 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo
from khoj.database.models import ChatModel
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
AgentMessage,
OperatorAgent,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,

View File

@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import commit_conversation_trace
from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace
from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
@@ -23,11 +23,6 @@ class AgentActResult(BaseModel):
rendered_response: Optional[dict] = None
class AgentMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]
class OperatorAgent(ABC):
def __init__(
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict

View File

@@ -5,15 +5,11 @@ from textwrap import dedent
from typing import List
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.conversation.utils import AgentMessage, construct_structured_message
from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
AgentMessage,
OperatorAgent,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,

View File

@@ -8,12 +8,9 @@ from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
AgentMessage,
OperatorAgent,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,

View File

@@ -26,6 +26,7 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
defilter_query,
save_to_conversation_log,
@@ -725,7 +726,7 @@ async def chat(
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: Dict[str, str] = {}
operator_results: List[OperatorRun] = []
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
@@ -960,11 +961,12 @@ async def chat(
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
operator_results = last_message.operatorContext or {}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
@@ -1034,7 +1036,7 @@ async def chat(
if research_result.context:
compiled_references.extend(research_result.context)
if research_result.operatorContext:
operator_results.update(research_result.operatorContext)
operator_results.append(research_result.operatorContext)
research_results.append(research_result)
else:
@@ -1306,16 +1308,16 @@ async def chat(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
operator_results = {result["query"]: result["result"]}
elif isinstance(result, OperatorRun):
operator_results.append(result)
# Add webpages visited while operating browser to references
if result.get("webpages"):
if result.webpages:
if not online_results.get(defiltered_query):
online_results[defiltered_query] = {"webpages": result["webpages"]}
online_results[defiltered_query] = {"webpages": result.webpages}
elif not online_results[defiltered_query].get("webpages"):
online_results[defiltered_query]["webpages"] = result["webpages"]
online_results[defiltered_query]["webpages"] = result.webpages
else:
online_results[defiltered_query]["webpages"] += result["webpages"]
online_results[defiltered_query]["webpages"] += result.webpages
except ValueError as e:
program_execution_context.append(f"Browser operation error: {e}")
logger.warning(f"Failed to operate browser with {e}", exc_info=True)
@@ -1333,7 +1335,6 @@ async def chat(
"context": compiled_references,
"onlineContext": unique_online_results,
"codeContext": code_results,
"operatorContext": operator_results,
},
):
yield result

View File

@@ -95,6 +95,7 @@ from khoj.processor.conversation.openai.gpt import (
from khoj.processor.conversation.utils import (
ChatEvent,
InformationCollectionIteration,
OperatorRun,
ResponseWithThought,
clean_json,
clean_mermaidjs,
@@ -1355,7 +1356,7 @@ async def agenerate_chat_response(
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
operator_results: Dict[str, str] = {},
operator_results: List[OperatorRun] = [],
research_results: List[InformationCollectionIteration] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
@@ -1414,7 +1415,7 @@ async def agenerate_chat_response(
compiled_references = []
online_results = {}
code_results = {}
operator_results = {}
operator_results = []
deepthought = True
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)

View File

@@ -14,6 +14,7 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
OperatorRun,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
@@ -248,7 +249,7 @@ async def execute_information_collection(
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: Dict[str, str] = {}
operator_results: OperatorRun = None
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
@@ -431,17 +432,17 @@ async def execute_information_collection(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
operator_results = {result["query"]: result["result"]}
elif isinstance(result, OperatorRun):
operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
if result.get("webpages"):
if result.webpages:
if not online_results.get(this_iteration.query):
online_results[this_iteration.query] = {"webpages": result["webpages"]}
online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
online_results[this_iteration.query]["webpages"] = result["webpages"]
online_results[this_iteration.query]["webpages"] = result.webpages
else:
online_results[this_iteration.query]["webpages"] += result["webpages"]
online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
@@ -489,7 +490,9 @@ async def execute_information_collection(
if code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if operator_results:
results_data += f"\n<browser_operator_results>\n{next(iter(operator_results.values()))}\n</browser_operator_results>"
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if summarize_files:
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
if this_iteration.warning: