diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index c6b3bc2f..0aaaa4b3 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -112,6 +112,7 @@ class InformationCollectionIteration:
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
+ warning: str = None,
):
self.tool = tool
self.query = query
@@ -119,6 +120,7 @@ class InformationCollectionIteration:
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
+ self.warning = warning
def construct_iteration_history(
diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py
index 34c4911a..d2d8c685 100644
--- a/src/khoj/processor/tools/online_search.py
+++ b/src/khoj/processor/tools/online_search.py
@@ -4,7 +4,7 @@ import logging
import os
import urllib.parse
from collections import defaultdict
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp
from bs4 import BeautifulSoup
@@ -66,6 +66,7 @@ async def search_online(
custom_filters: List[str] = [],
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None,
+ previous_subqueries: Set = set(),
agent: Agent = None,
tracer: dict = {},
):
@@ -76,19 +77,24 @@ async def search_online(
return
# Breakdown the query into subqueries to get the correct answer
- subqueries = await generate_online_subqueries(
+ new_subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
)
- response_dict = {}
+ subqueries = list(new_subqueries - previous_subqueries)
+ response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
- if subqueries:
- logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
- if send_status_func:
- subqueries_str = "\n- " + "\n- ".join(list(subqueries))
- async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
- yield {ChatEvent.STATUS: event}
+ if is_none_or_empty(subqueries):
+ logger.info("No new subqueries to search online")
+ yield response_dict
+ return
- with timer(f"Internet searches for {list(subqueries)} took", logger):
+ logger.info(f"🌐 Searching the Internet for {subqueries}")
+ if send_status_func:
+ subqueries_str = "\n- " + "\n- ".join(subqueries)
+ async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
+ yield {ChatEvent.STATUS: event}
+
+ with timer(f"Internet searches for {subqueries} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
@@ -119,7 +125,9 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
- read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
+ read_webpage_and_extract_content(
+ data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
+ )
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index bed7c27b..3eb2dea5 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -6,7 +6,7 @@ import os
import threading
import time
import uuid
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, List, Optional, Set, Union
import cron_descriptor
import pytz
@@ -349,6 +349,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
+ previous_inferred_queries: Set = set(),
agent: Agent = None,
tracer: dict = {},
):
@@ -477,6 +478,7 @@ async def extract_references_and_questions(
)
# Collate search results as context for GPT
+ inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 8646a695..c30f4cf8 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -778,7 +778,8 @@ async def chat(
yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent)
- logger.info(f"Researched Results: {researched_results}")
+ if state.verbose > 1:
+ logger.debug(f"Researched Results: {researched_results}")
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index c7a90fc3..d89ed147 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -20,6 +20,7 @@ from typing import (
Iterator,
List,
Optional,
+ Set,
Tuple,
Union,
)
@@ -494,7 +495,7 @@ async def generate_online_subqueries(
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
-) -> List[str]:
+) -> Set[str]:
"""
Generate subqueries from the given query
"""
@@ -529,14 +530,14 @@ async def generate_online_subqueries(
try:
response = clean_json(response)
response = json.loads(response)
- response = [q.strip() for q in response["queries"] if q.strip()]
- if not isinstance(response, list) or not response or len(response) == 0:
+ response = {q.strip() for q in response["queries"] if q.strip()}
+ if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
- return [q]
+ return {q}
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
- return [q]
+ return {q}
async def schedule_query(
@@ -1128,9 +1129,6 @@ def generate_chat_response(
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
- query_to_run = q
- if meta_research:
- query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
@@ -1148,6 +1146,13 @@ def generate_chat_response(
train_of_thought=train_of_thought,
)
+ query_to_run = q
+ if meta_research:
+ query_to_run = f"{q}\n\n{meta_research}\n"
+ compiled_references = []
+ online_results = {}
+ code_results = {}
+
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:
diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py
index 46d4c424..1caf7c96 100644
--- a/src/khoj/routers/research.py
+++ b/src/khoj/routers/research.py
@@ -43,38 +43,35 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
- previous_iterations_history: str = None,
+ previous_iterations: List[InformationCollectionIteration] = [],
max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
- """
- Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
- """
+ """Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
+ # Construct tool options for the agent to choose from
tool_options = dict()
tool_options_str = ""
-
agent_tools = agent.input_tools if agent else []
-
for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
+ # Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
+ previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
if query_images:
query = f"[placeholder for user attached images]\n{query}"
+ today = datetime.today()
+ location_data = f"{location}" if location else "Unknown"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
- # Extract Past User Message and Inferred Questions from Conversation Log
- today = datetime.today()
- location_data = f"{location}" if location else "Unknown"
-
function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str,
chat_history=chat_history,
@@ -112,8 +109,15 @@ async def apick_next_tool(
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
+ warning = None
logger.info(f"Response for determining relevant tools: {response}")
- if send_status_func:
+
+ # Detect selection of previously used query, tool combination.
+ previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
+ if (selected_tool, generated_query) in previous_tool_query_combinations:
+ warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
+ # Only send client status updates if we'll execute this iteration
+ elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
@@ -123,13 +127,14 @@ async def apick_next_tool(
yield InformationCollectionIteration(
tool=selected_tool,
query=generated_query,
+ warning=warning,
)
-
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
tool=None,
query=None,
+ warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)
@@ -156,7 +161,6 @@ async def execute_information_collection(
document_results: List[Dict[str, str]] = []
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
- previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
async for result in apick_next_tool(
query,
@@ -166,7 +170,7 @@ async def execute_information_collection(
location,
user_name,
agent,
- previous_iterations_history,
+ previous_iterations,
MAX_ITERATIONS,
send_status_func,
tracer=tracer,
@@ -176,9 +180,16 @@ async def execute_information_collection(
elif isinstance(result, InformationCollectionIteration):
this_iteration = result
- if this_iteration.tool == ConversationCommand.Notes:
+ # Skip running iteration if warning present in iteration
+ if this_iteration.warning:
+ logger.warning(f"Research mode: {this_iteration.warning}.")
+
+ elif this_iteration.tool == ConversationCommand.Notes:
this_iteration.context = []
document_results = []
+ previous_inferred_queries = {
+ c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
+ }
async for result in extract_references_and_questions(
request,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
@@ -190,6 +201,7 @@ async def execute_information_collection(
location,
send_status_func,
query_images,
+ previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
):
@@ -213,6 +225,12 @@ async def execute_information_collection(
logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online:
+ previous_subqueries = {
+ subquery
+ for iteration in previous_iterations
+ if iteration.onlineContext
+ for subquery in iteration.onlineContext.keys()
+ }
async for result in search_online(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
@@ -222,11 +240,16 @@ async def execute_information_collection(
[],
max_webpages_to_read=0,
query_images=query_images,
+ previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
+ elif is_none_or_empty(result):
+ this_iteration.warning = (
+ "Detected previously run online search queries. Skipping iteration. Try something different."
+ )
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
@@ -311,16 +334,19 @@ async def execute_information_collection(
current_iteration += 1
- if document_results or online_results or code_results or summarize_files:
- results_data = f"**Results**:\n"
+ if document_results or online_results or code_results or summarize_files or this_iteration.warning:
+ results_data = f"\n{current_iteration}\n{this_iteration.tool}\n{this_iteration.query}\n"
if document_results:
- results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if online_results:
- results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ results_data += f"\n\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if code_results:
- results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ results_data += f"\n\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if summarize_files:
- results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ results_data += f"\n\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ if this_iteration.warning:
+ results_data += f"\n\n{this_iteration.warning}\n"
+ results_data += "\n\n"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data