Simplify research iteration and main research function names

This commit is contained in:
Debanjum
2025-05-29 15:04:35 -07:00
parent 6c9d569a22
commit 864e0ac8b5
4 changed files with 22 additions and 27 deletions

View File

@@ -130,7 +130,7 @@ class OperatorRun:
}
class InformationCollectionIteration:
class ResearchIteration:
def __init__(
self,
tool: str,
@@ -160,7 +160,7 @@ class InformationCollectionIteration:
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration],
previous_iterations: List[ResearchIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]:
@@ -262,7 +262,7 @@ def construct_question_history(
def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
"""
Construct chat history from previous iterations for a specific tool
@@ -271,8 +271,8 @@ def construct_tool_chat_history(
If no tool is provided inferred query for all tools used are added.
"""
chat_history: list = []
base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = {
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
ConversationCommand.Notes: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
),
@@ -377,7 +377,7 @@ async def save_to_conversation_log(
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_mermaidjs_diagram: str = None,
research_results: Optional[List[InformationCollectionIteration]] = None,
research_results: Optional[List[ResearchIteration]] = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):

View File

@@ -66,10 +66,7 @@ from khoj.routers.helpers import (
update_telemetry_state,
validate_chat_model,
)
from khoj.routers.research import (
InformationCollectionIteration,
execute_information_collection,
)
from khoj.routers.research import ResearchIteration, research
from khoj.routers.storage import upload_user_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import (
@@ -723,7 +720,7 @@ async def chat(
for file in raw_query_files:
query_files[file.name] = file.content
research_results: List[InformationCollectionIteration] = []
research_results: List[ResearchIteration] = []
online_results: Dict = dict()
code_results: Dict = dict()
operator_results: List[OperatorRun] = []
@@ -962,9 +959,7 @@ async def chat(
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
compiled_references = [ref.model_dump() for ref in last_message.context or []]
research_results = [
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
]
research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []]
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
@@ -1011,7 +1006,7 @@ async def chat(
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
async for research_result in research(
user=user,
query=defiltered_query,
conversation_id=conversation_id,
@@ -1027,7 +1022,7 @@ async def chat(
tracer=tracer,
cancellation_event=cancellation_event,
):
if isinstance(research_result, InformationCollectionIteration):
if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult:
if research_result.onlineContext:
online_results.update(research_result.onlineContext)

View File

@@ -94,8 +94,8 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
InformationCollectionIteration,
OperatorRun,
ResearchIteration,
ResponseWithThought,
clean_json,
clean_mermaidjs,
@@ -1357,7 +1357,7 @@ async def agenerate_chat_response(
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
operator_results: List[OperatorRun] = [],
research_results: List[InformationCollectionIteration] = [],
research_results: List[ResearchIteration] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,

View File

@@ -13,8 +13,8 @@ from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
OperatorRun,
ResearchIteration,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
@@ -84,7 +84,7 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations: List[InformationCollectionIteration] = [],
previous_iterations: List[ResearchIteration] = [],
max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
@@ -166,7 +166,7 @@ async def apick_next_tool(
)
except Exception as e:
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
yield InformationCollectionIteration(
yield ResearchIteration(
tool=None,
query=None,
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
@@ -195,26 +195,26 @@ async def apick_next_tool(
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
yield InformationCollectionIteration(
yield ResearchIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
yield ResearchIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)
async def execute_information_collection(
async def research(
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: dict,
previous_iterations: List[InformationCollectionIteration],
previous_iterations: List[ResearchIteration],
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -251,7 +251,7 @@ async def execute_information_collection(
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
this_iteration = ResearchIteration(tool=None, query=query)
async for result in apick_next_tool(
query,
@@ -272,7 +272,7 @@ async def execute_information_collection(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, InformationCollectionIteration):
elif isinstance(result, ResearchIteration):
this_iteration = result
# Skip running iteration if warning present in iteration