mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Pass previous trajectory to operator agents for context
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation.utils import OperatorRun
|
||||
from khoj.processor.conversation.utils import OperatorRun, construct_chat_history
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
|
||||
from khoj.processor.operator.operator_agent_base import OperatorAgent
|
||||
@@ -32,6 +32,7 @@ async def operate_environment(
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
location_data: LocationData,
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None, # TODO: Handle query images
|
||||
@@ -50,13 +51,22 @@ async def operate_environment(
|
||||
if not reasoning_model:
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
|
||||
chat_history = construct_chat_history(conversation_log)
|
||||
query_with_history = (
|
||||
f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query
|
||||
) # Append chat history to query if available
|
||||
|
||||
# Initialize Agent
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
|
||||
operator_agent: OperatorAgent
|
||||
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
|
||||
operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
|
||||
operator_agent = OpenAIOperatorAgent(
|
||||
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
|
||||
)
|
||||
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
|
||||
operator_agent = AnthropicOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
|
||||
operator_agent = AnthropicOperatorAgent(
|
||||
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
|
||||
)
|
||||
else:
|
||||
grounding_model_name = "ui-tars-1.5"
|
||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||
@@ -67,7 +77,13 @@ async def operate_environment(
|
||||
):
|
||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||
operator_agent = BinaryOperatorAgent(
|
||||
query, reasoning_model, grounding_model, environment_type, max_iterations, tracer
|
||||
query_with_history,
|
||||
reasoning_model,
|
||||
grounding_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
)
|
||||
|
||||
# Initialize Environment
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.processor.operator.operator_actions import OperatorAction
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
@@ -25,7 +29,13 @@ class AgentActResult(BaseModel):
|
||||
|
||||
class OperatorAgent(ABC):
|
||||
def __init__(
|
||||
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict
|
||||
self,
|
||||
query: str,
|
||||
vision_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
self.query = query
|
||||
self.vision_model = vision_model
|
||||
@@ -35,6 +45,9 @@ class OperatorAgent(ABC):
|
||||
self.messages: List[AgentMessage] = []
|
||||
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
|
||||
|
||||
if previous_trajectory:
|
||||
self.messages = previous_trajectory.trajectory
|
||||
|
||||
# Context compression parameters
|
||||
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
|
||||
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.
|
||||
|
||||
@@ -2,10 +2,14 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import AgentMessage, construct_structured_message
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
construct_structured_message,
|
||||
)
|
||||
from khoj.processor.operator.grounding_agent import GroundingAgent
|
||||
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
@@ -36,10 +40,16 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
grounding_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
tracer: dict,
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
super().__init__(
|
||||
query, reasoning_model, environment_type, max_iterations, tracer
|
||||
query,
|
||||
reasoning_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
) # Use reasoning model for primary tracking
|
||||
self.reasoning_model = reasoning_model
|
||||
self.grounding_model = grounding_model
|
||||
|
||||
@@ -1294,6 +1294,7 @@ async def chat(
|
||||
user,
|
||||
meta_log,
|
||||
location,
|
||||
operator_results[-1] if operator_results else None,
|
||||
query_images=uploaded_images,
|
||||
query_files=attached_file_context,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
|
||||
@@ -423,6 +423,7 @@ async def research(
|
||||
user,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||
location,
|
||||
previous_iterations[-1].operatorContext if previous_iterations else None,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
|
||||
Reference in New Issue
Block a user