From 864e0ac8b5ac2a3cc0fdb80beb83091b7531868e Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 29 May 2025 15:04:35 -0700 Subject: [PATCH] Simplify research iteration and main research function names --- src/khoj/processor/conversation/utils.py | 12 ++++++------ src/khoj/routers/api_chat.py | 15 +++++---------- src/khoj/routers/helpers.py | 4 ++-- src/khoj/routers/research.py | 18 +++++++++--------- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 3a29952d..c6de9557 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -130,7 +130,7 @@ class OperatorRun: } -class InformationCollectionIteration: +class ResearchIteration: def __init__( self, tool: str, @@ -160,7 +160,7 @@ class InformationCollectionIteration: def construct_iteration_history( - previous_iterations: List[InformationCollectionIteration], + previous_iterations: List[ResearchIteration], previous_iteration_prompt: str, query: str = None, ) -> list[dict]: @@ -262,7 +262,7 @@ def construct_question_history( def construct_tool_chat_history( - previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None + previous_iterations: List[ResearchIteration], tool: ConversationCommand = None ) -> Dict[str, list]: """ Construct chat history from previous iterations for a specific tool @@ -271,8 +271,8 @@ def construct_tool_chat_history( If no tool is provided inferred query for all tools used are added. """ chat_history: list = [] - base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] - extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = { + base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: [] + extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = { ConversationCommand.Notes: ( lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] ), @@ -377,7 +377,7 @@ async def save_to_conversation_log( generated_images: List[str] = [], raw_generated_files: List[FileAttachment] = [], generated_mermaidjs_diagram: str = None, - research_results: Optional[List[InformationCollectionIteration]] = None, + research_results: Optional[List[ResearchIteration]] = None, train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 2cade5bc..39599c37 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -66,10 +66,7 @@ from khoj.routers.helpers import ( update_telemetry_state, validate_chat_model, ) -from khoj.routers.research import ( - InformationCollectionIteration, - execute_information_collection, -) +from khoj.routers.research import ResearchIteration, research from khoj.routers.storage import upload_user_image_to_bucket from khoj.utils import state from khoj.utils.helpers import ( @@ -723,7 +720,7 @@ async def chat( for file in raw_query_files: query_files[file.name] = file.content - research_results: List[InformationCollectionIteration] = [] + research_results: List[ResearchIteration] = [] online_results: Dict = dict() code_results: Dict = dict() operator_results: List[OperatorRun] = [] @@ -962,9 +959,7 @@ async def chat( online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} compiled_references = [ref.model_dump() for ref in last_message.context or []] - research_results = [ - InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or [] - ] + research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] # Drop the interrupted message from conversation history @@ -1011,7 +1006,7 @@ async def chat( file_filters = conversation.file_filters if conversation and conversation.file_filters else [] if conversation_commands == [ConversationCommand.Research]: - async for research_result in execute_information_collection( + async for research_result in research( user=user, query=defiltered_query, conversation_id=conversation_id, @@ -1027,7 +1022,7 @@ async def chat( tracer=tracer, cancellation_event=cancellation_event, ): - if isinstance(research_result, InformationCollectionIteration): + if isinstance(research_result, ResearchIteration): if research_result.summarizedResult: if research_result.onlineContext: online_results.update(research_result.onlineContext) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 28b15f3c..52610726 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -94,8 +94,8 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, - InformationCollectionIteration, OperatorRun, + ResearchIteration, ResponseWithThought, clean_json, clean_mermaidjs, @@ -1357,7 +1357,7 @@ async def agenerate_chat_response( online_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {}, operator_results: List[OperatorRun] = [], - research_results: List[InformationCollectionIteration] = [], + research_results: List[ResearchIteration] = [], inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], user: KhojUser = None, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index ab0c7062..d0f9f4ef 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -13,8 +13,8 @@ from khoj.database.adapters import AgentAdapters, EntryAdapters from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( - InformationCollectionIteration, OperatorRun, + ResearchIteration, construct_iteration_history, construct_tool_chat_history, load_complex_json, @@ -84,7 +84,7 @@ async def apick_next_tool( location: LocationData = None, user_name: str = None, agent: Agent = None, - previous_iterations: List[InformationCollectionIteration] = [], + previous_iterations: List[ResearchIteration] = [], max_iterations: int = 5, query_images: List[str] = [], query_files: str = None, @@ -166,7 +166,7 @@ async def apick_next_tool( ) except Exception as e: logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True) - yield InformationCollectionIteration( + yield ResearchIteration( tool=None, query=None, warning="Failed to infer information sources to refer. Skipping iteration. Try again.", @@ -195,26 +195,26 @@ async def apick_next_tool( async for event in send_status_func(f"{scratchpad}"): yield {ChatEvent.STATUS: event} - yield InformationCollectionIteration( + yield ResearchIteration( tool=selected_tool, query=generated_query, warning=warning, ) except Exception as e: logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) - yield InformationCollectionIteration( + yield ResearchIteration( tool=None, query=None, warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}", ) -async def execute_information_collection( +async def research( user: KhojUser, query: str, conversation_id: str, conversation_history: dict, - previous_iterations: List[InformationCollectionIteration], + previous_iterations: List[ResearchIteration], query_images: List[str], agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -251,7 +251,7 @@ async def execute_information_collection( document_results: List[Dict[str, str]] = [] operator_results: OperatorRun = None summarize_files: str = "" - this_iteration = InformationCollectionIteration(tool=None, query=query) + this_iteration = ResearchIteration(tool=None, query=query) async for result in apick_next_tool( query, @@ -272,7 +272,7 @@ async def execute_information_collection( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - elif isinstance(result, InformationCollectionIteration): + elif isinstance(result, ResearchIteration): this_iteration = result # Skip running iteration if warning present in iteration