Add intermediate summarization of results when planning with o1

This commit is contained in:
sabaimran
2024-10-09 17:40:56 -07:00
parent 7b288a1179
commit 5b8d663cf1
3 changed files with 84 additions and 78 deletions

View File

@@ -490,9 +490,11 @@ plan_function_execution = PromptTemplate.from_template(
You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query. You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query.
{personality_context} {personality_context}
- You have access to a variety of data sources to help you answer the user's question - You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information, one at a time - You can use the data sources listed below to collect more relevant information, one at a time. The outputs will be chained.
- You are given multiple iterations to with these data sources to answer the user's question - You are given multiple iterations to with these data sources to answer the user's question
- You are provided with additional context. If you have enough context to answer the question, then exit execution - You are provided with additional context. If you have enough context to answer the question, then exit execution
- Each query is self-contained and you can use the data source to answer the user's question. There will be no additional data injected between queries, so make sure the query you're asking is answered in the current iteration.
- Limit each query to a *single* intention. For example, do not say "Look up the top city by population and output the GDP." Instead, say "Look up the top city by population." and then "Tell me the GDP of <the city>."
If you already know the answer to the question, return an empty response, e.g., {{}}. If you already know the answer to the question, return an empty response, e.g., {{}}.
@@ -500,7 +502,7 @@ Which of the data sources listed below you would use to answer the user's questi
{tools} {tools}
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else. Provide the data source and associated query in a JSON object. Do not say anything else.
Previous Iterations: Previous Iterations:
{previous_iterations} {previous_iterations}
@@ -520,8 +522,7 @@ previous_iteration = PromptTemplate.from_template(
""" """
data_source: {data_source} data_source: {data_source}
query: {query} query: {query}
context: {context} summary: {summary}
onlineContext: {onlineContext}
--- ---
""".strip() """.strip()
) )

View File

@@ -718,18 +718,20 @@ async def chat(
): ):
if type(research_result) == InformationCollectionIteration: if type(research_result) == InformationCollectionIteration:
pending_research = False pending_research = False
if research_result.onlineContext: # if research_result.onlineContext:
researched_results += str(research_result.onlineContext) # researched_results += str(research_result.onlineContext)
online_results.update(research_result.onlineContext) # online_results.update(research_result.onlineContext)
if research_result.context: # if research_result.context:
researched_results += str(research_result.context) # researched_results += str(research_result.context)
compiled_references.extend(research_result.context) # compiled_references.extend(research_result.context)
researched_results += research_result.summarizedResult
else: else:
yield research_result yield research_result
researched_results = await extract_relevant_info(q, researched_results, agent) # researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}") logger.info(f"Researched Results: {researched_results}")

View File

