Show tool to use decision for next iteration in train of thought

This commit is contained in:
Debanjum Singh Solanky
2024-10-15 01:08:48 -07:00
parent abcd11cfc0
commit 336c6c3689

View File

@@ -44,6 +44,7 @@ async def apick_next_tool(
agent: Agent = None, agent: Agent = None,
previous_iterations_history: str = None, previous_iterations_history: str = None,
max_iterations: int = 5, max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
@@ -103,17 +104,22 @@ async def apick_next_tool(
selected_tool = response.get("tool", None) selected_tool = response.get("tool", None)
generated_query = response.get("query", None) generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None) scratchpad = response.get("scratchpad", None)
logger.info(f"Response for determining relevant tools: {response}") logger.info(f"Response for determining relevant tools: {response}")
if send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
return InformationCollectionIteration( yield InformationCollectionIteration(
tool=selected_tool, tool=selected_tool,
query=generated_query, query=generated_query,
) )
except Exception as e: except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
return InformationCollectionIteration( yield InformationCollectionIteration(
tool=None, tool=None,
query=None, query=None,
) )
@@ -143,9 +149,7 @@ async def execute_information_collection(
inferred_queries: List[Any] = [] inferred_queries: List[Any] = []
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
result: str = "" async for result in apick_next_tool(
this_iteration = await apick_next_tool(
query, query,
conversation_history, conversation_history,
user, user,
@@ -155,7 +159,13 @@ async def execute_information_collection(
agent, agent,
previous_iterations_history, previous_iterations_history,
MAX_ITERATIONS, MAX_ITERATIONS,
) send_status_func,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
this_iteration = result
if this_iteration.tool == ConversationCommand.Notes: if this_iteration.tool == ConversationCommand.Notes:
## Extract Document References ## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None compiled_references, inferred_queries, defiltered_query = [], [], None