diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 805030d8..a4f24999 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,19 +1,16 @@ -import concurrent.futures import json import logging import math import os import threading -import time import uuid -from typing import Any, Callable, List, Optional, Set, Union +from typing import List, Optional, Union import cron_descriptor import openai import pytz from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger -from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from fastapi.requests import Request from fastapi.responses import Response @@ -22,48 +19,28 @@ from starlette.authentication import has_required_scope, requires from khoj.configure import initialize_content from khoj.database import adapters from khoj.database.adapters import ( - AgentAdapters, AutomationAdapters, ConversationAdapters, EntryAdapters, - get_default_search_model, get_user_photo, ) -from khoj.database.models import ( - Agent, - ChatMessageModel, - ChatModel, - KhojUser, - SpeechToTextModelOptions, -) -from khoj.processor.conversation import prompts -from khoj.processor.conversation.anthropic.anthropic_chat import ( - extract_questions_anthropic, -) -from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini -from khoj.processor.conversation.offline.chat_model import extract_questions_offline +from khoj.database.models import KhojUser, SpeechToTextModelOptions from khoj.processor.conversation.offline.whisper import transcribe_audio_offline -from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio -from khoj.processor.conversation.utils import clean_json, defilter_query +from khoj.processor.conversation.utils import clean_json from khoj.routers.helpers import ( ApiUserRateLimiter, - ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, + execute_search, get_user_config, schedule_automation, schedule_query, update_telemetry_state, ) -from khoj.search_filter.date_filter import DateFilter -from khoj.search_filter.file_filter import FileFilter -from khoj.search_filter.word_filter import WordFilter -from khoj.search_type import text_search from khoj.utils import state -from khoj.utils.config import OfflineChatProcessorModel -from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer -from khoj.utils.rawconfig import LocationData, SearchResponse +from khoj.utils.helpers import is_none_or_empty +from khoj.utils.rawconfig import SearchResponse from khoj.utils.state import SearchType # Initialize Router @@ -116,98 +93,6 @@ async def search( return results -async def execute_search( - user: KhojUser, - q: str, - n: Optional[int] = 5, - t: Optional[SearchType] = SearchType.All, - r: Optional[bool] = False, - max_distance: Optional[Union[float, None]] = None, - dedupe: Optional[bool] = True, - agent: Optional[Agent] = None, -): - # Run validation checks - results: List[SearchResponse] = [] - - start_time = time.time() - - # Ensure the agent, if present, is accessible by the user - if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user): - logger.error(f"Agent {agent.slug} is not accessible by user {user}") - return results - - if q is None or q == "": - logger.warning(f"No query param (q) passed in API call to initiate search") - return results - - # initialize variables - user_query = q.strip() - results_count = n or 5 - search_futures: List[concurrent.futures.Future] = [] - - # return cached results, if available - if user: - query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" - if query_cache_key in state.query_cache[user.uuid]: - logger.debug(f"Return response from query cache") - return state.query_cache[user.uuid][query_cache_key] - - # Encode query with filter terms removed - defiltered_query = user_query - for filter in [DateFilter(), WordFilter(), FileFilter()]: - defiltered_query = filter.defilter(defiltered_query) - - encoded_asymmetric_query = None - if t != SearchType.Image: - with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_default_search_model)() - encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) - - with concurrent.futures.ThreadPoolExecutor() as executor: - if t in [ - SearchType.All, - SearchType.Org, - SearchType.Markdown, - SearchType.Github, - SearchType.Notion, - SearchType.Plaintext, - SearchType.Pdf, - ]: - # query markdown notes - search_futures += [ - executor.submit( - text_search.query, - user_query, - user, - t, - question_embedding=encoded_asymmetric_query, - max_distance=max_distance, - agent=agent, - ) - ] - - # Query across each requested content types in parallel - with timer("Query took", logger): - for search_future in concurrent.futures.as_completed(search_futures): - hits = await search_future.result() - # Collate results - results += text_search.collate_results(hits, dedupe=dedupe) - - # Sort results across all content types and take top results - results = text_search.rerank_and_sort_results( - results, query=defiltered_query, rank_results=r, search_model_name=search_model.name - )[:results_count] - - # Cache results - if user: - state.query_cache[user.uuid][query_cache_key] = results - - end_time = time.time() - logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") - - return results - - @api.get("/update") @requires(["authenticated"]) def update( @@ -357,184 +242,6 @@ def set_user_name( return {"status": "ok"} -async def search_documents( - user: KhojUser, - chat_history: list[ChatMessageModel], - q: str, - n: int, - d: float, - conversation_id: str, - conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], - location_data: LocationData = None, - send_status_func: Optional[Callable] = None, - query_images: Optional[List[str]] = None, - previous_inferred_queries: Set = set(), - agent: Agent = None, - query_files: str = None, - tracer: dict = {}, -): - # Initialize Variables - compiled_references: List[dict[str, str]] = [] - inferred_queries: List[str] = [] - - agent_has_entries = False - - if agent: - agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) - - if ( - not ConversationCommand.Notes in conversation_commands - and not ConversationCommand.Default in conversation_commands - and not agent_has_entries - ): - yield compiled_references, inferred_queries, q - return - - # If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base - should_limit_to_agent_knowledge = ( - ConversationCommand.Notes not in conversation_commands - and ConversationCommand.Default not in conversation_commands - ) - - if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): - if not agent_has_entries: - logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") - yield compiled_references, inferred_queries, q - return - - # Extract filter terms from user message - defiltered_query = defilter_query(q) - filters_in_query = q.replace(defiltered_query, "").strip() - conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id) - - if not conversation: - logger.error(f"Conversation with id {conversation_id} not found when extracting references.") - yield compiled_references, inferred_queries, defiltered_query - return - - filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) - using_offline_chat = False - if is_none_or_empty(filters_in_query): - logger.debug(f"Filters in query: {filters_in_query}") - - personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" - - # Infer search queries from user message - with timer("Extracting search queries took", logger): - # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - chat_model = await ConversationAdapters.aget_default_chat_model(user) - vision_enabled = chat_model.vision_enabled - - if chat_model.model_type == ChatModel.ModelType.OFFLINE: - using_offline_chat = True - chat_model_name = chat_model.name - max_tokens = chat_model.max_prompt_size - if state.offline_chat_processor_config is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - - loaded_model = state.offline_chat_processor_config.loaded_model - - inferred_queries = extract_questions_offline( - defiltered_query, - model=chat_model, - loaded_model=loaded_model, - chat_history=chat_history, - should_extract_questions=True, - location_data=location_data, - user=user, - max_prompt_size=chat_model.max_prompt_size, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - elif chat_model.model_type == ChatModel.ModelType.OPENAI: - api_key = chat_model.ai_model_api.api_key - base_url = chat_model.ai_model_api.api_base_url - chat_model_name = chat_model.name - inferred_queries = extract_questions( - defiltered_query, - model=chat_model_name, - api_key=api_key, - api_base_url=base_url, - chat_history=chat_history, - location_data=location_data, - user=user, - query_images=query_images, - vision_enabled=vision_enabled, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - chat_model_name = chat_model.name - inferred_queries = extract_questions_anthropic( - defiltered_query, - query_images=query_images, - model=chat_model_name, - api_key=api_key, - api_base_url=api_base_url, - chat_history=chat_history, - location_data=location_data, - user=user, - vision_enabled=vision_enabled, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - elif chat_model.model_type == ChatModel.ModelType.GOOGLE: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - chat_model_name = chat_model.name - inferred_queries = extract_questions_gemini( - defiltered_query, - query_images=query_images, - model=chat_model_name, - api_key=api_key, - api_base_url=api_base_url, - chat_history=chat_history, - location_data=location_data, - max_tokens=chat_model.max_prompt_size, - user=user, - vision_enabled=vision_enabled, - personality_context=personality_context, - query_files=query_files, - tracer=tracer, - ) - - # Collate search results as context for GPT - inferred_queries = list(set(inferred_queries) - previous_inferred_queries) - with timer("Searching knowledge base took", logger): - search_results = [] - logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") - if send_status_func: - inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) - async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"): - yield {ChatEvent.STATUS: event} - for query in inferred_queries: - n_items = min(n, 3) if using_offline_chat else n - search_results.extend( - await execute_search( - user if not should_limit_to_agent_knowledge else None, - f"{query} {filters_in_query}", - n=n_items, - t=SearchType.All, - r=True, - max_distance=d, - dedupe=False, - agent=agent, - ) - ) - search_results = text_search.deduplicated_search_responses(search_results) - compiled_references = [ - {"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]} - for q, item in zip(inferred_queries, search_results) - ] - - yield compiled_references, inferred_queries, defiltered_query - - @api.get("/health", response_class=Response) @requires(["authenticated"], status_code=200) def health_check(request: Request) -> Response: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 62090052..ce5bfbc5 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -40,7 +40,6 @@ from khoj.processor.tools.online_search import ( search_online, ) from khoj.processor.tools.run_code import run_code -from khoj.routers.api import search_documents from khoj.routers.email import send_query_feedback from khoj.routers.helpers import ( ApiImageRateLimiter, @@ -63,6 +62,7 @@ from khoj.routers.helpers import ( is_query_empty, is_ready_to_chat, read_chat_stream, + search_documents, update_telemetry_state, validate_chat_model, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4d5082d9..ff31fcad 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,10 +1,12 @@ import base64 +import concurrent.futures import hashlib import json import logging import math import os import re +import time from datetime import datetime, timedelta, timezone from random import random from typing import ( @@ -45,6 +47,7 @@ from khoj.database.adapters import ( aget_user_by_email, ais_user_subscribed, create_khoj_token, + get_default_search_model, get_khoj_tokens, get_user_name, get_user_notion_config, @@ -79,17 +82,21 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.anthropic_chat import ( anthropic_send_message_to_model, converse_anthropic, + extract_questions_anthropic, ) from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, + extract_questions_gemini, gemini_send_message_to_model, ) from khoj.processor.conversation.offline.chat_model import ( converse_offline, + extract_questions_offline, send_message_to_model_offline, ) from khoj.processor.conversation.openai.gpt import ( converse_openai, + extract_questions, send_message_to_model, ) from khoj.processor.conversation.utils import ( @@ -100,11 +107,15 @@ from khoj.processor.conversation.utils import ( clean_json, clean_mermaidjs, construct_chat_history, + defilter_query, generate_chatml_messages_with_context, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.email import is_resend_enabled, send_task_email from khoj.routers.twilio import is_twilio_enabled +from khoj.search_filter.date_filter import DateFilter +from khoj.search_filter.file_filter import FileFilter +from khoj.search_filter.word_filter import WordFilter from khoj.search_type import text_search from khoj.utils import state from khoj.utils.config import OfflineChatProcessorModel @@ -121,7 +132,13 @@ from khoj.utils.helpers import ( timer, tool_descriptions_for_llm, ) -from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData +from khoj.utils.rawconfig import ( + ChatRequestBody, + FileAttachment, + FileData, + LocationData, + SearchResponse, +) logger = logging.getLogger(__name__) @@ -1149,6 +1166,276 @@ async def generate_better_image_prompt( return response +async def search_documents( + user: KhojUser, + chat_history: list[ChatMessageModel], + q: str, + n: int, + d: float, + conversation_id: str, + conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], + location_data: LocationData = None, + send_status_func: Optional[Callable] = None, + query_images: Optional[List[str]] = None, + previous_inferred_queries: Set = set(), + agent: Agent = None, + query_files: str = None, + tracer: dict = {}, +): + # Initialize Variables + compiled_references: List[dict[str, str]] = [] + inferred_queries: List[str] = [] + + agent_has_entries = False + + if agent: + agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) + + if ( + not ConversationCommand.Notes in conversation_commands + and not ConversationCommand.Default in conversation_commands + and not agent_has_entries + ): + yield compiled_references, inferred_queries, q + return + + # If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base + should_limit_to_agent_knowledge = ( + ConversationCommand.Notes not in conversation_commands + and ConversationCommand.Default not in conversation_commands + ) + + if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): + if not agent_has_entries: + logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.") + yield compiled_references, inferred_queries, q + return + + # Extract filter terms from user message + defiltered_query = defilter_query(q) + filters_in_query = q.replace(defiltered_query, "").strip() + conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id) + + if not conversation: + logger.error(f"Conversation with id {conversation_id} not found when extracting references.") + yield compiled_references, inferred_queries, defiltered_query + return + + filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters]) + using_offline_chat = False + if is_none_or_empty(filters_in_query): + logger.debug(f"Filters in query: {filters_in_query}") + + personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" + + # Infer search queries from user message + with timer("Extracting search queries took", logger): + # If we've reached here, either the user has enabled offline chat or the openai model is enabled. + chat_model = await ConversationAdapters.aget_default_chat_model(user) + vision_enabled = chat_model.vision_enabled + + if chat_model.model_type == ChatModel.ModelType.OFFLINE: + using_offline_chat = True + chat_model_name = chat_model.name + max_tokens = chat_model.max_prompt_size + if state.offline_chat_processor_config is None: + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) + + loaded_model = state.offline_chat_processor_config.loaded_model + + inferred_queries = extract_questions_offline( + defiltered_query, + model=chat_model, + loaded_model=loaded_model, + chat_history=chat_history, + should_extract_questions=True, + location_data=location_data, + user=user, + max_prompt_size=chat_model.max_prompt_size, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + elif chat_model.model_type == ChatModel.ModelType.OPENAI: + api_key = chat_model.ai_model_api.api_key + base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name + inferred_queries = extract_questions( + defiltered_query, + model=chat_model_name, + api_key=api_key, + api_base_url=base_url, + chat_history=chat_history, + location_data=location_data, + user=user, + query_images=query_images, + vision_enabled=vision_enabled, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name + inferred_queries = extract_questions_anthropic( + defiltered_query, + query_images=query_images, + model=chat_model_name, + api_key=api_key, + api_base_url=api_base_url, + chat_history=chat_history, + location_data=location_data, + user=user, + vision_enabled=vision_enabled, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name + inferred_queries = extract_questions_gemini( + defiltered_query, + query_images=query_images, + model=chat_model_name, + api_key=api_key, + api_base_url=api_base_url, + chat_history=chat_history, + location_data=location_data, + max_tokens=chat_model.max_prompt_size, + user=user, + vision_enabled=vision_enabled, + personality_context=personality_context, + query_files=query_files, + tracer=tracer, + ) + + # Collate search results as context for GPT + inferred_queries = list(set(inferred_queries) - previous_inferred_queries) + with timer("Searching knowledge base took", logger): + search_results = [] + logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") + if send_status_func: + inferred_queries_str = "\n- " + "\n- ".join(inferred_queries) + async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"): + yield {ChatEvent.STATUS: event} + for query in inferred_queries: + n_items = min(n, 3) if using_offline_chat else n + search_results.extend( + await execute_search( + user if not should_limit_to_agent_knowledge else None, + f"{query} {filters_in_query}", + n=n_items, + t=state.SearchType.All, + r=True, + max_distance=d, + dedupe=False, + agent=agent, + ) + ) + search_results = text_search.deduplicated_search_responses(search_results) + compiled_references = [ + {"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]} + for q, item in zip(inferred_queries, search_results) + ] + + yield compiled_references, inferred_queries, defiltered_query + + +async def execute_search( + user: KhojUser, + q: str, + n: Optional[int] = 5, + t: Optional[state.SearchType] = state.SearchType.All, + r: Optional[bool] = False, + max_distance: Optional[Union[float, None]] = None, + dedupe: Optional[bool] = True, + agent: Optional[Agent] = None, +): + # Run validation checks + results: List[SearchResponse] = [] + + start_time = time.time() + + # Ensure the agent, if present, is accessible by the user + if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user): + logger.error(f"Agent {agent.slug} is not accessible by user {user}") + return results + + if q is None or q == "": + logger.warning(f"No query param (q) passed in API call to initiate search") + return results + + # initialize variables + user_query = q.strip() + results_count = n or 5 + search_futures: List[concurrent.futures.Future] = [] + + # return cached results, if available + if user: + query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" + if query_cache_key in state.query_cache[user.uuid]: + logger.debug(f"Return response from query cache") + return state.query_cache[user.uuid][query_cache_key] + + # Encode query with filter terms removed + defiltered_query = user_query + for filter in [DateFilter(), WordFilter(), FileFilter()]: + defiltered_query = filter.defilter(defiltered_query) + + encoded_asymmetric_query = None + if t != state.SearchType.Image: + with timer("Encoding query took", logger=logger): + search_model = await sync_to_async(get_default_search_model)() + encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) + + with concurrent.futures.ThreadPoolExecutor() as executor: + if t in [ + state.SearchType.All, + state.SearchType.Org, + state.SearchType.Markdown, + state.SearchType.Github, + state.SearchType.Notion, + state.SearchType.Plaintext, + state.SearchType.Pdf, + ]: + # query markdown notes + search_futures += [ + executor.submit( + text_search.query, + user_query, + user, + t, + question_embedding=encoded_asymmetric_query, + max_distance=max_distance, + agent=agent, + ) + ] + + # Query across each requested content types in parallel + with timer("Query took", logger): + for search_future in concurrent.futures.as_completed(search_futures): + hits = await search_future.result() + # Collate results + results += text_search.collate_results(hits, dedupe=dedupe) + + # Sort results across all content types and take top results + results = text_search.rerank_and_sort_results( + results, query=defiltered_query, rank_results=r, search_model_name=search_model.name + )[:results_count] + + # Cache results + if user: + state.query_cache[user.uuid][query_cache_key] = results + + end_time = time.time() + logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") + + return results + + async def send_message_to_model_wrapper( query: str, system_message: str = "", diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 43617e6c..881e49b9 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -22,10 +22,10 @@ from khoj.processor.conversation.utils import ( from khoj.processor.operator import operate_environment from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.run_code import run_code -from khoj.routers.api import search_documents from khoj.routers.helpers import ( ChatEvent, generate_summary_from_files, + search_documents, send_message_to_model_wrapper, ) from khoj.utils.helpers import ( diff --git a/tests/test_agents.py b/tests/test_agents.py index 242495e6..f573bed9 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -6,7 +6,7 @@ from asgiref.sync import sync_to_async from khoj.database.adapters import AgentAdapters from khoj.database.models import Agent, ChatModel, Entry, KhojUser -from khoj.routers.api import execute_search +from khoj.routers.helpers import execute_search from khoj.utils.helpers import get_absolute_path from tests.helpers import ChatModelFactory