diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index c6b3bc2f..0aaaa4b3 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -112,6 +112,7 @@ class InformationCollectionIteration: onlineContext: dict = None, codeContext: dict = None, summarizedResult: str = None, + warning: str = None, ): self.tool = tool self.query = query @@ -119,6 +120,7 @@ class InformationCollectionIteration: self.onlineContext = onlineContext self.codeContext = codeContext self.summarizedResult = summarizedResult + self.warning = warning def construct_iteration_history( diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 34c4911a..d2d8c685 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -4,7 +4,7 @@ import logging import os import urllib.parse from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import aiohttp from bs4 import BeautifulSoup @@ -66,6 +66,7 @@ async def search_online( custom_filters: List[str] = [], max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, query_images: List[str] = None, + previous_subqueries: Set = set(), agent: Agent = None, tracer: dict = {}, ): @@ -76,19 +77,24 @@ async def search_online( return # Breakdown the query into subqueries to get the correct answer - subqueries = await generate_online_subqueries( + new_subqueries = await generate_online_subqueries( query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer ) - response_dict = {} + subqueries = list(new_subqueries - previous_subqueries) + response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {} - if subqueries: - logger.info(f"🌐 Searching the Internet for {list(subqueries)}") - if send_status_func: - subqueries_str = "\n- " + "\n- ".join(list(subqueries)) - async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"): - yield {ChatEvent.STATUS: event} + if is_none_or_empty(subqueries): + logger.info("No new subqueries to search online") + yield response_dict + return - with timer(f"Internet searches for {list(subqueries)} took", logger): + logger.info(f"🌐 Searching the Internet for {subqueries}") + if send_status_func: + subqueries_str = "\n- " + "\n- ".join(subqueries) + async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"): + yield {ChatEvent.STATUS: event} + + with timer(f"Internet searches for {subqueries} took", logger): search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina search_tasks = [search_func(subquery, location) for subquery in subqueries] search_results = await asyncio.gather(*search_tasks) @@ -119,7 +125,9 @@ async def search_online( async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer) + read_webpage_and_extract_content( + data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer + ) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index bed7c27b..3eb2dea5 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,7 +6,7 @@ import os import threading import time import uuid -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Set, Union import cron_descriptor import pytz @@ -349,6 +349,7 @@ async def extract_references_and_questions( location_data: LocationData = None, send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, + previous_inferred_queries: Set = set(), agent: Agent = None, tracer: dict = {}, ): @@ -477,6 +478,7 @@ async def extract_references_and_questions( ) # Collate search results as context for GPT + inferred_queries = list(set(inferred_queries) - previous_inferred_queries) with timer("Searching knowledge base took", logger): search_results = [] logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 8646a695..c30f4cf8 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -778,7 +778,8 @@ async def chat( yield research_result # researched_results = await extract_relevant_info(q, researched_results, agent) - logger.info(f"Researched Results: {researched_results}") + if state.verbose > 1: + logger.debug(f"Researched Results: {researched_results}") used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] file_filters = conversation.file_filters if conversation else [] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c7a90fc3..d89ed147 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -20,6 +20,7 @@ from typing import ( Iterator, List, Optional, + Set, Tuple, Union, ) @@ -494,7 +495,7 @@ async def generate_online_subqueries( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, -) -> List[str]: +) -> Set[str]: """ Generate subqueries from the given query """ @@ -529,14 +530,14 @@ async def generate_online_subqueries( try: response = clean_json(response) response = json.loads(response) - response = [q.strip() for q in response["queries"] if q.strip()] - if not isinstance(response, list) or not response or len(response) == 0: + response = {q.strip() for q in response["queries"] if q.strip()} + if not isinstance(response, set) or not response or len(response) == 0: logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") - return [q] + return {q} return response except Exception as e: logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") - return [q] + return {q} async def schedule_query( @@ -1128,9 +1129,6 @@ def generate_chat_response( metadata = {} agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None - query_to_run = q - if meta_research: - query_to_run = f"AI Research: {meta_research} {q}" try: partial_completion = partial( save_to_conversation_log, @@ -1148,6 +1146,13 @@ def generate_chat_response( train_of_thought=train_of_thought, ) + query_to_run = q + if meta_research: + query_to_run = f"{q}\n\n{meta_research}\n" + compiled_references = [] + online_results = {} + code_results = {} + conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) vision_available = conversation_config.vision_enabled if not vision_available and query_images: diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 46d4c424..1caf7c96 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -43,38 +43,35 @@ async def apick_next_tool( location: LocationData = None, user_name: str = None, agent: Agent = None, - previous_iterations_history: str = None, + previous_iterations: List[InformationCollectionIteration] = [], max_iterations: int = 5, send_status_func: Optional[Callable] = None, tracer: dict = {}, ): - """ - Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. - """ + """Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" + # Construct tool options for the agent to choose from tool_options = dict() tool_options_str = "" - agent_tools = agent.input_tools if agent else [] - for tool, description in function_calling_description_for_llm.items(): tool_options[tool.value] = description if len(agent_tools) == 0 or tool.value in agent_tools: tool_options_str += f'- "{tool.value}": "{description}"\n' + # Construct chat history with user and iteration history with researcher agent for context chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) if query_images: query = f"[placeholder for user attached images]\n{query}" + today = datetime.today() + location_data = f"{location}" if location else "Unknown" personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) - # Extract Past User Message and Inferred Questions from Conversation Log - today = datetime.today() - location_data = f"{location}" if location else "Unknown" - function_planning_prompt = prompts.plan_function_execution.format( tools=tool_options_str, chat_history=chat_history, @@ -112,8 +109,15 @@ async def apick_next_tool( selected_tool = response.get("tool", None) generated_query = response.get("query", None) scratchpad = response.get("scratchpad", None) + warning = None logger.info(f"Response for determining relevant tools: {response}") - if send_status_func: + + # Detect selection of previously used query, tool combination. + previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations} + if (selected_tool, generated_query) in previous_tool_query_combinations: + warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." + # Only send client status updates if we'll execute this iteration + elif send_status_func: determined_tool_message = "**Determined Tool**: " determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond." determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else "" @@ -123,13 +127,14 @@ async def apick_next_tool( yield InformationCollectionIteration( tool=selected_tool, query=generated_query, + warning=warning, ) - except Exception as e: logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) yield InformationCollectionIteration( tool=None, query=None, + warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}", ) @@ -156,7 +161,6 @@ async def execute_information_collection( document_results: List[Dict[str, str]] = [] summarize_files: str = "" this_iteration = InformationCollectionIteration(tool=None, query=query) - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) async for result in apick_next_tool( query, @@ -166,7 +170,7 @@ async def execute_information_collection( location, user_name, agent, - previous_iterations_history, + previous_iterations, MAX_ITERATIONS, send_status_func, tracer=tracer, @@ -176,9 +180,16 @@ async def execute_information_collection( elif isinstance(result, InformationCollectionIteration): this_iteration = result - if this_iteration.tool == ConversationCommand.Notes: + # Skip running iteration if warning present in iteration + if this_iteration.warning: + logger.warning(f"Research mode: {this_iteration.warning}.") + + elif this_iteration.tool == ConversationCommand.Notes: this_iteration.context = [] document_results = [] + previous_inferred_queries = { + c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context + } async for result in extract_references_and_questions( request, construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), @@ -190,6 +201,7 @@ async def execute_information_collection( location, send_status_func, query_images, + previous_inferred_queries=previous_inferred_queries, agent=agent, tracer=tracer, ): @@ -213,6 +225,12 @@ async def execute_information_collection( logger.error(f"Error extracting document references: {e}", exc_info=True) elif this_iteration.tool == ConversationCommand.Online: + previous_subqueries = { + subquery + for iteration in previous_iterations + if iteration.onlineContext + for subquery in iteration.onlineContext.keys() + } async for result in search_online( this_iteration.query, construct_tool_chat_history(previous_iterations, ConversationCommand.Online), @@ -222,11 +240,16 @@ async def execute_information_collection( [], max_webpages_to_read=0, query_images=query_images, + previous_subqueries=previous_subqueries, agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] + elif is_none_or_empty(result): + this_iteration.warning = ( + "Detected previously run online search queries. Skipping iteration. Try something different." + ) else: online_results: Dict[str, Dict] = result # type: ignore this_iteration.onlineContext = online_results @@ -311,16 +334,19 @@ async def execute_information_collection( current_iteration += 1 - if document_results or online_results or code_results or summarize_files: - results_data = f"**Results**:\n" + if document_results or online_results or code_results or summarize_files or this_iteration.warning: + results_data = f"\n{current_iteration}\n{this_iteration.tool}\n{this_iteration.query}\n" if document_results: - results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if online_results: - results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"\n\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if code_results: - results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"\n\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if summarize_files: - results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"\n\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if this_iteration.warning: + results_data += f"\n\n{this_iteration.warning}\n" + results_data += "\n\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data