mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Simplify research iteration and main research function names
This commit is contained in:
@@ -130,7 +130,7 @@ class OperatorRun:
|
||||
}
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
class ResearchIteration:
|
||||
def __init__(
|
||||
self,
|
||||
tool: str,
|
||||
@@ -160,7 +160,7 @@ class InformationCollectionIteration:
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
previous_iterations: List[InformationCollectionIteration],
|
||||
previous_iterations: List[ResearchIteration],
|
||||
previous_iteration_prompt: str,
|
||||
query: str = None,
|
||||
) -> list[dict]:
|
||||
@@ -262,7 +262,7 @@ def construct_question_history(
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
"""
|
||||
Construct chat history from previous iterations for a specific tool
|
||||
@@ -271,8 +271,8 @@ def construct_tool_chat_history(
|
||||
If no tool is provided inferred query for all tools used are added.
|
||||
"""
|
||||
chat_history: list = []
|
||||
base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
||||
extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = {
|
||||
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
|
||||
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
|
||||
ConversationCommand.Notes: (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
),
|
||||
@@ -377,7 +377,7 @@ async def save_to_conversation_log(
|
||||
generated_images: List[str] = [],
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
research_results: Optional[List[InformationCollectionIteration]] = None,
|
||||
research_results: Optional[List[ResearchIteration]] = None,
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
|
||||
@@ -66,10 +66,7 @@ from khoj.routers.helpers import (
|
||||
update_telemetry_state,
|
||||
validate_chat_model,
|
||||
)
|
||||
from khoj.routers.research import (
|
||||
InformationCollectionIteration,
|
||||
execute_information_collection,
|
||||
)
|
||||
from khoj.routers.research import ResearchIteration, research
|
||||
from khoj.routers.storage import upload_user_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
@@ -723,7 +720,7 @@ async def chat(
|
||||
for file in raw_query_files:
|
||||
query_files[file.name] = file.content
|
||||
|
||||
research_results: List[InformationCollectionIteration] = []
|
||||
research_results: List[ResearchIteration] = []
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
operator_results: List[OperatorRun] = []
|
||||
@@ -962,9 +959,7 @@ async def chat(
|
||||
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
||||
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
||||
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
||||
research_results = [
|
||||
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
|
||||
]
|
||||
research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []]
|
||||
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
||||
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
||||
# Drop the interrupted message from conversation history
|
||||
@@ -1011,7 +1006,7 @@ async def chat(
|
||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||
|
||||
if conversation_commands == [ConversationCommand.Research]:
|
||||
async for research_result in execute_information_collection(
|
||||
async for research_result in research(
|
||||
user=user,
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
@@ -1027,7 +1022,7 @@ async def chat(
|
||||
tracer=tracer,
|
||||
cancellation_event=cancellation_event,
|
||||
):
|
||||
if isinstance(research_result, InformationCollectionIteration):
|
||||
if isinstance(research_result, ResearchIteration):
|
||||
if research_result.summarizedResult:
|
||||
if research_result.onlineContext:
|
||||
online_results.update(research_result.onlineContext)
|
||||
|
||||
@@ -94,8 +94,8 @@ from khoj.processor.conversation.openai.gpt import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
InformationCollectionIteration,
|
||||
OperatorRun,
|
||||
ResearchIteration,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
@@ -1357,7 +1357,7 @@ async def agenerate_chat_response(
|
||||
online_results: Dict[str, Dict] = {},
|
||||
code_results: Dict[str, Dict] = {},
|
||||
operator_results: List[OperatorRun] = [],
|
||||
research_results: List[InformationCollectionIteration] = [],
|
||||
research_results: List[ResearchIteration] = [],
|
||||
inferred_queries: List[str] = [],
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
user: KhojUser = None,
|
||||
|
||||
@@ -13,8 +13,8 @@ from khoj.database.adapters import AgentAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
InformationCollectionIteration,
|
||||
OperatorRun,
|
||||
ResearchIteration,
|
||||
construct_iteration_history,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
@@ -84,7 +84,7 @@ async def apick_next_tool(
|
||||
location: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
previous_iterations: List[InformationCollectionIteration] = [],
|
||||
previous_iterations: List[ResearchIteration] = [],
|
||||
max_iterations: int = 5,
|
||||
query_images: List[str] = [],
|
||||
query_files: str = None,
|
||||
@@ -166,7 +166,7 @@ async def apick_next_tool(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
yield ResearchIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
|
||||
@@ -195,26 +195,26 @@ async def apick_next_tool(
|
||||
async for event in send_status_func(f"{scratchpad}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
yield InformationCollectionIteration(
|
||||
yield ResearchIteration(
|
||||
tool=selected_tool,
|
||||
query=generated_query,
|
||||
warning=warning,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
yield ResearchIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
|
||||
)
|
||||
|
||||
|
||||
async def execute_information_collection(
|
||||
async def research(
|
||||
user: KhojUser,
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
previous_iterations: List[InformationCollectionIteration],
|
||||
previous_iterations: List[ResearchIteration],
|
||||
query_images: List[str],
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
@@ -251,7 +251,7 @@ async def execute_information_collection(
|
||||
document_results: List[Dict[str, str]] = []
|
||||
operator_results: OperatorRun = None
|
||||
summarize_files: str = ""
|
||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
||||
this_iteration = ResearchIteration(tool=None, query=query)
|
||||
|
||||
async for result in apick_next_tool(
|
||||
query,
|
||||
@@ -272,7 +272,7 @@ async def execute_information_collection(
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif isinstance(result, InformationCollectionIteration):
|
||||
elif isinstance(result, ResearchIteration):
|
||||
this_iteration = result
|
||||
|
||||
# Skip running iteration if warning present in iteration
|
||||
|
||||
Reference in New Issue
Block a user