Make researcher pick next tool using model function calling feature

The pick next tool requests next tool to call to model in function
calling / tool use format.
This commit is contained in:
Debanjum
2025-06-11 17:21:53 -07:00
parent b888d5e65e
commit 80522e370e
2 changed files with 46 additions and 39 deletions

View File

@@ -678,7 +678,6 @@ Create a multi-step plan and intelligently iterate on the plan based on the retr
- Ensure that all required context is passed to the tool AIs for successful execution. Include any relevant stuff that has previously been attempted. They only know the context provided in your query.
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
- Stop when you have the required information by returning a JSON object with the "tool" field set to "text" and "query" field empty. E.g., {{"scratchpad": "I have all I need", "tool": "text", "query": ""}}
# Examples
Assuming you can search the user's notes and the internet.
@@ -704,10 +703,6 @@ Assuming you can search the user's notes and the internet.
You decide which of the tool AIs listed below would you use to answer the user's question. You **only** have access to the following tool AIs:
{tools}
Your response should always be a valid JSON object with keys: "scratchpad" (str), "tool" (str) and "query" (str). Do not say anything else.
Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
""".strip()
)

View File

@@ -15,8 +15,10 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
ResearchIteration,
ToolDefinition,
construct_iteration_history,
construct_tool_chat_history,
create_tool_definition,
load_complex_json,
)
from khoj.processor.operator import operate_environment
@@ -144,6 +146,13 @@ async def apick_next_tool(
# Create planning reponse model with dynamically populated tool enum class
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
tools = [
create_tool_definition(
name="infer_information_sources_to_refer",
description="Infer which tool to use next and the query to send to the tool.",
schema=planning_response_model,
)
]
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
@@ -174,12 +183,13 @@ async def apick_next_tool(
try:
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
raw_response = await send_message_to_model_wrapper(
query=query,
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
response_type="json_object",
response_schema=planning_response_model,
tools=tools,
deepthought=True,
user=user,
query_images=query_images,
@@ -197,41 +207,43 @@ async def apick_next_tool(
return
try:
response = load_complex_json(response)
if not isinstance(response, dict):
raise ValueError(f"Expected dict response, got {type(response).__name__}: {response}")
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}")
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations if i.warning is None}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += (
f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond."
)
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
# Try parse the response as function call response to infer next tool to use.
response = load_complex_json(load_complex_json(raw_response)[0]["args"])
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield ResearchIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
try:
# Else try parse the text response as JSON to infer next tool to use.
response = load_complex_json(raw_response)
except Exception as e:
# Otherwise assume the model has decided to end the research run and respond to the user.
response = {"tool": ConversationCommand.Text, "query": None, "scratchpad": raw_response}
# If we have a valid response, extract the tool and query.
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}")
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations if i.warning is None}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func and scratchpad:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += (
f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond."
)
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
async def research(