@@ -13,6 +13,7 @@ from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
construct_chat_history, construct_chat_history,
extract_relevant_info,
generate_summary_from_files, generate_summary_from_files,
send_message_to_model_wrapper, send_message_to_model_wrapper,
) )
@@ -27,11 +28,19 @@ logger = logging.getLogger(__name__)
class InformationCollectionIteration: class InformationCollectionIteration:
def __init__(self, data_source: str, query: str, context: str = None, onlineContext: dict = None): def __init__(
self,
data_source: str,
query: str,
context: str = None,
onlineContext: dict = None,
summarizedResult: str = None,
):
self.data_source = data_source self.data_source = data_source
self.query = query self.query = query
self.context = context self.context = context
self.onlineContext = onlineContext self.onlineContext = onlineContext
self.summarizedResult = summarizedResult
async def apick_next_tool( async def apick_next_tool(
@@ -63,8 +72,7 @@ async def apick_next_tool(
iteration_data = prompts.previous_iteration.format( iteration_data = prompts.previous_iteration.format(
query=iteration.query, query=iteration.query,
data_source=iteration.data_source, data_source=iteration.data_source,
context=str(iteration.context), summary=iteration.summarizedResult,
onlineContext=str(iteration.onlineContext),
) )
previous_iterations_history += iteration_data previous_iterations_history += iteration_data
@@ -138,7 +146,8 @@ async def execute_information_collection(
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[Any] = [] inferred_queries: List[Any] = []
defiltered_query = None
result: str = ""
this_iteration = await apick_next_tool( this_iteration = await apick_next_tool(
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
@@ -165,13 +174,7 @@ async def execute_information_collection(
compiled_references.extend(result[0]) compiled_references.extend(result[0])
inferred_queries.extend(result[1]) inferred_queries.extend(result[1])
defiltered_query = result[2] defiltered_query = result[2]
previous_iterations.append( this_iteration.context = str(compiled_references)
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=str(compiled_references),
)
)
elif this_iteration.data_source == ConversationCommand.Online: elif this_iteration.data_source == ConversationCommand.Online:
async for result in search_online( async for result in search_online(
@@ -189,13 +192,7 @@ async def execute_information_collection(
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
else: else:
online_results = result online_results = result
previous_iterations.append( this_iteration.onlineContext = online_results
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
elif this_iteration.data_source == ConversationCommand.Webpage: elif this_iteration.data_source == ConversationCommand.Webpage:
async for result in read_webpages( async for result in read_webpages(
@@ -224,57 +221,63 @@ async def execute_information_collection(
webpages.append(webpage["link"]) webpages.append(webpage["link"])
yield send_status_func(f"**Read web pages**: {webpages}") yield send_status_func(f"**Read web pages**: {webpages}")
previous_iterations.append( this_iteration.onlineContext = online_results
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
elif this_iteration.data_source == ConversationCommand.Summarize: # TODO: Fix summarize later
response_log = "" # elif this_iteration.data_source == ConversationCommand.Summarize:
agent_has_entries = await EntryAdapters.aagent_has_entries(agent) # response_log = ""
if len(file_filters) == 0 and not agent_has_entries: # agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
previous_iterations.append( # if len(file_filters) == 0 and not agent_has_entries:
InformationCollectionIteration( # previous_iterations.append(
data_source=this_iteration.data_source, # InformationCollectionIteration(
query=this_iteration.query, # data_source=this_iteration.data_source,
context="No files selected for summarization.", # query=this_iteration.query,
) # context="No files selected for summarization.",
) # )
elif len(file_filters) > 1 and not agent_has_entries: # )
response_log = "Only one file can be selected for summarization." # elif len(file_filters) > 1 and not agent_has_entries:
previous_iterations.append( # response_log = "Only one file can be selected for summarization."
InformationCollectionIteration( # previous_iterations.append(
data_source=this_iteration.data_source, # InformationCollectionIteration(
query=this_iteration.query, # data_source=this_iteration.data_source,
context=response_log, # query=this_iteration.query,
) # context=response_log,
) # )
else: # )
async for response in generate_summary_from_files( # else:
q=query, # async for response in generate_summary_from_files(
user=user, # q=query,
file_filters=file_filters, # user=user,
meta_log=conversation_history, # file_filters=file_filters,
subscribed=subscribed, # meta_log=conversation_history,
send_status_func=send_status_func, # subscribed=subscribed,
): # send_status_func=send_status_func,
if isinstance(response, dict) and ChatEvent.STATUS in response: # ):
yield response[ChatEvent.STATUS] # if isinstance(response, dict) and ChatEvent.STATUS in response:
else: # yield response[ChatEvent.STATUS]
response_log = response # type: ignore # else:
previous_iterations.append( # response_log = response # type: ignore
InformationCollectionIteration( # previous_iterations.append(
data_source=this_iteration.data_source, # InformationCollectionIteration(
query=this_iteration.query, # data_source=this_iteration.data_source,
context=response_log, # query=this_iteration.query,
) # context=response_log,
) # )
# )
else: else:
iteration = MAX_ITERATIONS iteration = MAX_ITERATIONS
iteration += 1 iteration += 1
for completed_iter in previous_iterations:
yield completed_iter if compiled_references or online_results:
results_data = f"**Results**:\n"
if compiled_references:
results_data += f"**Document References**: {compiled_references}\n"
if online_results:
results_data += f"**Online Results**: {online_results}\n"
intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = intermediate_result
previous_iterations.append(this_iteration)
yield this_iteration