diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 34538df0..2c5f0ada 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -107,7 +107,7 @@ class ChatMessage(PydanticBaseModel):
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
researchContext: Optional[List] = None
- operatorContext: Optional[Dict[str, str]] = None
+ operatorContext: Optional[List] = None
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py
index e2de2c59..2e52a9f2 100644
--- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py
+++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py
@@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
+ OperatorRun,
ResponseWithThought,
clean_json,
construct_structured_message,
@@ -144,7 +145,7 @@ async def converse_anthropic(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
- operator_results: Optional[Dict[str, str]] = None,
+ operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
@@ -216,8 +217,11 @@ async def converse_anthropic(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
+ operator_content = [
+ {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
+ ]
context_message += (
- f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
+ f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py
index 6cbe87c8..17bb9b76 100644
--- a/src/khoj/processor/conversation/google/gemini_chat.py
+++ b/src/khoj/processor/conversation/google/gemini_chat.py
@@ -14,6 +14,7 @@ from khoj.processor.conversation.google.utils import (
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import (
+ OperatorRun,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -166,7 +167,7 @@ async def converse_gemini(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
- operator_results: Optional[Dict[str, str]] = None,
+ operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
@@ -240,8 +241,11 @@ async def converse_gemini(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results):
+ operator_content = [
+ {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
+ ]
context_message += (
- f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
+ f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 28739843..d5fcd0a1 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
+ OperatorRun,
ResponseWithThought,
clean_json,
construct_structured_message,
@@ -169,7 +170,7 @@ async def converse_openai(
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
- operator_results: Optional[Dict[str, str]] = None,
+ operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
@@ -242,8 +243,11 @@ async def converse_openai(
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
+ operator_content = [
+ {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
+ ]
context_message += (
- f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n"
+ f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 36aa001d..899a6786 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -10,7 +10,7 @@ from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Literal, Optional, Union
import PIL.Image
import pyjson5
@@ -20,6 +20,7 @@ import yaml
from langchain_core.messages.chat import ChatMessage
from llama_cpp import LlamaTokenizer
from llama_cpp.llama import Llama
+from pydantic import BaseModel
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters
@@ -87,6 +88,48 @@ model_to_prompt_size = {
model_to_tokenizer: Dict[str, str] = {}
+class AgentMessage(BaseModel):
+ role: Literal["user", "assistant", "system", "environment"]
+ content: Union[str, List]
+
+
+class OperatorRun:
+ def __init__(
+ self,
+ query: str,
+ trajectory: list[AgentMessage | dict] = None,
+ response: str = None,
+ webpages: list[dict] = None,
+ ):
+ self.query = query
+ self.response = response
+ self.webpages = webpages or []
+ self.trajectory: list[AgentMessage] = []
+ if trajectory:
+ for item in trajectory:
+ if isinstance(item, dict):
+ self.trajectory.append(AgentMessage(**item))
+ elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object
+ self.trajectory.append(item)
+ else:
+ logger.warning(f"Unexpected item type in trajectory: {type(item)}")
+
+ def to_dict(self) -> dict:
+ # Ensure AgentMessage instances in trajectory are also dicts
+ serialized_trajectory = []
+ for msg in self.trajectory:
+ if hasattr(msg, "model_dump"): # Check if it's a Pydantic model
+ serialized_trajectory.append(msg.model_dump())
+ elif isinstance(msg, dict):
+ serialized_trajectory.append(msg) # Already a dict
+ return {
+ "query": self.query,
+ "response": self.response,
+ "trajectory": serialized_trajectory,
+ "webpages": self.webpages,
+ }
+
+
class InformationCollectionIteration:
def __init__(
self,
@@ -95,7 +138,7 @@ class InformationCollectionIteration:
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
- operatorContext: dict[str, str] = None,
+ operatorContext: dict = None,
summarizedResult: str = None,
warning: str = None,
):
@@ -104,10 +147,17 @@ class InformationCollectionIteration:
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
- self.operatorContext = operatorContext
+ self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None
self.summarizedResult = summarizedResult
self.warning = warning
+ def to_dict(self) -> dict:
+ data = vars(self).copy()
+ data["operatorContext"] = (
+ self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None
+ )
+ return data
+
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration],
@@ -193,7 +243,7 @@ def construct_tool_chat_history(
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
ConversationCommand.Operator: (
- lambda iteration: list(iteration.operatorContext.keys()) if iteration.operatorContext else []
+ lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else []
),
}
for iteration in previous_iterations:
@@ -273,7 +323,7 @@ async def save_to_conversation_log(
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
code_results: Dict[str, Any] = {},
- operator_results: Dict[str, str] = {},
+ operator_results: List[OperatorRun] = None,
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
@@ -301,8 +351,8 @@ async def save_to_conversation_log(
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
- "operatorContext": operator_results,
- "researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None,
+ "operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None,
+ "researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
@@ -459,10 +509,12 @@ def generate_chatml_messages_with_context(
]
if not is_none_or_empty(chat.get("operatorContext")):
+ operator_context = chat.get("operatorContext")
+ operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
message_context += [
{
"type": "text",
- "text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}",
+ "text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}",
}
]
diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py
index b2ea846f..e38ae777 100644
--- a/src/khoj/processor/operator/__init__.py
+++ b/src/khoj/processor/operator/__init__.py
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser
+from khoj.processor.conversation.utils import OperatorRun
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent
@@ -160,11 +161,12 @@ async def operate_environment(
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
- yield {
- "query": query,
- "result": user_input_message or response,
- "webpages": webpages,
- }
+ yield OperatorRun(
+ query=query,
+ trajectory=operator_agent.messages,
+ response=response,
+ webpages=webpages,
+ )
def is_operator_model(model: str) -> ChatModel.ModelType | None:
diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py
index 567411f6..2b98d50d 100644
--- a/src/khoj/processor/operator/operator_agent_anthropic.py
+++ b/src/khoj/processor/operator/operator_agent_anthropic.py
@@ -10,12 +10,9 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo
from khoj.database.models import ChatModel
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
+from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
-from khoj.processor.operator.operator_agent_base import (
- AgentActResult,
- AgentMessage,
- OperatorAgent,
-)
+from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py
index 8c273b5e..1aa0f238 100644
--- a/src/khoj/processor/operator/operator_agent_base.py
+++ b/src/khoj/processor/operator/operator_agent_base.py
@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel
from khoj.database.models import ChatModel
-from khoj.processor.conversation.utils import commit_conversation_trace
+from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace
from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
@@ -23,11 +23,6 @@ class AgentActResult(BaseModel):
rendered_response: Optional[dict] = None
-class AgentMessage(BaseModel):
- role: Literal["user", "assistant", "system", "environment"]
- content: Union[str, List]
-
-
class OperatorAgent(ABC):
def __init__(
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict
diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py
index b869e8bb..77e28442 100644
--- a/src/khoj/processor/operator/operator_agent_binary.py
+++ b/src/khoj/processor/operator/operator_agent_binary.py
@@ -5,15 +5,11 @@ from textwrap import dedent
from typing import List
from khoj.database.models import ChatModel
-from khoj.processor.conversation.utils import construct_structured_message
+from khoj.processor.conversation.utils import AgentMessage, construct_structured_message
from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import *
-from khoj.processor.operator.operator_agent_base import (
- AgentActResult,
- AgentMessage,
- OperatorAgent,
-)
+from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py
index 2b665f4d..7ae71bff 100644
--- a/src/khoj/processor/operator/operator_agent_openai.py
+++ b/src/khoj/processor/operator/operator_agent_openai.py
@@ -8,12 +8,9 @@ from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem
+from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
-from khoj.processor.operator.operator_agent_base import (
- AgentActResult,
- AgentMessage,
- OperatorAgent,
-)
+from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 39e8c1a1..2cade5bc 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -26,6 +26,7 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import (
+ OperatorRun,
ResponseWithThought,
defilter_query,
save_to_conversation_log,
@@ -725,7 +726,7 @@ async def chat(
research_results: List[InformationCollectionIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
- operator_results: Dict[str, str] = {}
+ operator_results: List[OperatorRun] = []
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
attached_file_context = gather_raw_query_files(query_files)
@@ -960,11 +961,12 @@ async def chat(
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
- operator_results = last_message.operatorContext or {}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
+ operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
+ train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
@@ -1034,7 +1036,7 @@ async def chat(
if research_result.context:
compiled_references.extend(research_result.context)
if research_result.operatorContext:
- operator_results.update(research_result.operatorContext)
+ operator_results.append(research_result.operatorContext)
research_results.append(research_result)
else:
@@ -1306,16 +1308,16 @@ async def chat(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
- else:
- operator_results = {result["query"]: result["result"]}
+ elif isinstance(result, OperatorRun):
+ operator_results.append(result)
# Add webpages visited while operating browser to references
- if result.get("webpages"):
+ if result.webpages:
if not online_results.get(defiltered_query):
- online_results[defiltered_query] = {"webpages": result["webpages"]}
+ online_results[defiltered_query] = {"webpages": result.webpages}
elif not online_results[defiltered_query].get("webpages"):
- online_results[defiltered_query]["webpages"] = result["webpages"]
+ online_results[defiltered_query]["webpages"] = result.webpages
else:
- online_results[defiltered_query]["webpages"] += result["webpages"]
+ online_results[defiltered_query]["webpages"] += result.webpages
except ValueError as e:
program_execution_context.append(f"Browser operation error: {e}")
logger.warning(f"Failed to operate browser with {e}", exc_info=True)
@@ -1333,7 +1335,6 @@ async def chat(
"context": compiled_references,
"onlineContext": unique_online_results,
"codeContext": code_results,
- "operatorContext": operator_results,
},
):
yield result
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index c1ddb82d..a3322b67 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -95,6 +95,7 @@ from khoj.processor.conversation.openai.gpt import (
from khoj.processor.conversation.utils import (
ChatEvent,
InformationCollectionIteration,
+ OperatorRun,
ResponseWithThought,
clean_json,
clean_mermaidjs,
@@ -1355,7 +1356,7 @@ async def agenerate_chat_response(
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
- operator_results: Dict[str, str] = {},
+ operator_results: List[OperatorRun] = [],
research_results: List[InformationCollectionIteration] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
@@ -1414,7 +1415,7 @@ async def agenerate_chat_response(
compiled_references = []
online_results = {}
code_results = {}
- operator_results = {}
+ operator_results = []
deepthought = True
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py
index de55e47b..ab0c7062 100644
--- a/src/khoj/routers/research.py
+++ b/src/khoj/routers/research.py
@@ -14,6 +14,7 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
+ OperatorRun,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
@@ -248,7 +249,7 @@ async def execute_information_collection(
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
- operator_results: Dict[str, str] = {}
+ operator_results: OperatorRun = None
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
@@ -431,17 +432,17 @@ async def execute_information_collection(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
- else:
- operator_results = {result["query"]: result["result"]}
+ elif isinstance(result, OperatorRun):
+ operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
- if result.get("webpages"):
+ if result.webpages:
if not online_results.get(this_iteration.query):
- online_results[this_iteration.query] = {"webpages": result["webpages"]}
+ online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
- online_results[this_iteration.query]["webpages"] = result["webpages"]
+ online_results[this_iteration.query]["webpages"] = result.webpages
else:
- online_results[this_iteration.query]["webpages"] += result["webpages"]
+ online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
@@ -489,7 +490,9 @@ async def execute_information_collection(
if code_results:
results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if operator_results:
- results_data += f"\n\n{next(iter(operator_results.values()))}\n"
+ results_data += (
+ f"\n\n{operator_results.response}\n"
+ )
if summarize_files:
results_data += f"\n\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if this_iteration.warning: