Move document search tool into helpers module with other tools

Document search (because of its age) was the only tool directly within
an api router. Put it into helpers to have all the (mini) tools in one
place.
This commit is contained in:
Debanjum
2025-06-05 00:12:42 -07:00
parent 1dbe60a8a2
commit 7d59688729
5 changed files with 297 additions and 303 deletions

View File

@@ -1,19 +1,16 @@
import concurrent.futures
import json
import logging
import math
import os
import threading
import time
import uuid
from typing import Any, Callable, List, Optional, Set, Union
from typing import List, Optional, Union
import cron_descriptor
import openai
import pytz
from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.requests import Request
from fastapi.responses import Response
@@ -22,48 +19,28 @@ from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import (
AgentAdapters,
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
get_default_search_model,
get_user_photo,
)
from khoj.database.models import (
Agent,
ChatMessageModel,
ChatModel,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
from khoj.database.models import KhojUser, SpeechToTextModelOptions
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.utils import clean_json, defilter_query
from khoj.processor.conversation.utils import clean_json
from khoj.routers.helpers import (
ApiUserRateLimiter,
ChatEvent,
CommonQueryParams,
ConversationCommandRateLimiter,
execute_search,
get_user_config,
schedule_automation,
schedule_query,
update_telemetry_state,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer
from khoj.utils.rawconfig import LocationData, SearchResponse
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import SearchResponse
from khoj.utils.state import SearchType
# Initialize Router
@@ -116,98 +93,6 @@ async def search(
return results
async def execute_search(
user: KhojUser,
q: str,
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
agent: Optional[Agent] = None,
):
# Run validation checks
results: List[SearchResponse] = []
start_time = time.time()
# Ensure the agent, if present, is accessible by the user
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
return results
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
# initialize variables
user_query = q.strip()
results_count = n or 5
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed
defiltered_query = user_query
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
if t in [
SearchType.All,
SearchType.Org,
SearchType.Markdown,
SearchType.Github,
SearchType.Notion,
SearchType.Plaintext,
SearchType.Pdf,
]:
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user_query,
user,
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
agent=agent,
)
]
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
hits = await search_future.result()
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
# Sort results across all content types and take top results
results = text_search.rerank_and_sort_results(
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
)[:results_count]
# Cache results
if user:
state.query_cache[user.uuid][query_cache_key] = results
end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
return results
@api.get("/update")
@requires(["authenticated"])
def update(
@@ -357,184 +242,6 @@ def set_user_name(
return {"status": "ok"}
async def search_documents(
user: KhojUser,
chat_history: list[ChatMessageModel],
q: str,
n: int,
d: float,
conversation_id: str,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
):
# Initialize Variables
compiled_references: List[dict[str, str]] = []
inferred_queries: List[str] = []
agent_has_entries = False
if agent:
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if (
not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands
and not agent_has_entries
):
yield compiled_references, inferred_queries, q
return
# If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base
should_limit_to_agent_knowledge = (
ConversationCommand.Notes not in conversation_commands
and ConversationCommand.Default not in conversation_commands
)
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
if not agent_has_entries:
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
yield compiled_references, inferred_queries, q
return
# Extract filter terms from user message
defiltered_query = defilter_query(q)
filters_in_query = q.replace(defiltered_query, "").strip()
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
if not conversation:
logger.error(f"Conversation with id {conversation_id} not found when extracting references.")
yield compiled_references, inferred_queries, defiltered_query
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False
if is_none_or_empty(filters_in_query):
logger.debug(f"Filters in query: {filters_in_query}")
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = chat_model.vision_enabled
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query,
model=chat_model,
loaded_model=loaded_model,
chat_history=chat_history,
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions(
defiltered_query,
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
# Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
yield {ChatEvent.STATUS: event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(
await execute_search(
user if not should_limit_to_agent_knowledge else None,
f"{query} {filters_in_query}",
n=n_items,
t=SearchType.All,
r=True,
max_distance=d,
dedupe=False,
agent=agent,
)
)
search_results = text_search.deduplicated_search_responses(search_results)
compiled_references = [
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
for q, item in zip(inferred_queries, search_results)
]
yield compiled_references, inferred_queries, defiltered_query
@api.get("/health", response_class=Response)
@requires(["authenticated"], status_code=200)
def health_check(request: Request) -> Response:

View File

@@ -40,7 +40,6 @@ from khoj.processor.tools.online_search import (
search_online,
)
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import search_documents
from khoj.routers.email import send_query_feedback
from khoj.routers.helpers import (
ApiImageRateLimiter,
@@ -63,6 +62,7 @@ from khoj.routers.helpers import (
is_query_empty,
is_ready_to_chat,
read_chat_stream,
search_documents,
update_telemetry_state,
validate_chat_model,
)

View File

@@ -1,10 +1,12 @@
import base64
import concurrent.futures
import hashlib
import json
import logging
import math
import os
import re
import time
from datetime import datetime, timedelta, timezone
from random import random
from typing import (
@@ -45,6 +47,7 @@ from khoj.database.adapters import (
aget_user_by_email,
ais_user_subscribed,
create_khoj_token,
get_default_search_model,
get_khoj_tokens,
get_user_name,
get_user_notion_config,
@@ -79,17 +82,21 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
extract_questions_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
extract_questions_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
extract_questions_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
extract_questions,
send_message_to_model,
)
from khoj.processor.conversation.utils import (
@@ -100,11 +107,15 @@ from khoj.processor.conversation.utils import (
clean_json,
clean_mermaidjs,
construct_chat_history,
defilter_query,
generate_chatml_messages_with_context,
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email
from khoj.routers.twilio import is_twilio_enabled
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
@@ -121,7 +132,13 @@ from khoj.utils.helpers import (
timer,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
from khoj.utils.rawconfig import (
ChatRequestBody,
FileAttachment,
FileData,
LocationData,
SearchResponse,
)
logger = logging.getLogger(__name__)
@@ -1149,6 +1166,276 @@ async def generate_better_image_prompt(
return response
async def search_documents(
user: KhojUser,
chat_history: list[ChatMessageModel],
q: str,
n: int,
d: float,
conversation_id: str,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
):
# Initialize Variables
compiled_references: List[dict[str, str]] = []
inferred_queries: List[str] = []
agent_has_entries = False
if agent:
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if (
not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands
and not agent_has_entries
):
yield compiled_references, inferred_queries, q
return
# If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base
should_limit_to_agent_knowledge = (
ConversationCommand.Notes not in conversation_commands
and ConversationCommand.Default not in conversation_commands
)
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
if not agent_has_entries:
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
yield compiled_references, inferred_queries, q
return
# Extract filter terms from user message
defiltered_query = defilter_query(q)
filters_in_query = q.replace(defiltered_query, "").strip()
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
if not conversation:
logger.error(f"Conversation with id {conversation_id} not found when extracting references.")
yield compiled_references, inferred_queries, defiltered_query
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False
if is_none_or_empty(filters_in_query):
logger.debug(f"Filters in query: {filters_in_query}")
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = chat_model.vision_enabled
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query,
model=chat_model,
loaded_model=loaded_model,
chat_history=chat_history,
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions(
defiltered_query,
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
# Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
if send_status_func:
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
yield {ChatEvent.STATUS: event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(
await execute_search(
user if not should_limit_to_agent_knowledge else None,
f"{query} {filters_in_query}",
n=n_items,
t=state.SearchType.All,
r=True,
max_distance=d,
dedupe=False,
agent=agent,
)
)
search_results = text_search.deduplicated_search_responses(search_results)
compiled_references = [
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
for q, item in zip(inferred_queries, search_results)
]
yield compiled_references, inferred_queries, defiltered_query
async def execute_search(
user: KhojUser,
q: str,
n: Optional[int] = 5,
t: Optional[state.SearchType] = state.SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
agent: Optional[Agent] = None,
):
# Run validation checks
results: List[SearchResponse] = []
start_time = time.time()
# Ensure the agent, if present, is accessible by the user
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
return results
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
# initialize variables
user_query = q.strip()
results_count = n or 5
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed
defiltered_query = user_query
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
encoded_asymmetric_query = None
if t != state.SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
if t in [
state.SearchType.All,
state.SearchType.Org,
state.SearchType.Markdown,
state.SearchType.Github,
state.SearchType.Notion,
state.SearchType.Plaintext,
state.SearchType.Pdf,
]:
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user_query,
user,
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
agent=agent,
)
]
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
hits = await search_future.result()
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
# Sort results across all content types and take top results
results = text_search.rerank_and_sort_results(
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
)[:results_count]
# Cache results
if user:
state.query_cache[user.uuid][query_cache_key] = results
end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
return results
async def send_message_to_model_wrapper(
query: str,
system_message: str = "",

View File

@@ -22,10 +22,10 @@ from khoj.processor.conversation.utils import (
from khoj.processor.operator import operate_environment
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import search_documents
from khoj.routers.helpers import (
ChatEvent,
generate_summary_from_files,
search_documents,
send_message_to_model_wrapper,
)
from khoj.utils.helpers import (

View File

@@ -6,7 +6,7 @@ from asgiref.sync import sync_to_async
from khoj.database.adapters import AgentAdapters
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.routers.api import execute_search
from khoj.routers.helpers import execute_search
from khoj.utils.helpers import get_absolute_path
from tests.helpers import ChatModelFactory