Allow LLMs to make parallel tool call requests

Why
--
- The models are now smart enough to usually understand which tools to
  call in parallel and when.

- The LLM can request more work for each call to it, which is usually
  the slowest step. This speeds up work by reearch agent. Even though
  each tool is still executed in sequence (for now).
This commit is contained in:
Debanjum
2025-12-13 21:37:49 -08:00
parent f4c519a9d0
commit 054ed79fdf
3 changed files with 383 additions and 328 deletions

View File

@@ -85,8 +85,7 @@ def anthropic_completion_with_backoff(
# Cache tool definitions
last_tool = model_kwargs["tools"][-1]
last_tool["cache_control"] = {"type": "ephemeral"}
# Disable parallel tool call until we add support for it
model_kwargs["tool_choice"] = {"type": "auto", "disable_parallel_tool_use": True}
model_kwargs["tool_choice"] = {"type": "auto"}
elif response_schema:
tool = create_tool_definition(response_schema)
model_kwargs["tools"] = [

View File

@@ -245,8 +245,37 @@ def construct_iteration_history(
if query_message_content:
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
# Group iterations: parallel tool calls share the same raw_response (only first has it)
# We need to group them so one assistant message has all tool_use blocks and
# one user message has all tool_results
current_group_raw_response = None
current_group_tool_results = []
def flush_group():
"""Output the current group as assistant message + user message with tool results"""
nonlocal current_group_raw_response, current_group_tool_results
if current_group_raw_response and current_group_tool_results:
iteration_history.append(
ChatMessageModel(
by="khoj",
message=current_group_raw_response,
intent=Intent(type="tool_call", query=query),
)
)
iteration_history.append(
ChatMessageModel(
by="you",
intent=Intent(type="tool_result"),
message=current_group_tool_results,
)
)
current_group_raw_response = None
current_group_tool_results = []
for iteration in previous_iterations:
if not iteration.query or isinstance(iteration.query, str):
# Flush any pending group before adding non-tool message
flush_group()
iteration_history.append(
ChatMessageModel(
by="you",
@@ -256,25 +285,36 @@ def construct_iteration_history(
)
)
continue
iteration_history += [
ChatMessageModel(
by="khoj",
message=iteration.raw_response or [iteration.query.__dict__],
intent=Intent(type="tool_call", query=query),
),
ChatMessageModel(
by="you",
intent=Intent(type="tool_result"),
message=[
# If this iteration has raw_response, it starts a new group of parallel tool calls
if iteration.raw_response:
# Flush previous group if exists
flush_group()
current_group_raw_response = iteration.raw_response
# If no raw_response and no current group, create a fallback single-tool response
elif not current_group_raw_response:
current_group_raw_response = [
{
"type": "tool_use",
"id": iteration.query.id,
"name": iteration.query.name,
"input": iteration.query.args,
}
]
# Add tool result to current group
current_group_tool_results.append(
{
"type": "tool_result",
"id": iteration.query.id,
"name": iteration.query.name,
"content": iteration.summarizedResult,
}
],
),
]
)
# Flush any remaining group
flush_group()
return iteration_history

View File

@@ -208,32 +208,41 @@ async def apick_next_tool(
return
try:
# Try parse the response as function call response to infer next tool to use.
# TODO: Handle multiple tool calls.
# Try parse the response as function call response to infer next tools to use.
response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
parsed_responses = [ToolCall(**item) for item in load_complex_json(response_text)]
except Exception:
# Otherwise assume the model has decided to end the research run and respond to the user.
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
parsed_responses = [ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)]
# If we have a valid response, extract the tool and query.
warning = None
logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})")
# Detect selection of previously used query, tool combination.
# Detect selection of previously used query, tool combinations.
previous_tool_query_combinations = {
(i.query.name, dict_to_tuple(i.query.args))
for i in previous_iterations
if i.warning is None and isinstance(i.query, ToolCall)
}
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different."
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
elif send_status_func and not is_none_or_empty(response.thought):
# Send status update with model's thoughts if available
if send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
# Yield a ResearchIteration for each tool call to enable parallel execution
for idx, parsed_response in enumerate(parsed_responses):
warning = None
logger.info(
f"Response for determining relevant tools ({idx + 1}/{len(parsed_responses)}): {parsed_response.name}({parsed_response.args})"
)
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different."
# Include raw_response only for the first tool call to avoid duplication in history
yield ResearchIteration(
query=parsed_response,
warning=warning,
raw_response=response.raw_content if idx == 0 else None,
)
async def research(
@@ -296,13 +305,8 @@ async def research(
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
mcp_results: List = []
this_iteration = ResearchIteration(query=query)
# Collect all tool calls from apick_next_tool
iterations_to_process: List[ResearchIteration] = []
async for result in apick_next_tool(
query,
research_conversation_history,
@@ -324,8 +328,16 @@ async def research(
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, ResearchIteration):
this_iteration = result
yield this_iteration
iterations_to_process.append(result)
yield result
# Process all tool calls from this planning step
for this_iteration in iterations_to_process:
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
mcp_results: List = []
# Skip running iteration if warning present in iteration
if this_iteration.warning:
@@ -350,7 +362,9 @@ async def research(
n=max_document_searches,
d=None,
user=user,
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
chat_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.SemanticSearchFiles
),
conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Notes],
location_data=location,
@@ -370,7 +384,9 @@ async def research(
if not is_none_or_empty(document_results):
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
distinct_headings = set(
[d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]
)
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(