From 5b8d663cf178372a75ce92d07bc85836f607fe5b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 17:40:56 -0700 Subject: [PATCH] Add intermediate summarization of results when planning with o1 --- src/khoj/processor/conversation/prompts.py | 9 +- src/khoj/routers/api_chat.py | 16 +-- src/khoj/routers/research.py | 137 +++++++++++---------- 3 files changed, 84 insertions(+), 78 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 1246a43a..6ef1cb94 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -490,9 +490,11 @@ plan_function_execution = PromptTemplate.from_template( You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query. {personality_context} - You have access to a variety of data sources to help you answer the user's question -- You can use the data sources listed below to collect more relevant information, one at a time +- You can use the data sources listed below to collect more relevant information, one at a time. The outputs will be chained. - You are given multiple iterations to with these data sources to answer the user's question - You are provided with additional context. If you have enough context to answer the question, then exit execution +- Each query is self-contained and you can use the data source to answer the user's question. There will be no additional data injected between queries, so make sure the query you're asking is answered in the current iteration. +- Limit each query to a *single* intention. For example, do not say "Look up the top city by population and output the GDP." Instead, say "Look up the top city by population." and then "Tell me the GDP of ." If you already know the answer to the question, return an empty response, e.g., {{}}. @@ -500,7 +502,7 @@ Which of the data sources listed below you would use to answer the user's questi {tools} -Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else. +Provide the data source and associated query in a JSON object. Do not say anything else. Previous Iterations: {previous_iterations} @@ -520,8 +522,7 @@ previous_iteration = PromptTemplate.from_template( """ data_source: {data_source} query: {query} -context: {context} -onlineContext: {onlineContext} +summary: {summary} --- """.strip() ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index f4061c50..158509bd 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -718,18 +718,20 @@ async def chat( ): if type(research_result) == InformationCollectionIteration: pending_research = False - if research_result.onlineContext: - researched_results += str(research_result.onlineContext) - online_results.update(research_result.onlineContext) + # if research_result.onlineContext: + # researched_results += str(research_result.onlineContext) + # online_results.update(research_result.onlineContext) - if research_result.context: - researched_results += str(research_result.context) - compiled_references.extend(research_result.context) + # if research_result.context: + # researched_results += str(research_result.context) + # compiled_references.extend(research_result.context) + + researched_results += research_result.summarizedResult else: yield research_result - researched_results = await extract_relevant_info(q, researched_results, agent) + # researched_results = await extract_relevant_info(q, researched_results, agent) logger.info(f"Researched Results: {researched_results}") diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index f6eae48d..3921cb78 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -13,6 +13,7 @@ from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ChatEvent, construct_chat_history, + extract_relevant_info, generate_summary_from_files, send_message_to_model_wrapper, ) @@ -27,11 +28,19 @@ logger = logging.getLogger(__name__) class InformationCollectionIteration: - def __init__(self, data_source: str, query: str, context: str = None, onlineContext: dict = None): + def __init__( + self, + data_source: str, + query: str, + context: str = None, + onlineContext: dict = None, + summarizedResult: str = None, + ): self.data_source = data_source self.query = query self.context = context self.onlineContext = onlineContext + self.summarizedResult = summarizedResult async def apick_next_tool( @@ -63,8 +72,7 @@ async def apick_next_tool( iteration_data = prompts.previous_iteration.format( query=iteration.query, data_source=iteration.data_source, - context=str(iteration.context), - onlineContext=str(iteration.onlineContext), + summary=iteration.summarizedResult, ) previous_iterations_history += iteration_data @@ -138,7 +146,8 @@ async def execute_information_collection( compiled_references: List[Any] = [] inferred_queries: List[Any] = [] - defiltered_query = None + + result: str = "" this_iteration = await apick_next_tool( query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations @@ -165,13 +174,7 @@ async def execute_information_collection( compiled_references.extend(result[0]) inferred_queries.extend(result[1]) defiltered_query = result[2] - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context=str(compiled_references), - ) - ) + this_iteration.context = str(compiled_references) elif this_iteration.data_source == ConversationCommand.Online: async for result in search_online( @@ -189,13 +192,7 @@ async def execute_information_collection( yield result[ChatEvent.STATUS] else: online_results = result - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - onlineContext=online_results, - ) - ) + this_iteration.onlineContext = online_results elif this_iteration.data_source == ConversationCommand.Webpage: async for result in read_webpages( @@ -224,57 +221,63 @@ async def execute_information_collection( webpages.append(webpage["link"]) yield send_status_func(f"**Read web pages**: {webpages}") - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - onlineContext=online_results, - ) - ) + this_iteration.onlineContext = online_results - elif this_iteration.data_source == ConversationCommand.Summarize: - response_log = "" - agent_has_entries = await EntryAdapters.aagent_has_entries(agent) - if len(file_filters) == 0 and not agent_has_entries: - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context="No files selected for summarization.", - ) - ) - elif len(file_filters) > 1 and not agent_has_entries: - response_log = "Only one file can be selected for summarization." - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context=response_log, - ) - ) - else: - async for response in generate_summary_from_files( - q=query, - user=user, - file_filters=file_filters, - meta_log=conversation_history, - subscribed=subscribed, - send_status_func=send_status_func, - ): - if isinstance(response, dict) and ChatEvent.STATUS in response: - yield response[ChatEvent.STATUS] - else: - response_log = response # type: ignore - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context=response_log, - ) - ) + # TODO: Fix summarize later + # elif this_iteration.data_source == ConversationCommand.Summarize: + # response_log = "" + # agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + # if len(file_filters) == 0 and not agent_has_entries: + # previous_iterations.append( + # InformationCollectionIteration( + # data_source=this_iteration.data_source, + # query=this_iteration.query, + # context="No files selected for summarization.", + # ) + # ) + # elif len(file_filters) > 1 and not agent_has_entries: + # response_log = "Only one file can be selected for summarization." + # previous_iterations.append( + # InformationCollectionIteration( + # data_source=this_iteration.data_source, + # query=this_iteration.query, + # context=response_log, + # ) + # ) + # else: + # async for response in generate_summary_from_files( + # q=query, + # user=user, + # file_filters=file_filters, + # meta_log=conversation_history, + # subscribed=subscribed, + # send_status_func=send_status_func, + # ): + # if isinstance(response, dict) and ChatEvent.STATUS in response: + # yield response[ChatEvent.STATUS] + # else: + # response_log = response # type: ignore + # previous_iterations.append( + # InformationCollectionIteration( + # data_source=this_iteration.data_source, + # query=this_iteration.query, + # context=response_log, + # ) + # ) else: iteration = MAX_ITERATIONS iteration += 1 - for completed_iter in previous_iterations: - yield completed_iter + + if compiled_references or online_results: + results_data = f"**Results**:\n" + if compiled_references: + results_data += f"**Document References**: {compiled_references}\n" + if online_results: + results_data += f"**Online Results**: {online_results}\n" + + intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) + this_iteration.summarizedResult = intermediate_result + + previous_iterations.append(this_iteration) + yield this_iteration