mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Continue interrupted operator run with new query and previous context
Track research and operator results at each nested iteration step using python object references + async events bubbled up from nested iterators. Instantiates operator with interrupted operator messages from research or normal mode. Reflects actual interaction trajectory as closely as possible to agent including conversation history, partial operator trajectory and new query for fine grained, corrigible steerability. Research mode continues with operator tool directly if previous iteration was an interrupted operator run.
This commit is contained in:
@@ -261,6 +261,27 @@ def construct_question_history(
|
||||
return history_parts
|
||||
|
||||
|
||||
def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]:
|
||||
"""
|
||||
Construct chat history for operator agent in conversation log.
|
||||
Only include last n completed turns (i.e with user and khoj message).
|
||||
"""
|
||||
chat_history: list[AgentMessage] = []
|
||||
user_message: Optional[AgentMessage] = None
|
||||
|
||||
for chat in conversation_history.get("chat", []):
|
||||
if len(chat_history) >= n:
|
||||
break
|
||||
if chat["by"] == "you" and chat.get("message"):
|
||||
content = [{"type": "text", "text": chat["message"]}]
|
||||
for file in chat.get("queryFiles", []):
|
||||
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
|
||||
user_message = AgentMessage(role="user", content=content)
|
||||
elif chat["by"] == "khoj" and chat.get("message"):
|
||||
chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])]
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
@@ -285,9 +306,6 @@ def construct_tool_chat_history(
|
||||
ConversationCommand.Code: (
|
||||
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
||||
),
|
||||
ConversationCommand.Operator: (
|
||||
lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else []
|
||||
),
|
||||
}
|
||||
for iteration in previous_iterations:
|
||||
# If a tool is provided use the inferred query extractor for that tool if available
|
||||
|
||||
@@ -6,7 +6,11 @@ from typing import Callable, List, Optional
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation.utils import OperatorRun, construct_chat_history
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
construct_chat_history,
|
||||
construct_chat_history_for_operator,
|
||||
)
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
|
||||
from khoj.processor.operator.operator_agent_base import OperatorAgent
|
||||
@@ -43,6 +47,10 @@ async def operate_environment(
|
||||
):
|
||||
response, user_input_message = None, None
|
||||
|
||||
# Only use partial previous trajectories to continue existing task
|
||||
if previous_trajectory and previous_trajectory.response:
|
||||
previous_trajectory = None
|
||||
|
||||
# Get the agent chat model
|
||||
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
|
||||
reasoning_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
|
||||
@@ -51,21 +59,19 @@ async def operate_environment(
|
||||
if not reasoning_model:
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
|
||||
chat_history = construct_chat_history(conversation_log)
|
||||
query_with_history = (
|
||||
f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query
|
||||
) # Append chat history to query if available
|
||||
# Create conversation history from conversation log
|
||||
chat_history = construct_chat_history_for_operator(conversation_log)
|
||||
|
||||
# Initialize Agent
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
|
||||
operator_agent: OperatorAgent
|
||||
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
|
||||
operator_agent = OpenAIOperatorAgent(
|
||||
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
|
||||
query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer
|
||||
)
|
||||
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
|
||||
operator_agent = AnthropicOperatorAgent(
|
||||
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
|
||||
query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer
|
||||
)
|
||||
else:
|
||||
grounding_model_name = "ui-tars-1.5"
|
||||
@@ -77,11 +83,12 @@ async def operate_environment(
|
||||
):
|
||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||
operator_agent = BinaryOperatorAgent(
|
||||
query_with_history,
|
||||
query,
|
||||
reasoning_model,
|
||||
grounding_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
chat_history,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
)
|
||||
@@ -100,6 +107,8 @@ async def operate_environment(
|
||||
try:
|
||||
task_completed = False
|
||||
iterations = 0
|
||||
operator_run = OperatorRun(query=query, trajectory=operator_agent.messages, response=response)
|
||||
yield operator_run
|
||||
|
||||
with timer(
|
||||
f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger
|
||||
@@ -159,30 +168,27 @@ async def operate_environment(
|
||||
|
||||
# 4. Update agent on the results of its action on the environment
|
||||
operator_agent.add_action_results(env_steps, agent_result)
|
||||
operator_run.trajectory = operator_agent.messages
|
||||
|
||||
# Determine final response message
|
||||
if user_input_message:
|
||||
response = user_input_message
|
||||
operator_run.response = user_input_message
|
||||
elif task_completed:
|
||||
response = summary_message
|
||||
operator_run.response = summary_message
|
||||
elif cancellation_event and cancellation_event.is_set():
|
||||
operator_run.response = None
|
||||
else: # Hit iteration limit
|
||||
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
|
||||
operator_run.response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
|
||||
finally:
|
||||
if environment and not user_input_message: # Don't close environment if user input required
|
||||
await environment.close()
|
||||
if operator_agent:
|
||||
operator_agent.reset()
|
||||
|
||||
webpages = []
|
||||
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
|
||||
webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
|
||||
operator_run.webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
|
||||
|
||||
yield OperatorRun(
|
||||
query=query,
|
||||
trajectory=operator_agent.messages,
|
||||
response=response,
|
||||
webpages=webpages,
|
||||
)
|
||||
yield operator_run
|
||||
|
||||
|
||||
def is_operator_model(model: str) -> ChatModel.ModelType | None:
|
||||
|
||||
@@ -34,6 +34,7 @@ class OperatorAgent(ABC):
|
||||
vision_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
chat_history: List[AgentMessage] = [],
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -42,11 +43,15 @@ class OperatorAgent(ABC):
|
||||
self.environment_type = environment_type
|
||||
self.max_iterations = max_iterations
|
||||
self.tracer = tracer
|
||||
self.messages: List[AgentMessage] = []
|
||||
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
|
||||
|
||||
self.messages: List[AgentMessage] = chat_history
|
||||
if previous_trajectory:
|
||||
self.messages = previous_trajectory.trajectory
|
||||
# Remove tool call from previous trajectory as tool call w/o result not supported
|
||||
if previous_trajectory.trajectory and previous_trajectory.trajectory[-1].role == "assistant":
|
||||
previous_trajectory.trajectory.pop()
|
||||
self.messages += previous_trajectory.trajectory
|
||||
self.messages += [AgentMessage(role="user", content=query)]
|
||||
|
||||
# Context compression parameters
|
||||
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
|
||||
|
||||
@@ -40,6 +40,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
grounding_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
chat_history: List[AgentMessage] = [],
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -48,6 +49,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
reasoning_model,
|
||||
environment_type,
|
||||
max_iterations,
|
||||
chat_history,
|
||||
previous_trajectory,
|
||||
tracer,
|
||||
) # Use reasoning model for primary tracking
|
||||
|
||||
@@ -1011,7 +1011,7 @@ async def chat(
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
previous_iterations=research_results,
|
||||
previous_iterations=list(research_results),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1030,13 +1030,26 @@ async def chat(
|
||||
code_results.update(research_result.codeContext)
|
||||
if research_result.context:
|
||||
compiled_references.extend(research_result.context)
|
||||
if research_result.operatorContext:
|
||||
operator_results.append(research_result.operatorContext)
|
||||
if not research_results or research_results[-1] is not research_result:
|
||||
research_results.append(research_result)
|
||||
|
||||
else:
|
||||
yield research_result
|
||||
|
||||
# Track operator results across research and operator iterations
|
||||
# This relies on two conditions:
|
||||
# 1. Check to append new (partial) operator results
|
||||
# Relies on triggering this check on every status updates.
|
||||
# Status updates cascade up from operator to research to chat api on every step.
|
||||
# 2. Keep operator results in sync with each research operator step
|
||||
# Relies on python object references to ensure operator results
|
||||
# are implicitly kept in sync after the initial append
|
||||
if (
|
||||
research_results
|
||||
and research_results[-1].operatorContext
|
||||
and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
|
||||
):
|
||||
operator_results.append(research_results[-1].operatorContext)
|
||||
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
if state.verbose > 1:
|
||||
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
|
||||
@@ -1294,7 +1307,7 @@ async def chat(
|
||||
user,
|
||||
meta_log,
|
||||
location,
|
||||
operator_results[-1] if operator_results else None,
|
||||
list(operator_results)[-1] if operator_results else None,
|
||||
query_images=uploaded_images,
|
||||
query_files=attached_file_context,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1305,7 +1318,8 @@ async def chat(
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif isinstance(result, OperatorRun):
|
||||
operator_results.append(result)
|
||||
if not operator_results or operator_results[-1] is not result:
|
||||
operator_results.append(result)
|
||||
# Add webpages visited while operating browser to references
|
||||
if result.webpages:
|
||||
if not online_results.get(defiltered_query):
|
||||
|
||||
@@ -96,6 +96,24 @@ async def apick_next_tool(
|
||||
):
|
||||
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
|
||||
|
||||
# Continue with previous iteration if a multi-step tool use is in progress
|
||||
if (
|
||||
previous_iterations
|
||||
and previous_iterations[-1].tool == ConversationCommand.Operator
|
||||
and not previous_iterations[-1].summarizedResult
|
||||
):
|
||||
previous_iteration = previous_iterations[-1]
|
||||
yield ResearchIteration(
|
||||
tool=previous_iteration.tool,
|
||||
query=query,
|
||||
context=previous_iteration.context,
|
||||
onlineContext=previous_iteration.onlineContext,
|
||||
codeContext=previous_iteration.codeContext,
|
||||
operatorContext=previous_iteration.operatorContext,
|
||||
warning=previous_iteration.warning,
|
||||
)
|
||||
return
|
||||
|
||||
# Construct tool options for the agent to choose from
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
@@ -274,6 +292,7 @@ async def research(
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif isinstance(result, ResearchIteration):
|
||||
this_iteration = result
|
||||
yield this_iteration
|
||||
|
||||
# Skip running iteration if warning present in iteration
|
||||
if this_iteration.warning:
|
||||
|
||||
Reference in New Issue
Block a user