mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Simplify research iteration and main research function names
This commit is contained in:
@@ -130,7 +130,7 @@ class OperatorRun:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class InformationCollectionIteration:
|
class ResearchIteration:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tool: str,
|
tool: str,
|
||||||
@@ -160,7 +160,7 @@ class InformationCollectionIteration:
|
|||||||
|
|
||||||
|
|
||||||
def construct_iteration_history(
|
def construct_iteration_history(
|
||||||
previous_iterations: List[InformationCollectionIteration],
|
previous_iterations: List[ResearchIteration],
|
||||||
previous_iteration_prompt: str,
|
previous_iteration_prompt: str,
|
||||||
query: str = None,
|
query: str = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
@@ -262,7 +262,7 @@ def construct_question_history(
|
|||||||
|
|
||||||
|
|
||||||
def construct_tool_chat_history(
|
def construct_tool_chat_history(
|
||||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
"""
|
"""
|
||||||
Construct chat history from previous iterations for a specific tool
|
Construct chat history from previous iterations for a specific tool
|
||||||
@@ -271,8 +271,8 @@ def construct_tool_chat_history(
|
|||||||
If no tool is provided inferred query for all tools used are added.
|
If no tool is provided inferred query for all tools used are added.
|
||||||
"""
|
"""
|
||||||
chat_history: list = []
|
chat_history: list = []
|
||||||
base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
|
||||||
extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = {
|
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
|
||||||
ConversationCommand.Notes: (
|
ConversationCommand.Notes: (
|
||||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||||
),
|
),
|
||||||
@@ -377,7 +377,7 @@ async def save_to_conversation_log(
|
|||||||
generated_images: List[str] = [],
|
generated_images: List[str] = [],
|
||||||
raw_generated_files: List[FileAttachment] = [],
|
raw_generated_files: List[FileAttachment] = [],
|
||||||
generated_mermaidjs_diagram: str = None,
|
generated_mermaidjs_diagram: str = None,
|
||||||
research_results: Optional[List[InformationCollectionIteration]] = None,
|
research_results: Optional[List[ResearchIteration]] = None,
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
tracer: Dict[str, Any] = {},
|
tracer: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -66,10 +66,7 @@ from khoj.routers.helpers import (
|
|||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_chat_model,
|
validate_chat_model,
|
||||||
)
|
)
|
||||||
from khoj.routers.research import (
|
from khoj.routers.research import ResearchIteration, research
|
||||||
InformationCollectionIteration,
|
|
||||||
execute_information_collection,
|
|
||||||
)
|
|
||||||
from khoj.routers.storage import upload_user_image_to_bucket
|
from khoj.routers.storage import upload_user_image_to_bucket
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
@@ -723,7 +720,7 @@ async def chat(
|
|||||||
for file in raw_query_files:
|
for file in raw_query_files:
|
||||||
query_files[file.name] = file.content
|
query_files[file.name] = file.content
|
||||||
|
|
||||||
research_results: List[InformationCollectionIteration] = []
|
research_results: List[ResearchIteration] = []
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
code_results: Dict = dict()
|
code_results: Dict = dict()
|
||||||
operator_results: List[OperatorRun] = []
|
operator_results: List[OperatorRun] = []
|
||||||
@@ -962,9 +959,7 @@ async def chat(
|
|||||||
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
|
||||||
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
|
||||||
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
compiled_references = [ref.model_dump() for ref in last_message.context or []]
|
||||||
research_results = [
|
research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []]
|
||||||
InformationCollectionIteration(**iter_dict) for iter_dict in last_message.researchContext or []
|
|
||||||
]
|
|
||||||
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
||||||
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
||||||
# Drop the interrupted message from conversation history
|
# Drop the interrupted message from conversation history
|
||||||
@@ -1011,7 +1006,7 @@ async def chat(
|
|||||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Research]:
|
if conversation_commands == [ConversationCommand.Research]:
|
||||||
async for research_result in execute_information_collection(
|
async for research_result in research(
|
||||||
user=user,
|
user=user,
|
||||||
query=defiltered_query,
|
query=defiltered_query,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@@ -1027,7 +1022,7 @@ async def chat(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
cancellation_event=cancellation_event,
|
cancellation_event=cancellation_event,
|
||||||
):
|
):
|
||||||
if isinstance(research_result, InformationCollectionIteration):
|
if isinstance(research_result, ResearchIteration):
|
||||||
if research_result.summarizedResult:
|
if research_result.summarizedResult:
|
||||||
if research_result.onlineContext:
|
if research_result.onlineContext:
|
||||||
online_results.update(research_result.onlineContext)
|
online_results.update(research_result.onlineContext)
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ from khoj.processor.conversation.openai.gpt import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
InformationCollectionIteration,
|
|
||||||
OperatorRun,
|
OperatorRun,
|
||||||
|
ResearchIteration,
|
||||||
ResponseWithThought,
|
ResponseWithThought,
|
||||||
clean_json,
|
clean_json,
|
||||||
clean_mermaidjs,
|
clean_mermaidjs,
|
||||||
@@ -1357,7 +1357,7 @@ async def agenerate_chat_response(
|
|||||||
online_results: Dict[str, Dict] = {},
|
online_results: Dict[str, Dict] = {},
|
||||||
code_results: Dict[str, Dict] = {},
|
code_results: Dict[str, Dict] = {},
|
||||||
operator_results: List[OperatorRun] = [],
|
operator_results: List[OperatorRun] = [],
|
||||||
research_results: List[InformationCollectionIteration] = [],
|
research_results: List[ResearchIteration] = [],
|
||||||
inferred_queries: List[str] = [],
|
inferred_queries: List[str] = [],
|
||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from khoj.database.adapters import AgentAdapters, EntryAdapters
|
|||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
InformationCollectionIteration,
|
|
||||||
OperatorRun,
|
OperatorRun,
|
||||||
|
ResearchIteration,
|
||||||
construct_iteration_history,
|
construct_iteration_history,
|
||||||
construct_tool_chat_history,
|
construct_tool_chat_history,
|
||||||
load_complex_json,
|
load_complex_json,
|
||||||
@@ -84,7 +84,7 @@ async def apick_next_tool(
|
|||||||
location: LocationData = None,
|
location: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
previous_iterations: List[InformationCollectionIteration] = [],
|
previous_iterations: List[ResearchIteration] = [],
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
query_images: List[str] = [],
|
query_images: List[str] = [],
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
@@ -166,7 +166,7 @@ async def apick_next_tool(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
|
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
|
||||||
yield InformationCollectionIteration(
|
yield ResearchIteration(
|
||||||
tool=None,
|
tool=None,
|
||||||
query=None,
|
query=None,
|
||||||
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
|
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
|
||||||
@@ -195,26 +195,26 @@ async def apick_next_tool(
|
|||||||
async for event in send_status_func(f"{scratchpad}"):
|
async for event in send_status_func(f"{scratchpad}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
yield InformationCollectionIteration(
|
yield ResearchIteration(
|
||||||
tool=selected_tool,
|
tool=selected_tool,
|
||||||
query=generated_query,
|
query=generated_query,
|
||||||
warning=warning,
|
warning=warning,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||||
yield InformationCollectionIteration(
|
yield ResearchIteration(
|
||||||
tool=None,
|
tool=None,
|
||||||
query=None,
|
query=None,
|
||||||
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
|
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def execute_information_collection(
|
async def research(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query: str,
|
query: str,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
previous_iterations: List[InformationCollectionIteration],
|
previous_iterations: List[ResearchIteration],
|
||||||
query_images: List[str],
|
query_images: List[str],
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
@@ -251,7 +251,7 @@ async def execute_information_collection(
|
|||||||
document_results: List[Dict[str, str]] = []
|
document_results: List[Dict[str, str]] = []
|
||||||
operator_results: OperatorRun = None
|
operator_results: OperatorRun = None
|
||||||
summarize_files: str = ""
|
summarize_files: str = ""
|
||||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
this_iteration = ResearchIteration(tool=None, query=query)
|
||||||
|
|
||||||
async for result in apick_next_tool(
|
async for result in apick_next_tool(
|
||||||
query,
|
query,
|
||||||
@@ -272,7 +272,7 @@ async def execute_information_collection(
|
|||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
elif isinstance(result, InformationCollectionIteration):
|
elif isinstance(result, ResearchIteration):
|
||||||
this_iteration = result
|
this_iteration = result
|
||||||
|
|
||||||
# Skip running iteration if warning present in iteration
|
# Skip running iteration if warning present in iteration
|
||||||
|
|||||||
Reference in New Issue
Block a user