From 7d59688729559bd2ab5bbb3d62f88ecec849a897 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 5 Jun 2025 00:12:42 -0700 Subject: [PATCH] Move document search tool into helpers module with other tools Document search (because of its age) was the only tool directly within an api router. Put it into helpers to have all the (mini) tools in one place. --- src/khoj/routers/api.py | 305 +---------------------------------- src/khoj/routers/api_chat.py | 2 +- src/khoj/routers/helpers.py | 289 ++++++++++++++++++++++++++++++++- src/khoj/routers/research.py | 2 +- tests/test_agents.py | 2 +- 5 files changed, 297 insertions(+), 303 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 805030d8..a4f24999 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,19 +1,16 @@ -import concurrent.futures import json import logging import math import os import threading -import time import uuid -from typing import Any, Callable, List, Optional, Set, Union +from typing import List, Optional, Union import cron_descriptor import openai import pytz from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger -from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from fastapi.requests import Request from fastapi.responses import Response @@ -22,48 +19,28 @@ from starlette.authentication import has_required_scope, requires from khoj.configure import initialize_content from khoj.database import adapters from khoj.database.adapters import ( - AgentAdapters, AutomationAdapters, ConversationAdapters, EntryAdapters, - get_default_search_model, get_user_photo, ) -from khoj.database.models import ( - Agent, - ChatMessageModel, - ChatModel, - KhojUser, - SpeechToTextModelOptions, -) -from khoj.processor.conversation import prompts -from khoj.processor.conversation.anthropic.anthropic_chat import ( - extract_questions_anthropic, -) -from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini -from khoj.processor.conversation.offline.chat_model import extract_questions_offline +from khoj.database.models import KhojUser, SpeechToTextModelOptions from khoj.processor.conversation.offline.whisper import transcribe_audio_offline -from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio -from khoj.processor.conversation.utils import clean_json, defilter_query +from khoj.processor.conversation.utils import clean_json from khoj.routers.helpers import ( ApiUserRateLimiter, - ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, + execute_search, get_user_config, schedule_automation, schedule_query, update_telemetry_state, ) -from khoj.search_filter.date_filter import DateFilter -from khoj.search_filter.file_filter import FileFilter -from khoj.search_filter.word_filter import WordFilter -from khoj.search_type import text_search from khoj.utils import state -from khoj.utils.config import OfflineChatProcessorModel -from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer -from khoj.utils.rawconfig import LocationData, SearchResponse +from khoj.utils.helpers import is_none_or_empty +from khoj.utils.rawconfig import SearchResponse from khoj.utils.state import SearchType # Initialize Router @@ -116,98 +93,6 @@ async def search( return results -async def execute_search( - user: KhojUser, - q: str, - n: Optional[int] = 5, - t: Optional[SearchType] = SearchType.All, - r: Optional[bool] = False, - max_distance: Optional[Union[float, None]] = None, - dedupe: Optional[bool] = True, - agent: Optional[Agent] = None, -): - # Run validation checks - results: List[SearchResponse] = [] - - start_time = time.time() - - # Ensure the agent, if present, is accessible by the user - if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user): - logger.error(f"Agent {agent.slug} is not accessible by user {user}") - return results - - if q is None or q == "": - logger.warning(f"No query param (q) passed in API call to initiate search") - return results - - # initialize variables - user_query = q.strip() - results_count = n or 5 - search_futures: List[concurrent.futures.Future] = [] - - # return cached results, if available - if user: - query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" - if query_cache_key in state.query_cache[user.uuid]: - logger.debug(f"Return response from query cache") - return state.query_cache[user.uuid][query_cache_key] - - # Encode query with filter terms removed - defiltered_query = user_query - for filter in [DateFilter(), WordFilter(), FileFilter()]: - defiltered_query = filter.defilter(defiltered_query) - - encoded_asymmetric_query = None - if t != SearchType.Image: - with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_default_search_model)() - encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) - - with concurrent.futures.ThreadPoolExecutor() as executor: - if t in [ - SearchType.All, - SearchType.Org, - SearchType.Markdown, - SearchType.Github, - SearchType.Notion, - SearchType.Plaintext, - SearchType.Pdf, - ]: - # query markdown notes - search_futures += [ - executor.submit( - text_search.query, - user_query, - user, - t, - question_embedding=encoded_asymmetric_query, - max_distance=max_distance, - agent=agent, - ) - ] - - # Query across each requested content types in parallel - with timer("Query took", logger): - for search_future in concurrent.futures.as_completed(search_futures): - hits = await search_future.result() - # Collate results - results += text_search.collate_results(hits, dedupe=dedupe) - - # Sort results across all content types and take top results - results = text_search.rerank_and_sort_results( - results, query=defiltered_query, rank_results=r, search_model_name=search_model.name - )[:results_count] - - # Cache results - if user: - state.query_cache[user.uuid][query_cache_key] = results - - end_time = time.time() - logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") - - return results - - @api.get("/update") @requires(["authenticated"]) def update( @@ -357,184 +242,6 @@ def set_user_name( return {"status": "ok"} -async def search_documents( - user: KhojUser, - chat_history: list[ChatMessageModel], - q: str, - n: int, - d: float, - conversation_id: str, - conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], - location_data: LocationData = None, - send_status_func: Optional[Callable] = None, - query_images: Optional[List[str]] = None, - previous_inferred_queries: Set = set(), - agent: Agent = None, - query_files: str = None, - tracer: dict = {}, -): - # Initialize Variables - compiled_references: List[dict[str, str]] = [] - inferred_queries: List[str] = [] - - agent_has_entries = False - - if agent: - agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) - - if ( - not ConversationCommand.Notes in conversation_commands - and not ConversationCommand.Default in conversation_commands - and not agent_has_entries - ): - yield compiled_references, inferred_queries, q - return - - # If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base - should_limit_to_agent_knowledge = ( - ConversationCommand.Notes not in conversation_commands - and ConversationCommand.Default not in conversation_commands - ) - - if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): - if not agent_has_entries: - logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") - yield compiled_references, inferred_queries, q - return - - # Extract filter terms from user message - defiltered_query = defilter_query(q) - filters_in_query = q.replace(defiltered_query, "").strip() - conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id) - - if not conversation: - logger.error(f"Conversation with id {conversation_id} not found when extracting references.") - yield compiled_references, inferred_queries, defiltered_query - return - - filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) - using_offline_chat = False - if is_none_or_empty(filters_in_query): - logger.debug(f"Filters in query: {filters_in_query}") - - personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" - - # Infer search queries from user message - with timer("Extracting search queries took", logger): - # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - chat_model = await ConversationAdapters.aget_default_chat_model(user) - vision_enabled = chat_model.vision_enabled - - if chat_model.model_type == ChatModel.ModelType.OFFLINE: - using_offline_chat = True - chat_model_name = chat_model.name - max_tokens = chat_model.max_prompt_size - if state.offline_chat_processor_config is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - - loaded_model = state.offline_chat_processor_config.loaded_model - - inferred_queries = extract_questions_offline( - defiltered_query, - model=chat_model, - loaded_model=loaded_model, - chat_history=chat_history, - should_extract_questions=True, - location_data=location_data, - user=user, - max_prompt_size=chat_model.max_prompt_size, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - elif chat_model.model_type == ChatModel.ModelType.OPENAI: - api_key = chat_model.ai_model_api.api_key - base_url = chat_model.ai_model_api.api_base_url - chat_model_name = chat_model.name - inferred_queries = extract_questions( - defiltered_query, - model=chat_model_name, - api_key=api_key, - api_base_url=base_url, - chat_history=chat_history, - location_data=location_data, - user=user, - query_images=query_images, - vision_enabled=vision_enabled, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - chat_model_name = chat_model.name - inferred_queries = extract_questions_anthropic( - defiltered_query, - query_images=query_images, - model=chat_model_name, - api_key=api_key, - api_base_url=api_base_url, - chat_history=chat_history, - location_data=location_data, - user=user, - vision_enabled=vision_enabled, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - elif chat_model.model_type == ChatModel.ModelType.GOOGLE: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - chat_model_name = chat_model.name - inferred_queries = extract_questions_gemini( - defiltered_query, - query_images=query_images, - model=chat_model_name, - api_key=api_key, - api_base_url=api_base_url, - chat_history=chat_history, - location_data=location_data, - max_tokens=chat_model.max_prompt_size, - user=user, - vision_enabled=vision_enabled, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - - # Collate search results as context for GPT - inferred_queries = list(set(inferred_queries) - previous_inferred_queries) - with timer("Searching knowledge base took", logger): - search_results = [] - logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") - if send_status_func: - inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) - async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"): - yield {ChatEvent.STATUS: event} - for query in inferred_queries: - n_items = min(n, 3) if using_offline_chat else n - search_results.extend( - await execute_search( - user if not should_limit_to_agent_knowledge else None, - f"{query} {filters_in_query}", - n=n_items, - t=SearchType.All, - r=True, - max_distance=d, - dedupe=False, - agent=agent, - ) - ) - search_results = text_search.deduplicated_search_responses(search_results) - compiled_references = [ - {"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]} - for q, item in zip(inferred_queries, search_results) - ] - - yield compiled_references, inferred_queries, defiltered_query - - @api.get("/health", response_class=Response) @requires(["authenticated"], status_code=200) def health_check(request: Request) -> Response: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 62090052..ce5bfbc5 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -40,7 +40,6 @@ from khoj.processor.tools.online_search import ( search_online, ) from khoj.processor.tools.run_code import run_code -from khoj.routers.api import search_documents from khoj.routers.email import send_query_feedback from khoj.routers.helpers import ( ApiImageRateLimiter, @@ -63,6 +62,7 @@ from khoj.routers.helpers import ( is_query_empty, is_ready_to_chat, read_chat_stream, + search_documents, update_telemetry_state, validate_chat_model, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4d5082d9..ff31fcad 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,10 +1,12 @@ import base64 +import concurrent.futures import hashlib import json import logging import math import os import re +import time from datetime import datetime, timedelta, timezone from random import random from typing import ( @@ -45,6 +47,7 @@ from khoj.database.adapters import ( aget_user_by_email, ais_user_subscribed, create_khoj_token, + get_default_search_model, get_khoj_tokens, get_user_name, get_user_notion_config, @@ -79,17 +82,21 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.anthropic_chat import ( anthropic_send_message_to_model, converse_anthropic, + extract_questions_anthropic, ) from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, + extract_questions_gemini, gemini_send_message_to_model, ) from khoj.processor.conversation.offline.chat_model import ( converse_offline, + extract_questions_offline, send_message_to_model_offline, ) from khoj.processor.conversation.openai.gpt import ( converse_openai, + extract_questions, send_message_to_model, ) from khoj.processor.conversation.utils import ( @@ -100,11 +107,15 @@ from khoj.processor.conversation.utils import ( clean_json, clean_mermaidjs, construct_chat_history, + defilter_query, generate_chatml_messages_with_context, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.email import is_resend_enabled, send_task_email from khoj.routers.twilio import is_twilio_enabled +from khoj.search_filter.date_filter import DateFilter +from khoj.search_filter.file_filter import FileFilter +from khoj.search_filter.word_filter import WordFilter from khoj.search_type import text_search from khoj.utils import state from khoj.utils.config import OfflineChatProcessorModel @@ -121,7 +132,13 @@ from khoj.utils.helpers import ( timer, tool_descriptions_for_llm, ) -from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData +from khoj.utils.rawconfig import ( + ChatRequestBody, + FileAttachment, + FileData, + LocationData, + SearchResponse, +) logger = logging.getLogger(__name__) @@ -1149,6 +1166,276 @@ async def generate_better_image_prompt( return response +async def search_documents( + user: KhojUser, + chat_history: list[ChatMessageModel], + q: str, + n: int, + d: float, + conversation_id: str, + conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], + location_data: LocationData = None, + send_status_func: Optional[Callable] = None, + query_images: Optional[List[str]] = None, + previous_inferred_queries: Set = set(), + agent: Agent = None, + query_files: str = None, + tracer: dict = {}, +): + # Initialize Variables + compiled_references: List[dict[str, str]] = [] + inferred_queries: List[str] = [] + + agent_has_entries = False + + if agent: + agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) + + if ( + not ConversationCommand.Notes in conversation_commands + and not ConversationCommand.Default in conversation_commands + and not agent_has_entries + ): + yield compiled_references, inferred_queries, q + return + + # If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base + should_limit_to_agent_knowledge = ( + ConversationCommand.Notes not in conversation_commands + and ConversationCommand.Default not in conversation_commands + ) + + if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): + if not agent_has_entries: + logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") + yield compiled_references, inferred_queries, q + return + + # Extract filter terms from user message + defiltered_query = defilter_query(q) + filters_in_query = q.replace(defiltered_query, "").strip() + conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id) + + if not conversation: + logger.error(f"Conversation with id {conversation_id} not found when extracting references.") + yield compiled_references, inferred_queries, defiltered_query + return + + filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) + using_offline_chat = False + if is_none_or_empty(filters_in_query): + logger.debug(f"Filters in query: {filters_in_query}") + + personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" + + # Infer search queries from user message + with timer("Extracting search queries took", logger): + # If we've reached here, either the user has enabled offline chat or the openai model is enabled. + chat_model = await ConversationAdapters.aget_default_chat_model(user) + vision_enabled = chat_model.vision_enabled + + if chat_model.model_type == ChatModel.ModelType.OFFLINE: + using_offline_chat = True + chat_model_name = chat_model.name + max_tokens = chat_model.max_prompt_size + if state.offline_chat_processor_config is None: + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) + + loaded_model = state.offline_chat_processor_config.loaded_model + + inferred_queries = extract_questions_offline( + defiltered_query, + model=chat_model, + loaded_model=loaded_model, + chat_history=chat_history, + should_extract_questions=True, + location_data=location_data, + user=user, + max_prompt_size=chat_model.max_prompt_size, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + elif chat_model.model_type == ChatModel.ModelType.OPENAI: + api_key = chat_model.ai_model_api.api_key + base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name + inferred_queries = extract_questions( + defiltered_query, + model=chat_model_name, + api_key=api_key, + api_base_url=base_url, + chat_history=chat_history, + location_data=location_data, + user=user, + query_images=query_images, + vision_enabled=vision_enabled, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name + inferred_queries = extract_questions_anthropic( + defiltered_query, + query_images=query_images, + model=chat_model_name, + api_key=api_key, + api_base_url=api_base_url, + chat_history=chat_history, + location_data=location_data, + user=user, + vision_enabled=vision_enabled, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name + inferred_queries = extract_questions_gemini( + defiltered_query, + query_images=query_images, + model=chat_model_name, + api_key=api_key, + api_base_url=api_base_url, + chat_history=chat_history, + location_data=location_data, + max_tokens=chat_model.max_prompt_size, + user=user, + vision_enabled=vision_enabled, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + + # Collate search results as context for GPT + inferred_queries = list(set(inferred_queries) - previous_inferred_queries) + with timer("Searching knowledge base took", logger): + search_results = [] + logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") + if send_status_func: + inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) + async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"): + yield {ChatEvent.STATUS: event} + for query in inferred_queries: + n_items = min(n, 3) if using_offline_chat else n + search_results.extend( + await execute_search( + user if not should_limit_to_agent_knowledge else None, + f"{query} {filters_in_query}", + n=n_items, + t=state.SearchType.All, + r=True, + max_distance=d, + dedupe=False, + agent=agent, + ) + ) + search_results = text_search.deduplicated_search_responses(search_results) + compiled_references = [ + {"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]} + for q, item in zip(inferred_queries, search_results) + ] + + yield compiled_references, inferred_queries, defiltered_query + + +async def execute_search( + user: KhojUser, + q: str, + n: Optional[int] = 5, + t: Optional[state.SearchType] = state.SearchType.All, + r: Optional[bool] = False, + max_distance: Optional[Union[float, None]] = None, + dedupe: Optional[bool] = True, + agent: Optional[Agent] = None, +): + # Run validation checks + results: List[SearchResponse] = [] + + start_time = time.time() + + # Ensure the agent, if present, is accessible by the user + if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user): + logger.error(f"Agent {agent.slug} is not accessible by user {user}") + return results + + if q is None or q == "": + logger.warning(f"No query param (q) passed in API call to initiate search") + return results + + # initialize variables + user_query = q.strip() + results_count = n or 5 + search_futures: List[concurrent.futures.Future] = [] + + # return cached results, if available + if user: + query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" + if query_cache_key in state.query_cache[user.uuid]: + logger.debug(f"Return response from query cache") + return state.query_cache[user.uuid][query_cache_key] + + # Encode query with filter terms removed + defiltered_query = user_query + for filter in [DateFilter(), WordFilter(), FileFilter()]: + defiltered_query = filter.defilter(defiltered_query) + + encoded_asymmetric_query = None + if t != state.SearchType.Image: + with timer("Encoding query took", logger=logger): + search_model = await sync_to_async(get_default_search_model)() + encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) + + with concurrent.futures.ThreadPoolExecutor() as executor: + if t in [ + state.SearchType.All, + state.SearchType.Org, + state.SearchType.Markdown, + state.SearchType.Github, + state.SearchType.Notion, + state.SearchType.Plaintext, + state.SearchType.Pdf, + ]: + # query markdown notes + search_futures += [ + executor.submit( + text_search.query, + user_query, + user, + t, + question_embedding=encoded_asymmetric_query, + max_distance=max_distance, + agent=agent, + ) + ] + + # Query across each requested content types in parallel + with timer("Query took", logger): + for search_future in concurrent.futures.as_completed(search_futures): + hits = await search_future.result() + # Collate results + results += text_search.collate_results(hits, dedupe=dedupe) + + # Sort results across all content types and take top results + results = text_search.rerank_and_sort_results( + results, query=defiltered_query, rank_results=r, search_model_name=search_model.name + )[:results_count] + + # Cache results + if user: + state.query_cache[user.uuid][query_cache_key] = results + + end_time = time.time() + logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") + + return results + + async def send_message_to_model_wrapper( query: str, system_message: str = "", diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 43617e6c..881e49b9 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -22,10 +22,10 @@ from khoj.processor.conversation.utils import ( from khoj.processor.operator import operate_environment from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.run_code import run_code -from khoj.routers.api import search_documents from khoj.routers.helpers import ( ChatEvent, generate_summary_from_files, + search_documents, send_message_to_model_wrapper, ) from khoj.utils.helpers import ( diff --git a/tests/test_agents.py b/tests/test_agents.py index 242495e6..f573bed9 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -6,7 +6,7 @@ from asgiref.sync import sync_to_async from khoj.database.adapters import AgentAdapters from khoj.database.models import Agent, ChatModel, Entry, KhojUser -from khoj.routers.api import execute_search +from khoj.routers.helpers import execute_search from khoj.utils.helpers import get_absolute_path from tests.helpers import ChatModelFactory