Files
khoj/src/khoj/routers/research.py

561 lines
26 KiB
Python

import asyncio
import logging
import os
from copy import deepcopy
from datetime import datetime
from typing import Callable, Dict, List, Optional
import yaml
from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.models import Agent, ChatMessageModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
ResearchIteration,
ToolCall,
construct_iteration_history,
construct_structured_message,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.operator import operate_environment
from khoj.processor.tools.online_search import read_webpages_content, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import (
ChatEvent,
generate_summary_from_files,
get_message_from_queue,
grep_files,
list_files,
search_documents,
send_message_to_model_wrapper,
view_file_content,
)
from khoj.utils.helpers import (
ConversationCommand,
ToolDefinition,
dict_to_tuple,
is_none_or_empty,
is_operator_enabled,
timer,
tools_for_research_llm,
truncate_code_context,
)
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
async def apick_next_tool(
query: str,
conversation_history: List[ChatMessageModel],
user: KhojUser = None,
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations: List[ResearchIteration] = [],
max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
max_document_searches: int = 7,
max_online_searches: int = 3,
max_webpages_to_read: int = 3,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
# Continue with previous iteration if a multi-step tool use is in progress
if (
previous_iterations
and previous_iterations[-1].query
and isinstance(previous_iterations[-1].query, ToolCall)
and previous_iterations[-1].query.name == ConversationCommand.Operator
and not previous_iterations[-1].summarizedResult
):
previous_iteration = previous_iterations[-1]
yield ResearchIteration(
query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext,
codeContext=previous_iteration.codeContext,
operatorContext=previous_iteration.operatorContext,
warning=previous_iteration.warning,
)
return
# Construct tool options for the agent to choose from
tools = []
tool_options_str = ""
agent_input_tools = agent.input_tools if agent else []
agent_tools = []
# Map agent user facing tools to research tools to include in agents toolbox
document_research_tools = [
ConversationCommand.SemanticSearchFiles,
ConversationCommand.RegexSearchFiles,
ConversationCommand.ViewFile,
ConversationCommand.ListFiles,
]
input_tools_to_research_tools = {
ConversationCommand.Notes.value: [tool.value for tool in document_research_tools],
ConversationCommand.Webpage.value: [ConversationCommand.ReadWebpage.value],
ConversationCommand.Online.value: [ConversationCommand.SearchWeb.value],
ConversationCommand.Code.value: [ConversationCommand.RunCode.value],
ConversationCommand.Operator.value: [ConversationCommand.OperateComputer.value],
}
for input_tool, research_tools in input_tools_to_research_tools.items():
if input_tool in agent_input_tools:
agent_tools += research_tools
user_has_entries = await EntryAdapters.auser_has_entries(user)
for tool, tool_data in tools_for_research_llm.items():
# Skip showing operator tool as an option if not enabled
if tool == ConversationCommand.OperateComputer and not is_operator_enabled():
continue
# Skip showing document related tools if user has no documents
if tool in document_research_tools and not user_has_entries:
continue
if tool == ConversationCommand.SemanticSearchFiles:
description = tool_data.description.format(max_search_queries=max_document_searches)
elif tool == ConversationCommand.ReadWebpage:
description = tool_data.description.format(max_webpages_to_read=max_webpages_to_read)
elif tool == ConversationCommand.SearchWeb:
description = tool_data.description.format(max_search_queries=max_online_searches)
else:
description = tool_data.description
# Add tool if agent does not have any tools defined or the tool is supported by the agent.
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
tools.append(
ToolDefinition(
name=tool.value,
description=description,
schema=tool_data.schema,
)
)
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str,
personality_context=personality_context,
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
username=user_name or "Unknown",
location=location_data,
max_iterations=max_iterations,
)
# Construct chat history with user and iteration history with researcher agent for context
iteration_chat_history = construct_iteration_history(previous_iterations, query, query_images, query_files)
chat_and_research_history = conversation_history + iteration_chat_history
try:
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query="",
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
tools=tools,
deepthought=True,
user=user,
query_images=query_images,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
except Exception as e:
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
yield ResearchIteration(
query=None,
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
)
return
try:
# Try parse the response as function call response to infer next tool to use.
# TODO: Handle multiple tool calls.
response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
except Exception as e:
# Otherwise assume the model has decided to end the research run and respond to the user.
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
# If we have a valid response, extract the tool and query.
warning = None
logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})")
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {
(i.query.name, dict_to_tuple(i.query.args))
for i in previous_iterations
if i.warning is None and isinstance(i.query, ToolCall)
}
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
elif send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
async def research(
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: List[ChatMessageModel],
previous_iterations: List[ResearchIteration],
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
user_name: str = None,
location: LocationData = None,
file_filters: List[str] = [],
tracer: dict = {},
query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: str = "␃🔚␗",
):
max_document_searches = 7
max_online_searches = 3
max_webpages_to_read = 1
current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
# Incorporate previous partial research into current research chat history
research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message]
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations)
research_conversation_history += previous_iterations_history
while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.debug(f"Research cancelled. User {user} disconnected client.")
break
# Update the query for the current research iteration
if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Research cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history
logger.info(
f"Continuing research for user {user} with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
)
previous_iterations_history = construct_iteration_history(
previous_iterations, query, query_images, query_files
)
research_conversation_history += previous_iterations_history
query = interrupt_query
previous_iterations = []
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
this_iteration = ResearchIteration(query=query)
async for result in apick_next_tool(
query,
research_conversation_history,
user,
location,
user_name,
agent,
previous_iterations,
MAX_ITERATIONS,
query_images=query_images,
query_files=query_files,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
max_webpages_to_read=max_webpages_to_read,
send_status_func=send_status_func,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, ResearchIteration):
this_iteration = result
yield this_iteration
# Skip running iteration if warning present in iteration
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")
# Terminate research if selected text tool or query, tool not set for next iteration
elif (
not this_iteration.query
or isinstance(this_iteration.query, str)
or this_iteration.query.name == ConversationCommand.Text
):
current_iteration = MAX_ITERATIONS
elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in search_documents(
**this_iteration.query.args,
n=max_document_searches,
d=None,
user=user,
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Default],
location_data=location,
send_status_func=send_status_func,
query_images=query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, tuple):
document_results = result[0]
this_iteration.context += document_results
if not is_none_or_empty(document_results):
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
this_iteration.warning = f"Error extracting document references: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
this_iteration.warning = "No matching document references found"
elif this_iteration.query.name == ConversationCommand.SearchWeb:
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
try:
async for result in search_online(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.SearchWeb
),
location=location,
user=user,
send_status_func=send_status_func,
custom_filters=[],
max_online_searches=max_online_searches,
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different."
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error searching online: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ReadWebpage:
try:
async for result in read_webpages_content(
**this_iteration.query.args,
user=user,
send_status_func=send_status_func,
# max_webpages_to_read=max_webpages_to_read,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages: Dict[str, Dict] = result # type: ignore
webpages = []
for web_query in direct_web_pages:
if online_results.get(web_query):
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
else:
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
for webpage in direct_web_pages[web_query]["webpages"]:
webpages.append(webpage["link"])
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error reading webpages: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.RunCode:
try:
async for result in run_code(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.RunCode),
context="",
location_data=location,
user=user,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results: Dict[str, Dict] = result # type: ignore
this_iteration.codeContext = code_results
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
yield result
except ValueError as e:
this_iteration.warning = f"Error running code: {e}"
logger.warning(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.OperateComputer:
try:
async for result in operate_environment(
**this_iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, OperatorRun):
operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
if result.webpages:
if not online_results.get(this_iteration.query):
online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
online_results[this_iteration.query]["webpages"] = result.webpages
else:
online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ViewFile:
try:
async for result in view_file_content(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = result # type: ignore
this_iteration.context += document_results
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
yield result
except Exception as e:
this_iteration.warning = f"Error viewing file: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ListFiles:
try:
async for result in list_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error listing files: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
try:
async for result in grep_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error searching with regex: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
current_iteration += 1
if document_results or online_results or code_results or operator_results or this_iteration.warning:
results_data = f"\n<iteration_{current_iteration}_results>"
if document_results:
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if operator_results:
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += f"\n</results>\n</iteration_{current_iteration}_results>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data
this_iteration.summarizedResult = this_iteration.summarizedResult or "Failed to get results."
previous_iterations.append(this_iteration)
yield this_iteration