mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Allow LLMs to make parallel tool call requests
Why -- - The models are now smart enough to usually understand which tools to call in parallel and when. - The LLM can request more work for each call to it, which is usually the slowest step. This speeds up work by reearch agent. Even though each tool is still executed in sequence (for now).
This commit is contained in:
@@ -85,8 +85,7 @@ def anthropic_completion_with_backoff(
|
|||||||
# Cache tool definitions
|
# Cache tool definitions
|
||||||
last_tool = model_kwargs["tools"][-1]
|
last_tool = model_kwargs["tools"][-1]
|
||||||
last_tool["cache_control"] = {"type": "ephemeral"}
|
last_tool["cache_control"] = {"type": "ephemeral"}
|
||||||
# Disable parallel tool call until we add support for it
|
model_kwargs["tool_choice"] = {"type": "auto"}
|
||||||
model_kwargs["tool_choice"] = {"type": "auto", "disable_parallel_tool_use": True}
|
|
||||||
elif response_schema:
|
elif response_schema:
|
||||||
tool = create_tool_definition(response_schema)
|
tool = create_tool_definition(response_schema)
|
||||||
model_kwargs["tools"] = [
|
model_kwargs["tools"] = [
|
||||||
|
|||||||
@@ -245,8 +245,37 @@ def construct_iteration_history(
|
|||||||
if query_message_content:
|
if query_message_content:
|
||||||
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
|
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
|
||||||
|
|
||||||
|
# Group iterations: parallel tool calls share the same raw_response (only first has it)
|
||||||
|
# We need to group them so one assistant message has all tool_use blocks and
|
||||||
|
# one user message has all tool_results
|
||||||
|
current_group_raw_response = None
|
||||||
|
current_group_tool_results = []
|
||||||
|
|
||||||
|
def flush_group():
|
||||||
|
"""Output the current group as assistant message + user message with tool results"""
|
||||||
|
nonlocal current_group_raw_response, current_group_tool_results
|
||||||
|
if current_group_raw_response and current_group_tool_results:
|
||||||
|
iteration_history.append(
|
||||||
|
ChatMessageModel(
|
||||||
|
by="khoj",
|
||||||
|
message=current_group_raw_response,
|
||||||
|
intent=Intent(type="tool_call", query=query),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
iteration_history.append(
|
||||||
|
ChatMessageModel(
|
||||||
|
by="you",
|
||||||
|
intent=Intent(type="tool_result"),
|
||||||
|
message=current_group_tool_results,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_group_raw_response = None
|
||||||
|
current_group_tool_results = []
|
||||||
|
|
||||||
for iteration in previous_iterations:
|
for iteration in previous_iterations:
|
||||||
if not iteration.query or isinstance(iteration.query, str):
|
if not iteration.query or isinstance(iteration.query, str):
|
||||||
|
# Flush any pending group before adding non-tool message
|
||||||
|
flush_group()
|
||||||
iteration_history.append(
|
iteration_history.append(
|
||||||
ChatMessageModel(
|
ChatMessageModel(
|
||||||
by="you",
|
by="you",
|
||||||
@@ -256,25 +285,36 @@ def construct_iteration_history(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
iteration_history += [
|
|
||||||
ChatMessageModel(
|
# If this iteration has raw_response, it starts a new group of parallel tool calls
|
||||||
by="khoj",
|
if iteration.raw_response:
|
||||||
message=iteration.raw_response or [iteration.query.__dict__],
|
# Flush previous group if exists
|
||||||
intent=Intent(type="tool_call", query=query),
|
flush_group()
|
||||||
),
|
current_group_raw_response = iteration.raw_response
|
||||||
ChatMessageModel(
|
|
||||||
by="you",
|
# If no raw_response and no current group, create a fallback single-tool response
|
||||||
intent=Intent(type="tool_result"),
|
elif not current_group_raw_response:
|
||||||
message=[
|
current_group_raw_response = [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": iteration.query.id,
|
||||||
|
"name": iteration.query.name,
|
||||||
|
"input": iteration.query.args,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add tool result to current group
|
||||||
|
current_group_tool_results.append(
|
||||||
{
|
{
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"id": iteration.query.id,
|
"id": iteration.query.id,
|
||||||
"name": iteration.query.name,
|
"name": iteration.query.name,
|
||||||
"content": iteration.summarizedResult,
|
"content": iteration.summarizedResult,
|
||||||
}
|
}
|
||||||
],
|
)
|
||||||
),
|
|
||||||
]
|
# Flush any remaining group
|
||||||
|
flush_group()
|
||||||
|
|
||||||
return iteration_history
|
return iteration_history
|
||||||
|
|
||||||
|
|||||||
@@ -208,32 +208,41 @@ async def apick_next_tool(
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try parse the response as function call response to infer next tool to use.
|
# Try parse the response as function call response to infer next tools to use.
|
||||||
# TODO: Handle multiple tool calls.
|
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
|
parsed_responses = [ToolCall(**item) for item in load_complex_json(response_text)]
|
||||||
except Exception:
|
except Exception:
|
||||||
# Otherwise assume the model has decided to end the research run and respond to the user.
|
# Otherwise assume the model has decided to end the research run and respond to the user.
|
||||||
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
|
parsed_responses = [ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)]
|
||||||
|
|
||||||
# If we have a valid response, extract the tool and query.
|
# Detect selection of previously used query, tool combinations.
|
||||||
warning = None
|
|
||||||
logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})")
|
|
||||||
|
|
||||||
# Detect selection of previously used query, tool combination.
|
|
||||||
previous_tool_query_combinations = {
|
previous_tool_query_combinations = {
|
||||||
(i.query.name, dict_to_tuple(i.query.args))
|
(i.query.name, dict_to_tuple(i.query.args))
|
||||||
for i in previous_iterations
|
for i in previous_iterations
|
||||||
if i.warning is None and isinstance(i.query, ToolCall)
|
if i.warning is None and isinstance(i.query, ToolCall)
|
||||||
}
|
}
|
||||||
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
|
|
||||||
warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different."
|
# Send status update with model's thoughts if available
|
||||||
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
|
if send_status_func and not is_none_or_empty(response.thought):
|
||||||
elif send_status_func and not is_none_or_empty(response.thought):
|
|
||||||
async for event in send_status_func(response.thought):
|
async for event in send_status_func(response.thought):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
|
# Yield a ResearchIteration for each tool call to enable parallel execution
|
||||||
|
for idx, parsed_response in enumerate(parsed_responses):
|
||||||
|
warning = None
|
||||||
|
logger.info(
|
||||||
|
f"Response for determining relevant tools ({idx + 1}/{len(parsed_responses)}): {parsed_response.name}({parsed_response.args})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
|
||||||
|
warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different."
|
||||||
|
|
||||||
|
# Include raw_response only for the first tool call to avoid duplication in history
|
||||||
|
yield ResearchIteration(
|
||||||
|
query=parsed_response,
|
||||||
|
warning=warning,
|
||||||
|
raw_response=response.raw_content if idx == 0 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def research(
|
async def research(
|
||||||
@@ -296,13 +305,8 @@ async def research(
|
|||||||
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
# Collect all tool calls from apick_next_tool
|
||||||
code_results: Dict = dict()
|
iterations_to_process: List[ResearchIteration] = []
|
||||||
document_results: List[Dict[str, str]] = []
|
|
||||||
operator_results: OperatorRun = None
|
|
||||||
mcp_results: List = []
|
|
||||||
this_iteration = ResearchIteration(query=query)
|
|
||||||
|
|
||||||
async for result in apick_next_tool(
|
async for result in apick_next_tool(
|
||||||
query,
|
query,
|
||||||
research_conversation_history,
|
research_conversation_history,
|
||||||
@@ -324,8 +328,16 @@ async def research(
|
|||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
elif isinstance(result, ResearchIteration):
|
elif isinstance(result, ResearchIteration):
|
||||||
this_iteration = result
|
iterations_to_process.append(result)
|
||||||
yield this_iteration
|
yield result
|
||||||
|
|
||||||
|
# Process all tool calls from this planning step
|
||||||
|
for this_iteration in iterations_to_process:
|
||||||
|
online_results: Dict = dict()
|
||||||
|
code_results: Dict = dict()
|
||||||
|
document_results: List[Dict[str, str]] = []
|
||||||
|
operator_results: OperatorRun = None
|
||||||
|
mcp_results: List = []
|
||||||
|
|
||||||
# Skip running iteration if warning present in iteration
|
# Skip running iteration if warning present in iteration
|
||||||
if this_iteration.warning:
|
if this_iteration.warning:
|
||||||
@@ -350,7 +362,9 @@ async def research(
|
|||||||
n=max_document_searches,
|
n=max_document_searches,
|
||||||
d=None,
|
d=None,
|
||||||
user=user,
|
user=user,
|
||||||
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
|
chat_history=construct_tool_chat_history(
|
||||||
|
previous_iterations, ConversationCommand.SemanticSearchFiles
|
||||||
|
),
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
conversation_commands=[ConversationCommand.Notes],
|
conversation_commands=[ConversationCommand.Notes],
|
||||||
location_data=location,
|
location_data=location,
|
||||||
@@ -370,7 +384,9 @@ async def research(
|
|||||||
if not is_none_or_empty(document_results):
|
if not is_none_or_empty(document_results):
|
||||||
try:
|
try:
|
||||||
distinct_files = {d["file"] for d in document_results}
|
distinct_files = {d["file"] for d in document_results}
|
||||||
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
|
distinct_headings = set(
|
||||||
|
[d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]
|
||||||
|
)
|
||||||
# Strip only leading # from headings
|
# Strip only leading # from headings
|
||||||
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
|
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
|
||||||
async for result in send_status_func(
|
async for result in send_status_func(
|
||||||
|
|||||||
Reference in New Issue
Block a user