diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 34538df0..2c5f0ada 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -107,7 +107,7 @@ class ChatMessage(PydanticBaseModel): onlineContext: Dict[str, OnlineContext] = {} codeContext: Dict[str, CodeContextData] = {} researchContext: Optional[List] = None - operatorContext: Optional[Dict[str, str]] = None + operatorContext: Optional[List] = None created: str images: Optional[List[str]] = None queryFiles: Optional[List[Dict]] = None diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index e2de2c59..2e52a9f2 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import ( format_messages_for_anthropic, ) from khoj.processor.conversation.utils import ( + OperatorRun, ResponseWithThought, clean_json, construct_structured_message, @@ -144,7 +145,7 @@ async def converse_anthropic( references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, str]] = None, + operator_results: Optional[List[OperatorRun]] = None, conversation_log={}, model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, @@ -216,8 +217,11 @@ async def converse_anthropic( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n" + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" ) context_message = context_message.strip() diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 6cbe87c8..17bb9b76 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.google.utils import ( gemini_completion_with_backoff, ) from khoj.processor.conversation.utils import ( + OperatorRun, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -166,7 +167,7 @@ async def converse_gemini( references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, str]] = None, + operator_results: Optional[List[OperatorRun]] = None, conversation_log={}, model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, @@ -240,8 +241,11 @@ async def converse_gemini( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) if ConversationCommand.Operator in conversation_commands and not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n" + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" ) context_message = context_message.strip() diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 28739843..d5fcd0a1 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import ( ) from khoj.processor.conversation.utils import ( JsonSupport, + OperatorRun, ResponseWithThought, clean_json, construct_structured_message, @@ -169,7 +170,7 @@ async def converse_openai( references: list[dict], online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, str]] = None, + operator_results: Optional[List[OperatorRun]] = None, conversation_log={}, model: str = "gpt-4o-mini", api_key: Optional[str] = None, @@ -242,8 +243,11 @@ async def converse_openai( f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" ) if not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_results))}\n\n" + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" ) context_message = context_message.strip() diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 36aa001d..899a6786 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Union import PIL.Image import pyjson5 @@ -20,6 +20,7 @@ import yaml from langchain_core.messages.chat import ChatMessage from llama_cpp import LlamaTokenizer from llama_cpp.llama import Llama +from pydantic import BaseModel from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters @@ -87,6 +88,48 @@ model_to_prompt_size = { model_to_tokenizer: Dict[str, str] = {} +class AgentMessage(BaseModel): + role: Literal["user", "assistant", "system", "environment"] + content: Union[str, List] + + +class OperatorRun: + def __init__( + self, + query: str, + trajectory: list[AgentMessage | dict] = None, + response: str = None, + webpages: list[dict] = None, + ): + self.query = query + self.response = response + self.webpages = webpages or [] + self.trajectory: list[AgentMessage] = [] + if trajectory: + for item in trajectory: + if isinstance(item, dict): + self.trajectory.append(AgentMessage(**item)) + elif hasattr(item, "role") and hasattr(item, "content"): # Heuristic for AgentMessage like object + self.trajectory.append(item) + else: + logger.warning(f"Unexpected item type in trajectory: {type(item)}") + + def to_dict(self) -> dict: + # Ensure AgentMessage instances in trajectory are also dicts + serialized_trajectory = [] + for msg in self.trajectory: + if hasattr(msg, "model_dump"): # Check if it's a Pydantic model + serialized_trajectory.append(msg.model_dump()) + elif isinstance(msg, dict): + serialized_trajectory.append(msg) # Already a dict + return { + "query": self.query, + "response": self.response, + "trajectory": serialized_trajectory, + "webpages": self.webpages, + } + + class InformationCollectionIteration: def __init__( self, @@ -95,7 +138,7 @@ class InformationCollectionIteration: context: list = None, onlineContext: dict = None, codeContext: dict = None, - operatorContext: dict[str, str] = None, + operatorContext: dict = None, summarizedResult: str = None, warning: str = None, ): @@ -104,10 +147,17 @@ class InformationCollectionIteration: self.context = context self.onlineContext = onlineContext self.codeContext = codeContext - self.operatorContext = operatorContext + self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else None self.summarizedResult = summarizedResult self.warning = warning + def to_dict(self) -> dict: + data = vars(self).copy() + data["operatorContext"] = ( + self.operatorContext.to_dict() if isinstance(self.operatorContext, OperatorRun) else None + ) + return data + def construct_iteration_history( previous_iterations: List[InformationCollectionIteration], @@ -193,7 +243,7 @@ def construct_tool_chat_history( lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] ), ConversationCommand.Operator: ( - lambda iteration: list(iteration.operatorContext.keys()) if iteration.operatorContext else [] + lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else [] ), } for iteration in previous_iterations: @@ -273,7 +323,7 @@ async def save_to_conversation_log( compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, code_results: Dict[str, Any] = {}, - operator_results: Dict[str, str] = {}, + operator_results: List[OperatorRun] = None, inferred_queries: List[str] = [], intent_type: str = "remember", client_application: ClientApplication = None, @@ -301,8 +351,8 @@ async def save_to_conversation_log( "intent": {"inferred-queries": inferred_queries, "type": intent_type}, "onlineContext": online_results, "codeContext": code_results, - "operatorContext": operator_results, - "researchContext": [vars(r) for r in research_results] if research_results and not chat_response else None, + "operatorContext": [o.to_dict() for o in operator_results] if operator_results and not chat_response else None, + "researchContext": [r.to_dict() for r in research_results] if research_results and not chat_response else None, "automationId": automation_id, "trainOfThought": train_of_thought, "turnId": turn_id, @@ -459,10 +509,12 @@ def generate_chatml_messages_with_context( ] if not is_none_or_empty(chat.get("operatorContext")): + operator_context = chat.get("operatorContext") + operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context]) message_context += [ { "type": "text", - "text": f"{prompts.operator_execution_context.format(operator_results=chat.get('operatorContext'))}", + "text": f"{prompts.operator_execution_context.format(operator_results=operator_content)}", } ] diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index b2ea846f..e38ae777 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.processor.conversation.utils import OperatorRun from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent @@ -160,11 +161,12 @@ async def operate_environment( if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"): webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls] - yield { - "query": query, - "result": user_input_message or response, - "webpages": webpages, - } + yield OperatorRun( + query=query, + trajectory=operator_agent.messages, + response=response, + webpages=webpages, + ) def is_operator_model(model: str) -> ChatModel.ModelType | None: diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index 567411f6..2b98d50d 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -10,12 +10,9 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo from khoj.database.models import ChatModel from khoj.processor.conversation.anthropic.utils import is_reasoning_model +from khoj.processor.conversation.utils import AgentMessage from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import ( - AgentActResult, - AgentMessage, - OperatorAgent, -) +from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, EnvState, diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 8c273b5e..1aa0f238 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union from pydantic import BaseModel from khoj.database.models import ChatModel -from khoj.processor.conversation.utils import commit_conversation_trace +from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace from khoj.processor.operator.operator_actions import OperatorAction from khoj.processor.operator.operator_environment_base import ( EnvironmentType, @@ -23,11 +23,6 @@ class AgentActResult(BaseModel): rendered_response: Optional[dict] = None -class AgentMessage(BaseModel): - role: Literal["user", "assistant", "system", "environment"] - content: Union[str, List] - - class OperatorAgent(ABC): def __init__( self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index b869e8bb..77e28442 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -5,15 +5,11 @@ from textwrap import dedent from typing import List from khoj.database.models import ChatModel -from khoj.processor.conversation.utils import construct_structured_message +from khoj.processor.conversation.utils import AgentMessage, construct_structured_message from khoj.processor.operator.grounding_agent import GroundingAgent from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import ( - AgentActResult, - AgentMessage, - OperatorAgent, -) +from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, EnvState, diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py index 2b665f4d..7ae71bff 100644 --- a/src/khoj/processor/operator/operator_agent_openai.py +++ b/src/khoj/processor/operator/operator_agent_openai.py @@ -8,12 +8,9 @@ from typing import List, Optional, cast from openai.types.responses import Response, ResponseOutputItem +from khoj.processor.conversation.utils import AgentMessage from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import ( - AgentActResult, - AgentMessage, - OperatorAgent, -) +from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, EnvState, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 39e8c1a1..2cade5bc 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -26,6 +26,7 @@ from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import ( + OperatorRun, ResponseWithThought, defilter_query, save_to_conversation_log, @@ -725,7 +726,7 @@ async def chat( research_results: List[InformationCollectionIteration] = [] online_results: Dict = dict() code_results: Dict = dict() - operator_results: Dict[str, str] = {} + operator_results: List[OperatorRun] = [] compiled_references: List[Any] = [] inferred_queries: List[Any] = [] attached_file_context = gather_raw_query_files(query_files) @@ -960,11 +961,12 @@ async def chat( last_message = conversation.messages[-1] online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} - operator_results = last_message.operatorContext or {} compiled_references = [ref.model_dump() for ref in last_message.context or []] research_results = [ InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or [] ] + operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] + train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] # Drop the interrupted message from conversation history meta_log["chat"].pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") @@ -1034,7 +1036,7 @@ async def chat( if research_result.context: compiled_references.extend(research_result.context) if research_result.operatorContext: - operator_results.update(research_result.operatorContext) + operator_results.append(research_result.operatorContext) research_results.append(research_result) else: @@ -1306,16 +1308,16 @@ async def chat( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - operator_results = {result["query"]: result["result"]} + elif isinstance(result, OperatorRun): + operator_results.append(result) # Add webpages visited while operating browser to references - if result.get("webpages"): + if result.webpages: if not online_results.get(defiltered_query): - online_results[defiltered_query] = {"webpages": result["webpages"]} + online_results[defiltered_query] = {"webpages": result.webpages} elif not online_results[defiltered_query].get("webpages"): - online_results[defiltered_query]["webpages"] = result["webpages"] + online_results[defiltered_query]["webpages"] = result.webpages else: - online_results[defiltered_query]["webpages"] += result["webpages"] + online_results[defiltered_query]["webpages"] += result.webpages except ValueError as e: program_execution_context.append(f"Browser operation error: {e}") logger.warning(f"Failed to operate browser with {e}", exc_info=True) @@ -1333,7 +1335,6 @@ async def chat( "context": compiled_references, "onlineContext": unique_online_results, "codeContext": code_results, - "operatorContext": operator_results, }, ): yield result diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c1ddb82d..a3322b67 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -95,6 +95,7 @@ from khoj.processor.conversation.openai.gpt import ( from khoj.processor.conversation.utils import ( ChatEvent, InformationCollectionIteration, + OperatorRun, ResponseWithThought, clean_json, clean_mermaidjs, @@ -1355,7 +1356,7 @@ async def agenerate_chat_response( compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {}, - operator_results: Dict[str, str] = {}, + operator_results: List[OperatorRun] = [], research_results: List[InformationCollectionIteration] = [], inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], @@ -1414,7 +1415,7 @@ async def agenerate_chat_response( compiled_references = [] online_results = {} code_results = {} - operator_results = {} + operator_results = [] deepthought = True chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index de55e47b..ab0c7062 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -14,6 +14,7 @@ from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( InformationCollectionIteration, + OperatorRun, construct_iteration_history, construct_tool_chat_history, load_complex_json, @@ -248,7 +249,7 @@ async def execute_information_collection( online_results: Dict = dict() code_results: Dict = dict() document_results: List[Dict[str, str]] = [] - operator_results: Dict[str, str] = {} + operator_results: OperatorRun = None summarize_files: str = "" this_iteration = InformationCollectionIteration(tool=None, query=query) @@ -431,17 +432,17 @@ async def execute_information_collection( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - operator_results = {result["query"]: result["result"]} + elif isinstance(result, OperatorRun): + operator_results = result this_iteration.operatorContext = operator_results # Add webpages visited while operating browser to references - if result.get("webpages"): + if result.webpages: if not online_results.get(this_iteration.query): - online_results[this_iteration.query] = {"webpages": result["webpages"]} + online_results[this_iteration.query] = {"webpages": result.webpages} elif not online_results[this_iteration.query].get("webpages"): - online_results[this_iteration.query]["webpages"] = result["webpages"] + online_results[this_iteration.query]["webpages"] = result.webpages else: - online_results[this_iteration.query]["webpages"] += result["webpages"] + online_results[this_iteration.query]["webpages"] += result.webpages this_iteration.onlineContext = online_results except Exception as e: this_iteration.warning = f"Error operating browser: {e}" @@ -489,7 +490,9 @@ async def execute_information_collection( if code_results: results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if operator_results: - results_data += f"\n\n{next(iter(operator_results.values()))}\n" + results_data += ( + f"\n\n{operator_results.response}\n" + ) if summarize_files: results_data += f"\n\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if this_iteration.warning: