mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Execute tool calls in parallel to make research iterations faster
This commit is contained in:
@@ -48,6 +48,261 @@ from khoj.utils.rawconfig import LocationData
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutionResult:
|
||||||
|
"""Result of executing a single tool call"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.status_messages: List = []
|
||||||
|
self.document_results: List[Dict[str, str]] = []
|
||||||
|
self.online_results: Dict = {}
|
||||||
|
self.code_results: Dict = {}
|
||||||
|
self.operator_results: OperatorRun = None
|
||||||
|
self.mcp_results: List = []
|
||||||
|
self.should_terminate: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_tool(
|
||||||
|
iteration: ResearchIteration,
|
||||||
|
user: KhojUser,
|
||||||
|
conversation_id: str,
|
||||||
|
previous_iterations: List[ResearchIteration],
|
||||||
|
location: LocationData,
|
||||||
|
query_images: List[str],
|
||||||
|
query_files: str,
|
||||||
|
max_document_searches: int,
|
||||||
|
max_online_searches: int,
|
||||||
|
mcp_clients: List[MCPClient],
|
||||||
|
cancellation_event: Optional[asyncio.Event],
|
||||||
|
interrupt_queue: Optional[asyncio.Queue],
|
||||||
|
agent: Agent,
|
||||||
|
tracer: dict,
|
||||||
|
) -> ToolExecutionResult:
|
||||||
|
"""Execute a single tool call and return results. Designed for parallel execution."""
|
||||||
|
result = ToolExecutionResult()
|
||||||
|
|
||||||
|
# Skip if warning present
|
||||||
|
if iteration.warning:
|
||||||
|
logger.warning(f"Research mode: {iteration.warning}.")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Check for termination conditions
|
||||||
|
if not iteration.query or isinstance(iteration.query, str) or iteration.query.name == ConversationCommand.Text:
|
||||||
|
result.should_terminate = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Create a status collector that captures messages for later batch yield
|
||||||
|
async def status_collector(message: str):
|
||||||
|
"""Async generator that collects status messages instead of streaming them."""
|
||||||
|
# Just collect the message - we'll process it later in the main loop
|
||||||
|
result.status_messages.append(message)
|
||||||
|
yield message # Yield to satisfy async generator protocol expected by tool functions
|
||||||
|
|
||||||
|
try:
|
||||||
|
if iteration.query.name == ConversationCommand.SemanticSearchFiles:
|
||||||
|
iteration.context = []
|
||||||
|
previous_inferred_queries = {
|
||||||
|
c["query"] for iter in previous_iterations if iter.context for c in iter.context
|
||||||
|
}
|
||||||
|
async for res in search_documents(
|
||||||
|
**iteration.query.args,
|
||||||
|
n=max_document_searches,
|
||||||
|
d=None,
|
||||||
|
user=user,
|
||||||
|
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
conversation_commands=[ConversationCommand.Notes],
|
||||||
|
location_data=location,
|
||||||
|
send_status_func=status_collector,
|
||||||
|
query_images=query_images,
|
||||||
|
query_files=query_files,
|
||||||
|
previous_inferred_queries=previous_inferred_queries,
|
||||||
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if isinstance(res, tuple):
|
||||||
|
result.document_results = res[0]
|
||||||
|
iteration.context += result.document_results
|
||||||
|
|
||||||
|
if not is_none_or_empty(result.document_results):
|
||||||
|
try:
|
||||||
|
distinct_files = {d["file"] for d in result.document_results}
|
||||||
|
distinct_headings = set(
|
||||||
|
[d["compiled"].split("\n")[0] for d in result.document_results if "compiled" in d]
|
||||||
|
)
|
||||||
|
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
|
||||||
|
async for _ in status_collector(
|
||||||
|
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
iteration.warning = f"Error extracting document references: {e}"
|
||||||
|
logger.error(iteration.warning, exc_info=True)
|
||||||
|
else:
|
||||||
|
iteration.warning = "No matching document references found"
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.SearchWeb:
|
||||||
|
previous_subqueries = {
|
||||||
|
subquery for iter in previous_iterations if iter.onlineContext for subquery in iter.onlineContext.keys()
|
||||||
|
}
|
||||||
|
async for res in search_online(
|
||||||
|
**iteration.query.args,
|
||||||
|
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SearchWeb),
|
||||||
|
location=location,
|
||||||
|
user=user,
|
||||||
|
send_status_func=status_collector,
|
||||||
|
custom_filters=[],
|
||||||
|
max_online_searches=max_online_searches,
|
||||||
|
max_webpages_to_read=0,
|
||||||
|
query_images=query_images,
|
||||||
|
previous_subqueries=previous_subqueries,
|
||||||
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if is_none_or_empty(res):
|
||||||
|
iteration.warning = (
|
||||||
|
"Detected previously run online search queries. Skipping iteration. Try something different."
|
||||||
|
)
|
||||||
|
elif not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||||
|
result.online_results = res
|
||||||
|
iteration.onlineContext = result.online_results
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.ReadWebpage:
|
||||||
|
async for res in read_webpages_content(
|
||||||
|
**iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
send_status_func=status_collector,
|
||||||
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||||
|
direct_web_pages: Dict[str, Dict] = res
|
||||||
|
for web_query in direct_web_pages:
|
||||||
|
if result.online_results.get(web_query):
|
||||||
|
result.online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
|
||||||
|
else:
|
||||||
|
result.online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
|
||||||
|
iteration.onlineContext = result.online_results
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.PythonCoder:
|
||||||
|
async for res in run_code(
|
||||||
|
**iteration.query.args,
|
||||||
|
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.PythonCoder),
|
||||||
|
context="",
|
||||||
|
location_data=location,
|
||||||
|
user=user,
|
||||||
|
send_status_func=status_collector,
|
||||||
|
query_images=query_images,
|
||||||
|
query_files=query_files,
|
||||||
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||||
|
result.code_results = res
|
||||||
|
iteration.codeContext = result.code_results
|
||||||
|
if iteration.codeContext:
|
||||||
|
async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.OperateComputer:
|
||||||
|
async for res in operate_environment(
|
||||||
|
**iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||||
|
location_data=location,
|
||||||
|
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
|
||||||
|
send_status_func=status_collector,
|
||||||
|
query_images=query_images,
|
||||||
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
|
cancellation_event=cancellation_event,
|
||||||
|
interrupt_queue=interrupt_queue,
|
||||||
|
tracer=tracer,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if isinstance(res, OperatorRun):
|
||||||
|
result.operator_results = res
|
||||||
|
iteration.operatorContext = result.operator_results
|
||||||
|
if res.webpages:
|
||||||
|
if not result.online_results.get(iteration.query):
|
||||||
|
result.online_results[iteration.query] = {"webpages": res.webpages}
|
||||||
|
elif not result.online_results[iteration.query].get("webpages"):
|
||||||
|
result.online_results[iteration.query]["webpages"] = res.webpages
|
||||||
|
else:
|
||||||
|
result.online_results[iteration.query]["webpages"] += res.webpages
|
||||||
|
iteration.onlineContext = result.online_results
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.ViewFile:
|
||||||
|
async for res in view_file_content(
|
||||||
|
**iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||||
|
if iteration.context is None:
|
||||||
|
iteration.context = []
|
||||||
|
result.document_results = res
|
||||||
|
iteration.context += result.document_results
|
||||||
|
async for _ in status_collector(f"**Viewed file**: {iteration.query.args['path']}"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.ListFiles:
|
||||||
|
async for res in list_files(
|
||||||
|
**iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||||
|
if iteration.context is None:
|
||||||
|
iteration.context = []
|
||||||
|
result.document_results = [res]
|
||||||
|
iteration.context += result.document_results
|
||||||
|
if result.document_results:
|
||||||
|
async for _ in status_collector(result.document_results[-1].get("query", "Listed files")):
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif iteration.query.name == ConversationCommand.RegexSearchFiles:
|
||||||
|
async for res in grep_files(
|
||||||
|
**iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
):
|
||||||
|
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||||
|
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||||
|
if iteration.context is None:
|
||||||
|
iteration.context = []
|
||||||
|
result.document_results = [res]
|
||||||
|
iteration.context += result.document_results
|
||||||
|
if result.document_results:
|
||||||
|
async for _ in status_collector(result.document_results[-1].get("query", "Searched files")):
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif "/" in iteration.query.name:
|
||||||
|
server_name, tool_name = iteration.query.name.split("/", 1)
|
||||||
|
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
|
||||||
|
if not mcp_client:
|
||||||
|
raise ValueError(f"Could not find MCP server with name {server_name}")
|
||||||
|
|
||||||
|
result.mcp_results = await mcp_client.run_tool(tool_name, iteration.query.args)
|
||||||
|
if iteration.context is None:
|
||||||
|
iteration.context = []
|
||||||
|
iteration.context += result.mcp_results
|
||||||
|
async for _ in status_collector(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
result.should_terminate = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
tool_name = iteration.query.name if iteration.query else "unknown"
|
||||||
|
iteration.warning = f"Error executing {tool_name}: {e}"
|
||||||
|
logger.error(iteration.warning, exc_info=True)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def apick_next_tool(
|
async def apick_next_tool(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: List[ChatMessageModel],
|
conversation_history: List[ChatMessageModel],
|
||||||
@@ -331,316 +586,83 @@ async def research(
|
|||||||
iterations_to_process.append(result)
|
iterations_to_process.append(result)
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
# Process all tool calls from this planning step
|
if iterations_to_process:
|
||||||
for this_iteration in iterations_to_process:
|
# Create tasks for parallel execution
|
||||||
online_results: Dict = dict()
|
tasks = [
|
||||||
code_results: Dict = dict()
|
execute_tool(
|
||||||
document_results: List[Dict[str, str]] = []
|
iteration=iteration,
|
||||||
operator_results: OperatorRun = None
|
|
||||||
mcp_results: List = []
|
|
||||||
|
|
||||||
# Skip running iteration if warning present in iteration
|
|
||||||
if this_iteration.warning:
|
|
||||||
logger.warning(f"Research mode: {this_iteration.warning}.")
|
|
||||||
|
|
||||||
# Terminate research if selected text tool or query, tool not set for next iteration
|
|
||||||
elif (
|
|
||||||
not this_iteration.query
|
|
||||||
or isinstance(this_iteration.query, str)
|
|
||||||
or this_iteration.query.name == ConversationCommand.Text
|
|
||||||
):
|
|
||||||
current_iteration = MAX_ITERATIONS
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
|
|
||||||
this_iteration.context = []
|
|
||||||
document_results = []
|
|
||||||
previous_inferred_queries = {
|
|
||||||
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
|
|
||||||
}
|
|
||||||
async for result in search_documents(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
n=max_document_searches,
|
|
||||||
d=None,
|
|
||||||
user=user,
|
user=user,
|
||||||
chat_history=construct_tool_chat_history(
|
|
||||||
previous_iterations, ConversationCommand.SemanticSearchFiles
|
|
||||||
),
|
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
conversation_commands=[ConversationCommand.Notes],
|
previous_iterations=previous_iterations,
|
||||||
location_data=location,
|
location=location,
|
||||||
send_status_func=send_status_func,
|
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
previous_inferred_queries=previous_inferred_queries,
|
max_document_searches=max_document_searches,
|
||||||
|
max_online_searches=max_online_searches,
|
||||||
|
mcp_clients=mcp_clients,
|
||||||
|
cancellation_event=cancellation_event,
|
||||||
|
interrupt_queue=interrupt_queue,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
for iteration in iterations_to_process
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute all tools in parallel
|
||||||
|
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Process results and yield status messages
|
||||||
|
for this_iteration, tool_result in zip(iterations_to_process, tool_results):
|
||||||
|
# Handle exceptions from asyncio.gather
|
||||||
|
if isinstance(tool_result, Exception):
|
||||||
|
this_iteration.warning = f"Error executing tool: {tool_result}"
|
||||||
|
logger.error(this_iteration.warning, exc_info=True)
|
||||||
|
tool_result = ToolExecutionResult()
|
||||||
|
|
||||||
|
# Check for termination
|
||||||
|
if tool_result.should_terminate:
|
||||||
|
current_iteration = MAX_ITERATIONS
|
||||||
|
|
||||||
|
# Yield all collected status messages through the real send_status_func
|
||||||
|
for status_msg in tool_result.status_messages:
|
||||||
|
if send_status_func:
|
||||||
|
async for status_event in send_status_func(status_msg):
|
||||||
|
yield status_event
|
||||||
|
|
||||||
|
current_iteration += 1
|
||||||
|
|
||||||
|
# Build summarized results
|
||||||
|
if (
|
||||||
|
tool_result.document_results
|
||||||
|
or tool_result.online_results
|
||||||
|
or tool_result.code_results
|
||||||
|
or tool_result.operator_results
|
||||||
|
or tool_result.mcp_results
|
||||||
|
or this_iteration.warning
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
results_data = f"\n<iteration_{current_iteration}_results>"
|
||||||
yield result[ChatEvent.STATUS]
|
if tool_result.document_results:
|
||||||
elif isinstance(result, tuple):
|
results_data += f"\n<document_references>\n{yaml.dump(tool_result.document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
|
||||||
document_results = result[0]
|
if tool_result.online_results:
|
||||||
this_iteration.context += document_results
|
results_data += f"\n<online_results>\n{yaml.dump(tool_result.online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
|
||||||
|
if tool_result.code_results:
|
||||||
|
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(tool_result.code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
|
||||||
|
if tool_result.operator_results:
|
||||||
|
results_data += f"\n<browser_operator_results>\n{tool_result.operator_results.response}\n</browser_operator_results>"
|
||||||
|
if tool_result.mcp_results:
|
||||||
|
results_data += f"\n<mcp_tool_results>\n{yaml.dump(tool_result.mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
|
||||||
|
if this_iteration.warning:
|
||||||
|
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
|
||||||
|
results_data += f"\n</iteration_{current_iteration}_results>"
|
||||||
|
|
||||||
if not is_none_or_empty(document_results):
|
this_iteration.summarizedResult = results_data
|
||||||
try:
|
|
||||||
distinct_files = {d["file"] for d in document_results}
|
|
||||||
distinct_headings = set(
|
|
||||||
[d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]
|
|
||||||
)
|
|
||||||
# Strip only leading # from headings
|
|
||||||
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
|
|
||||||
async for result in send_status_func(
|
|
||||||
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
|
|
||||||
):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error extracting document references: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
else:
|
|
||||||
this_iteration.warning = "No matching document references found"
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.SearchWeb:
|
this_iteration.summarizedResult = (
|
||||||
previous_subqueries = {
|
this_iteration.summarizedResult
|
||||||
subquery
|
or f"<iteration_{current_iteration}_results>Failed to get results.</iteration_{current_iteration}_results>"
|
||||||
for iteration in previous_iterations
|
)
|
||||||
if iteration.onlineContext
|
previous_iterations.append(this_iteration)
|
||||||
for subquery in iteration.onlineContext.keys()
|
yield this_iteration
|
||||||
}
|
|
||||||
try:
|
|
||||||
async for result in search_online(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
conversation_history=construct_tool_chat_history(
|
|
||||||
previous_iterations, ConversationCommand.SearchWeb
|
|
||||||
),
|
|
||||||
location=location,
|
|
||||||
user=user,
|
|
||||||
send_status_func=send_status_func,
|
|
||||||
custom_filters=[],
|
|
||||||
max_online_searches=max_online_searches,
|
|
||||||
max_webpages_to_read=0,
|
|
||||||
query_images=query_images,
|
|
||||||
previous_subqueries=previous_subqueries,
|
|
||||||
agent=agent,
|
|
||||||
tracer=tracer,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
elif is_none_or_empty(result):
|
|
||||||
this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different."
|
|
||||||
else:
|
|
||||||
online_results: Dict[str, Dict] = result # type: ignore
|
|
||||||
this_iteration.onlineContext = online_results
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error searching online: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.ReadWebpage:
|
|
||||||
try:
|
|
||||||
async for result in read_webpages_content(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
user=user,
|
|
||||||
send_status_func=send_status_func,
|
|
||||||
agent=agent,
|
|
||||||
tracer=tracer,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
else:
|
|
||||||
direct_web_pages: Dict[str, Dict] = result # type: ignore
|
|
||||||
|
|
||||||
webpages = []
|
|
||||||
for web_query in direct_web_pages:
|
|
||||||
if online_results.get(web_query):
|
|
||||||
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
|
|
||||||
else:
|
|
||||||
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
|
|
||||||
|
|
||||||
for webpage in direct_web_pages[web_query]["webpages"]:
|
|
||||||
webpages.append(webpage["link"])
|
|
||||||
this_iteration.onlineContext = online_results
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error reading webpages: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.PythonCoder:
|
|
||||||
try:
|
|
||||||
async for result in run_code(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
conversation_history=construct_tool_chat_history(
|
|
||||||
previous_iterations, ConversationCommand.PythonCoder
|
|
||||||
),
|
|
||||||
context="",
|
|
||||||
location_data=location,
|
|
||||||
user=user,
|
|
||||||
send_status_func=send_status_func,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
agent=agent,
|
|
||||||
tracer=tracer,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
else:
|
|
||||||
code_results: Dict[str, Dict] = result # type: ignore
|
|
||||||
this_iteration.codeContext = code_results
|
|
||||||
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
|
|
||||||
yield result
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
this_iteration.warning = f"Error running code: {e}"
|
|
||||||
logger.warning(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.OperateComputer:
|
|
||||||
try:
|
|
||||||
async for result in operate_environment(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
user=user,
|
|
||||||
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
|
||||||
location_data=location,
|
|
||||||
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
|
|
||||||
send_status_func=send_status_func,
|
|
||||||
query_images=query_images,
|
|
||||||
agent=agent,
|
|
||||||
query_files=query_files,
|
|
||||||
cancellation_event=cancellation_event,
|
|
||||||
interrupt_queue=interrupt_queue,
|
|
||||||
tracer=tracer,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
elif isinstance(result, OperatorRun):
|
|
||||||
operator_results = result
|
|
||||||
this_iteration.operatorContext = operator_results
|
|
||||||
# Add webpages visited while operating browser to references
|
|
||||||
if result.webpages:
|
|
||||||
if not online_results.get(this_iteration.query):
|
|
||||||
online_results[this_iteration.query] = {"webpages": result.webpages}
|
|
||||||
elif not online_results[this_iteration.query].get("webpages"):
|
|
||||||
online_results[this_iteration.query]["webpages"] = result.webpages
|
|
||||||
else:
|
|
||||||
online_results[this_iteration.query]["webpages"] += result.webpages
|
|
||||||
this_iteration.onlineContext = online_results
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error operating browser: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.ViewFile:
|
|
||||||
try:
|
|
||||||
async for result in view_file_content(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
user=user,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
else:
|
|
||||||
if this_iteration.context is None:
|
|
||||||
this_iteration.context = []
|
|
||||||
document_results: List[Dict[str, str]] = result # type: ignore
|
|
||||||
this_iteration.context += document_results
|
|
||||||
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error viewing file: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.ListFiles:
|
|
||||||
try:
|
|
||||||
async for result in list_files(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
user=user,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
else:
|
|
||||||
if this_iteration.context is None:
|
|
||||||
this_iteration.context = []
|
|
||||||
document_results: List[Dict[str, str]] = [result] # type: ignore
|
|
||||||
this_iteration.context += document_results
|
|
||||||
async for result in send_status_func(result["query"]):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error listing files: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
|
|
||||||
try:
|
|
||||||
async for result in grep_files(
|
|
||||||
**this_iteration.query.args,
|
|
||||||
user=user,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
else:
|
|
||||||
if this_iteration.context is None:
|
|
||||||
this_iteration.context = []
|
|
||||||
document_results: List[Dict[str, str]] = [result] # type: ignore
|
|
||||||
this_iteration.context += document_results
|
|
||||||
async for result in send_status_func(result["query"]):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error searching with regex: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
elif "/" in this_iteration.query.name:
|
|
||||||
try:
|
|
||||||
# Identify MCP client to use
|
|
||||||
server_name, tool_name = this_iteration.query.name.split("/", 1)
|
|
||||||
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
|
|
||||||
if not mcp_client:
|
|
||||||
raise ValueError(f"Could not find MCP server with name {server_name}")
|
|
||||||
|
|
||||||
# Invoke tool on the identified MCP server
|
|
||||||
mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
|
|
||||||
|
|
||||||
# Record tool result in context
|
|
||||||
if this_iteration.context is None:
|
|
||||||
this_iteration.context = []
|
|
||||||
this_iteration.context += mcp_results
|
|
||||||
async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
this_iteration.warning = f"Error using MCP tool: {e}"
|
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# No valid tools. This is our exit condition.
|
|
||||||
current_iteration = MAX_ITERATIONS
|
|
||||||
|
|
||||||
current_iteration += 1
|
|
||||||
|
|
||||||
if (
|
|
||||||
document_results
|
|
||||||
or online_results
|
|
||||||
or code_results
|
|
||||||
or operator_results
|
|
||||||
or mcp_results
|
|
||||||
or this_iteration.warning
|
|
||||||
):
|
|
||||||
results_data = f"\n<iteration_{current_iteration}_results>"
|
|
||||||
if document_results:
|
|
||||||
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
|
|
||||||
if online_results:
|
|
||||||
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
|
|
||||||
if code_results:
|
|
||||||
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
|
|
||||||
if operator_results:
|
|
||||||
results_data += (
|
|
||||||
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
|
|
||||||
)
|
|
||||||
if mcp_results:
|
|
||||||
results_data += f"\n<mcp_tool_results>\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
|
|
||||||
if this_iteration.warning:
|
|
||||||
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
|
|
||||||
results_data += f"\n</iteration_{current_iteration}_results>"
|
|
||||||
|
|
||||||
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
|
|
||||||
this_iteration.summarizedResult = results_data
|
|
||||||
|
|
||||||
this_iteration.summarizedResult = (
|
|
||||||
this_iteration.summarizedResult
|
|
||||||
or f"<iteration_{current_iteration}_results>Failed to get results.</iteration_{current_iteration}_results>"
|
|
||||||
)
|
|
||||||
previous_iterations.append(this_iteration)
|
|
||||||
yield this_iteration
|
|
||||||
|
|
||||||
# Close MCP client connections
|
# Close MCP client connections
|
||||||
for mcp_client in mcp_clients:
|
for mcp_client in mcp_clients:
|
||||||
|
|||||||
Reference in New Issue
Block a user