mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Move document search tool into helpers module with other tools
Document search (because of its age) was the only tool directly within an api router. Put it into helpers to have all the (mini) tools in one place.
This commit is contained in:
@@ -1,19 +1,16 @@
|
|||||||
import concurrent.futures
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, List, Optional, Set, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import openai
|
import openai
|
||||||
import pytz
|
import pytz
|
||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
@@ -22,48 +19,28 @@ from starlette.authentication import has_required_scope, requires
|
|||||||
from khoj.configure import initialize_content
|
from khoj.configure import initialize_content
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
AgentAdapters,
|
|
||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
get_default_search_model,
|
|
||||||
get_user_photo,
|
get_user_photo,
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import KhojUser, SpeechToTextModelOptions
|
||||||
Agent,
|
|
||||||
ChatMessageModel,
|
|
||||||
ChatModel,
|
|
||||||
KhojUser,
|
|
||||||
SpeechToTextModelOptions,
|
|
||||||
)
|
|
||||||
from khoj.processor.conversation import prompts
|
|
||||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
|
||||||
extract_questions_anthropic,
|
|
||||||
)
|
|
||||||
from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
|
|
||||||
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
|
||||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
|
||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.processor.conversation.utils import clean_json, defilter_query
|
from khoj.processor.conversation.utils import clean_json
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
ChatEvent,
|
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
|
execute_search,
|
||||||
get_user_config,
|
get_user_config,
|
||||||
schedule_automation,
|
schedule_automation,
|
||||||
schedule_query,
|
schedule_query,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
)
|
)
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
|
||||||
from khoj.search_type import text_search
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import OfflineChatProcessorModel
|
from khoj.utils.helpers import is_none_or_empty
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
from khoj.utils.rawconfig import LocationData, SearchResponse
|
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
@@ -116,98 +93,6 @@ async def search(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def execute_search(
|
|
||||||
user: KhojUser,
|
|
||||||
q: str,
|
|
||||||
n: Optional[int] = 5,
|
|
||||||
t: Optional[SearchType] = SearchType.All,
|
|
||||||
r: Optional[bool] = False,
|
|
||||||
max_distance: Optional[Union[float, None]] = None,
|
|
||||||
dedupe: Optional[bool] = True,
|
|
||||||
agent: Optional[Agent] = None,
|
|
||||||
):
|
|
||||||
# Run validation checks
|
|
||||||
results: List[SearchResponse] = []
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Ensure the agent, if present, is accessible by the user
|
|
||||||
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
|
|
||||||
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
|
|
||||||
return results
|
|
||||||
|
|
||||||
if q is None or q == "":
|
|
||||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
|
||||||
return results
|
|
||||||
|
|
||||||
# initialize variables
|
|
||||||
user_query = q.strip()
|
|
||||||
results_count = n or 5
|
|
||||||
search_futures: List[concurrent.futures.Future] = []
|
|
||||||
|
|
||||||
# return cached results, if available
|
|
||||||
if user:
|
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
|
||||||
if query_cache_key in state.query_cache[user.uuid]:
|
|
||||||
logger.debug(f"Return response from query cache")
|
|
||||||
return state.query_cache[user.uuid][query_cache_key]
|
|
||||||
|
|
||||||
# Encode query with filter terms removed
|
|
||||||
defiltered_query = user_query
|
|
||||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
|
||||||
defiltered_query = filter.defilter(defiltered_query)
|
|
||||||
|
|
||||||
encoded_asymmetric_query = None
|
|
||||||
if t != SearchType.Image:
|
|
||||||
with timer("Encoding query took", logger=logger):
|
|
||||||
search_model = await sync_to_async(get_default_search_model)()
|
|
||||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
if t in [
|
|
||||||
SearchType.All,
|
|
||||||
SearchType.Org,
|
|
||||||
SearchType.Markdown,
|
|
||||||
SearchType.Github,
|
|
||||||
SearchType.Notion,
|
|
||||||
SearchType.Plaintext,
|
|
||||||
SearchType.Pdf,
|
|
||||||
]:
|
|
||||||
# query markdown notes
|
|
||||||
search_futures += [
|
|
||||||
executor.submit(
|
|
||||||
text_search.query,
|
|
||||||
user_query,
|
|
||||||
user,
|
|
||||||
t,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
max_distance=max_distance,
|
|
||||||
agent=agent,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Query across each requested content types in parallel
|
|
||||||
with timer("Query took", logger):
|
|
||||||
for search_future in concurrent.futures.as_completed(search_futures):
|
|
||||||
hits = await search_future.result()
|
|
||||||
# Collate results
|
|
||||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
|
||||||
|
|
||||||
# Sort results across all content types and take top results
|
|
||||||
results = text_search.rerank_and_sort_results(
|
|
||||||
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
|
||||||
)[:results_count]
|
|
||||||
|
|
||||||
# Cache results
|
|
||||||
if user:
|
|
||||||
state.query_cache[user.uuid][query_cache_key] = results
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/update")
|
@api.get("/update")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
def update(
|
def update(
|
||||||
@@ -357,184 +242,6 @@ def set_user_name(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
async def search_documents(
|
|
||||||
user: KhojUser,
|
|
||||||
chat_history: list[ChatMessageModel],
|
|
||||||
q: str,
|
|
||||||
n: int,
|
|
||||||
d: float,
|
|
||||||
conversation_id: str,
|
|
||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
|
||||||
location_data: LocationData = None,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
|
||||||
query_images: Optional[List[str]] = None,
|
|
||||||
previous_inferred_queries: Set = set(),
|
|
||||||
agent: Agent = None,
|
|
||||||
query_files: str = None,
|
|
||||||
tracer: dict = {},
|
|
||||||
):
|
|
||||||
# Initialize Variables
|
|
||||||
compiled_references: List[dict[str, str]] = []
|
|
||||||
inferred_queries: List[str] = []
|
|
||||||
|
|
||||||
agent_has_entries = False
|
|
||||||
|
|
||||||
if agent:
|
|
||||||
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not ConversationCommand.Notes in conversation_commands
|
|
||||||
and not ConversationCommand.Default in conversation_commands
|
|
||||||
and not agent_has_entries
|
|
||||||
):
|
|
||||||
yield compiled_references, inferred_queries, q
|
|
||||||
return
|
|
||||||
|
|
||||||
# If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base
|
|
||||||
should_limit_to_agent_knowledge = (
|
|
||||||
ConversationCommand.Notes not in conversation_commands
|
|
||||||
and ConversationCommand.Default not in conversation_commands
|
|
||||||
)
|
|
||||||
|
|
||||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
|
||||||
if not agent_has_entries:
|
|
||||||
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
|
||||||
yield compiled_references, inferred_queries, q
|
|
||||||
return
|
|
||||||
|
|
||||||
# Extract filter terms from user message
|
|
||||||
defiltered_query = defilter_query(q)
|
|
||||||
filters_in_query = q.replace(defiltered_query, "").strip()
|
|
||||||
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
|
|
||||||
|
|
||||||
if not conversation:
|
|
||||||
logger.error(f"Conversation with id {conversation_id} not found when extracting references.")
|
|
||||||
yield compiled_references, inferred_queries, defiltered_query
|
|
||||||
return
|
|
||||||
|
|
||||||
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
|
|
||||||
using_offline_chat = False
|
|
||||||
if is_none_or_empty(filters_in_query):
|
|
||||||
logger.debug(f"Filters in query: {filters_in_query}")
|
|
||||||
|
|
||||||
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
|
|
||||||
|
|
||||||
# Infer search queries from user message
|
|
||||||
with timer("Extracting search queries took", logger):
|
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
|
||||||
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
|
||||||
vision_enabled = chat_model.vision_enabled
|
|
||||||
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
|
||||||
using_offline_chat = True
|
|
||||||
chat_model_name = chat_model.name
|
|
||||||
max_tokens = chat_model.max_prompt_size
|
|
||||||
if state.offline_chat_processor_config is None:
|
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
||||||
|
|
||||||
inferred_queries = extract_questions_offline(
|
|
||||||
defiltered_query,
|
|
||||||
model=chat_model,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
chat_history=chat_history,
|
|
||||||
should_extract_questions=True,
|
|
||||||
location_data=location_data,
|
|
||||||
user=user,
|
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
|
||||||
personality_context=personality_context,
|
|
||||||
query_files=query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
chat_model_name = chat_model.name
|
|
||||||
inferred_queries = extract_questions(
|
|
||||||
defiltered_query,
|
|
||||||
model=chat_model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base_url=base_url,
|
|
||||||
chat_history=chat_history,
|
|
||||||
location_data=location_data,
|
|
||||||
user=user,
|
|
||||||
query_images=query_images,
|
|
||||||
vision_enabled=vision_enabled,
|
|
||||||
personality_context=personality_context,
|
|
||||||
query_files=query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
chat_model_name = chat_model.name
|
|
||||||
inferred_queries = extract_questions_anthropic(
|
|
||||||
defiltered_query,
|
|
||||||
query_images=query_images,
|
|
||||||
model=chat_model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base_url=api_base_url,
|
|
||||||
chat_history=chat_history,
|
|
||||||
location_data=location_data,
|
|
||||||
user=user,
|
|
||||||
vision_enabled=vision_enabled,
|
|
||||||
personality_context=personality_context,
|
|
||||||
query_files=query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
chat_model_name = chat_model.name
|
|
||||||
inferred_queries = extract_questions_gemini(
|
|
||||||
defiltered_query,
|
|
||||||
query_images=query_images,
|
|
||||||
model=chat_model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base_url=api_base_url,
|
|
||||||
chat_history=chat_history,
|
|
||||||
location_data=location_data,
|
|
||||||
max_tokens=chat_model.max_prompt_size,
|
|
||||||
user=user,
|
|
||||||
vision_enabled=vision_enabled,
|
|
||||||
personality_context=personality_context,
|
|
||||||
query_files=query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
|
||||||
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
|
|
||||||
with timer("Searching knowledge base took", logger):
|
|
||||||
search_results = []
|
|
||||||
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
|
||||||
if send_status_func:
|
|
||||||
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
|
||||||
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
|
|
||||||
yield {ChatEvent.STATUS: event}
|
|
||||||
for query in inferred_queries:
|
|
||||||
n_items = min(n, 3) if using_offline_chat else n
|
|
||||||
search_results.extend(
|
|
||||||
await execute_search(
|
|
||||||
user if not should_limit_to_agent_knowledge else None,
|
|
||||||
f"{query} {filters_in_query}",
|
|
||||||
n=n_items,
|
|
||||||
t=SearchType.All,
|
|
||||||
r=True,
|
|
||||||
max_distance=d,
|
|
||||||
dedupe=False,
|
|
||||||
agent=agent,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
search_results = text_search.deduplicated_search_responses(search_results)
|
|
||||||
compiled_references = [
|
|
||||||
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
|
|
||||||
for q, item in zip(inferred_queries, search_results)
|
|
||||||
]
|
|
||||||
|
|
||||||
yield compiled_references, inferred_queries, defiltered_query
|
|
||||||
|
|
||||||
|
|
||||||
@api.get("/health", response_class=Response)
|
@api.get("/health", response_class=Response)
|
||||||
@requires(["authenticated"], status_code=200)
|
@requires(["authenticated"], status_code=200)
|
||||||
def health_check(request: Request) -> Response:
|
def health_check(request: Request) -> Response:
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from khoj.processor.tools.online_search import (
|
|||||||
search_online,
|
search_online,
|
||||||
)
|
)
|
||||||
from khoj.processor.tools.run_code import run_code
|
from khoj.processor.tools.run_code import run_code
|
||||||
from khoj.routers.api import search_documents
|
|
||||||
from khoj.routers.email import send_query_feedback
|
from khoj.routers.email import send_query_feedback
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiImageRateLimiter,
|
ApiImageRateLimiter,
|
||||||
@@ -63,6 +62,7 @@ from khoj.routers.helpers import (
|
|||||||
is_query_empty,
|
is_query_empty,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
read_chat_stream,
|
read_chat_stream,
|
||||||
|
search_documents,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_chat_model,
|
validate_chat_model,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import concurrent.futures
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from random import random
|
from random import random
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -45,6 +47,7 @@ from khoj.database.adapters import (
|
|||||||
aget_user_by_email,
|
aget_user_by_email,
|
||||||
ais_user_subscribed,
|
ais_user_subscribed,
|
||||||
create_khoj_token,
|
create_khoj_token,
|
||||||
|
get_default_search_model,
|
||||||
get_khoj_tokens,
|
get_khoj_tokens,
|
||||||
get_user_name,
|
get_user_name,
|
||||||
get_user_notion_config,
|
get_user_notion_config,
|
||||||
@@ -79,17 +82,21 @@ from khoj.processor.conversation import prompts
|
|||||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
anthropic_send_message_to_model,
|
anthropic_send_message_to_model,
|
||||||
converse_anthropic,
|
converse_anthropic,
|
||||||
|
extract_questions_anthropic,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.google.gemini_chat import (
|
from khoj.processor.conversation.google.gemini_chat import (
|
||||||
converse_gemini,
|
converse_gemini,
|
||||||
|
extract_questions_gemini,
|
||||||
gemini_send_message_to_model,
|
gemini_send_message_to_model,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.offline.chat_model import (
|
from khoj.processor.conversation.offline.chat_model import (
|
||||||
converse_offline,
|
converse_offline,
|
||||||
|
extract_questions_offline,
|
||||||
send_message_to_model_offline,
|
send_message_to_model_offline,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.openai.gpt import (
|
from khoj.processor.conversation.openai.gpt import (
|
||||||
converse_openai,
|
converse_openai,
|
||||||
|
extract_questions,
|
||||||
send_message_to_model,
|
send_message_to_model,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
@@ -100,11 +107,15 @@ from khoj.processor.conversation.utils import (
|
|||||||
clean_json,
|
clean_json,
|
||||||
clean_mermaidjs,
|
clean_mermaidjs,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
|
defilter_query,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
)
|
)
|
||||||
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
||||||
from khoj.routers.email import is_resend_enabled, send_task_email
|
from khoj.routers.email import is_resend_enabled, send_task_email
|
||||||
from khoj.routers.twilio import is_twilio_enabled
|
from khoj.routers.twilio import is_twilio_enabled
|
||||||
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import OfflineChatProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
@@ -121,7 +132,13 @@ from khoj.utils.helpers import (
|
|||||||
timer,
|
timer,
|
||||||
tool_descriptions_for_llm,
|
tool_descriptions_for_llm,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
|
from khoj.utils.rawconfig import (
|
||||||
|
ChatRequestBody,
|
||||||
|
FileAttachment,
|
||||||
|
FileData,
|
||||||
|
LocationData,
|
||||||
|
SearchResponse,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1149,6 +1166,276 @@ async def generate_better_image_prompt(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def search_documents(
|
||||||
|
user: KhojUser,
|
||||||
|
chat_history: list[ChatMessageModel],
|
||||||
|
q: str,
|
||||||
|
n: int,
|
||||||
|
d: float,
|
||||||
|
conversation_id: str,
|
||||||
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
|
location_data: LocationData = None,
|
||||||
|
send_status_func: Optional[Callable] = None,
|
||||||
|
query_images: Optional[List[str]] = None,
|
||||||
|
previous_inferred_queries: Set = set(),
|
||||||
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
|
tracer: dict = {},
|
||||||
|
):
|
||||||
|
# Initialize Variables
|
||||||
|
compiled_references: List[dict[str, str]] = []
|
||||||
|
inferred_queries: List[str] = []
|
||||||
|
|
||||||
|
agent_has_entries = False
|
||||||
|
|
||||||
|
if agent:
|
||||||
|
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not ConversationCommand.Notes in conversation_commands
|
||||||
|
and not ConversationCommand.Default in conversation_commands
|
||||||
|
and not agent_has_entries
|
||||||
|
):
|
||||||
|
yield compiled_references, inferred_queries, q
|
||||||
|
return
|
||||||
|
|
||||||
|
# If Notes or Default is not in the conversation command, then the search should be restricted to the agent's knowledge base
|
||||||
|
should_limit_to_agent_knowledge = (
|
||||||
|
ConversationCommand.Notes not in conversation_commands
|
||||||
|
and ConversationCommand.Default not in conversation_commands
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||||
|
if not agent_has_entries:
|
||||||
|
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||||
|
yield compiled_references, inferred_queries, q
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract filter terms from user message
|
||||||
|
defiltered_query = defilter_query(q)
|
||||||
|
filters_in_query = q.replace(defiltered_query, "").strip()
|
||||||
|
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
logger.error(f"Conversation with id {conversation_id} not found when extracting references.")
|
||||||
|
yield compiled_references, inferred_queries, defiltered_query
|
||||||
|
return
|
||||||
|
|
||||||
|
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
|
||||||
|
using_offline_chat = False
|
||||||
|
if is_none_or_empty(filters_in_query):
|
||||||
|
logger.debug(f"Filters in query: {filters_in_query}")
|
||||||
|
|
||||||
|
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
|
||||||
|
|
||||||
|
# Infer search queries from user message
|
||||||
|
with timer("Extracting search queries took", logger):
|
||||||
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
|
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||||
|
vision_enabled = chat_model.vision_enabled
|
||||||
|
|
||||||
|
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
|
using_offline_chat = True
|
||||||
|
chat_model_name = chat_model.name
|
||||||
|
max_tokens = chat_model.max_prompt_size
|
||||||
|
if state.offline_chat_processor_config is None:
|
||||||
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
|
||||||
|
inferred_queries = extract_questions_offline(
|
||||||
|
defiltered_query,
|
||||||
|
model=chat_model,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
chat_history=chat_history,
|
||||||
|
should_extract_questions=True,
|
||||||
|
location_data=location_data,
|
||||||
|
user=user,
|
||||||
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||||
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
base_url = chat_model.ai_model_api.api_base_url
|
||||||
|
chat_model_name = chat_model.name
|
||||||
|
inferred_queries = extract_questions(
|
||||||
|
defiltered_query,
|
||||||
|
model=chat_model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base_url=base_url,
|
||||||
|
chat_history=chat_history,
|
||||||
|
location_data=location_data,
|
||||||
|
user=user,
|
||||||
|
query_images=query_images,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
|
chat_model_name = chat_model.name
|
||||||
|
inferred_queries = extract_questions_anthropic(
|
||||||
|
defiltered_query,
|
||||||
|
query_images=query_images,
|
||||||
|
model=chat_model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
chat_history=chat_history,
|
||||||
|
location_data=location_data,
|
||||||
|
user=user,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
|
chat_model_name = chat_model.name
|
||||||
|
inferred_queries = extract_questions_gemini(
|
||||||
|
defiltered_query,
|
||||||
|
query_images=query_images,
|
||||||
|
model=chat_model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
chat_history=chat_history,
|
||||||
|
location_data=location_data,
|
||||||
|
max_tokens=chat_model.max_prompt_size,
|
||||||
|
user=user,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collate search results as context for GPT
|
||||||
|
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
|
||||||
|
with timer("Searching knowledge base took", logger):
|
||||||
|
search_results = []
|
||||||
|
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||||
|
if send_status_func:
|
||||||
|
inferred_queries_str = "\n- " + "\n- ".join(inferred_queries)
|
||||||
|
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
for query in inferred_queries:
|
||||||
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
|
search_results.extend(
|
||||||
|
await execute_search(
|
||||||
|
user if not should_limit_to_agent_knowledge else None,
|
||||||
|
f"{query} {filters_in_query}",
|
||||||
|
n=n_items,
|
||||||
|
t=state.SearchType.All,
|
||||||
|
r=True,
|
||||||
|
max_distance=d,
|
||||||
|
dedupe=False,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
search_results = text_search.deduplicated_search_responses(search_results)
|
||||||
|
compiled_references = [
|
||||||
|
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
|
||||||
|
for q, item in zip(inferred_queries, search_results)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_search(
|
||||||
|
user: KhojUser,
|
||||||
|
q: str,
|
||||||
|
n: Optional[int] = 5,
|
||||||
|
t: Optional[state.SearchType] = state.SearchType.All,
|
||||||
|
r: Optional[bool] = False,
|
||||||
|
max_distance: Optional[Union[float, None]] = None,
|
||||||
|
dedupe: Optional[bool] = True,
|
||||||
|
agent: Optional[Agent] = None,
|
||||||
|
):
|
||||||
|
# Run validation checks
|
||||||
|
results: List[SearchResponse] = []
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ensure the agent, if present, is accessible by the user
|
||||||
|
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
|
||||||
|
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
if q is None or q == "":
|
||||||
|
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||||
|
return results
|
||||||
|
|
||||||
|
# initialize variables
|
||||||
|
user_query = q.strip()
|
||||||
|
results_count = n or 5
|
||||||
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
|
# return cached results, if available
|
||||||
|
if user:
|
||||||
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||||
|
if query_cache_key in state.query_cache[user.uuid]:
|
||||||
|
logger.debug(f"Return response from query cache")
|
||||||
|
return state.query_cache[user.uuid][query_cache_key]
|
||||||
|
|
||||||
|
# Encode query with filter terms removed
|
||||||
|
defiltered_query = user_query
|
||||||
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
|
defiltered_query = filter.defilter(defiltered_query)
|
||||||
|
|
||||||
|
encoded_asymmetric_query = None
|
||||||
|
if t != state.SearchType.Image:
|
||||||
|
with timer("Encoding query took", logger=logger):
|
||||||
|
search_model = await sync_to_async(get_default_search_model)()
|
||||||
|
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
if t in [
|
||||||
|
state.SearchType.All,
|
||||||
|
state.SearchType.Org,
|
||||||
|
state.SearchType.Markdown,
|
||||||
|
state.SearchType.Github,
|
||||||
|
state.SearchType.Notion,
|
||||||
|
state.SearchType.Plaintext,
|
||||||
|
state.SearchType.Pdf,
|
||||||
|
]:
|
||||||
|
# query markdown notes
|
||||||
|
search_futures += [
|
||||||
|
executor.submit(
|
||||||
|
text_search.query,
|
||||||
|
user_query,
|
||||||
|
user,
|
||||||
|
t,
|
||||||
|
question_embedding=encoded_asymmetric_query,
|
||||||
|
max_distance=max_distance,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Query across each requested content types in parallel
|
||||||
|
with timer("Query took", logger):
|
||||||
|
for search_future in concurrent.futures.as_completed(search_futures):
|
||||||
|
hits = await search_future.result()
|
||||||
|
# Collate results
|
||||||
|
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||||
|
|
||||||
|
# Sort results across all content types and take top results
|
||||||
|
results = text_search.rerank_and_sort_results(
|
||||||
|
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
||||||
|
)[:results_count]
|
||||||
|
|
||||||
|
# Cache results
|
||||||
|
if user:
|
||||||
|
state.query_cache[user.uuid][query_cache_key] = results
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def send_message_to_model_wrapper(
|
async def send_message_to_model_wrapper(
|
||||||
query: str,
|
query: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
|
|||||||
@@ -22,10 +22,10 @@ from khoj.processor.conversation.utils import (
|
|||||||
from khoj.processor.operator import operate_environment
|
from khoj.processor.operator import operate_environment
|
||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.processor.tools.run_code import run_code
|
from khoj.processor.tools.run_code import run_code
|
||||||
from khoj.routers.api import search_documents
|
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
|
search_documents,
|
||||||
send_message_to_model_wrapper,
|
send_message_to_model_wrapper,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from asgiref.sync import sync_to_async
|
|||||||
|
|
||||||
from khoj.database.adapters import AgentAdapters
|
from khoj.database.adapters import AgentAdapters
|
||||||
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||||
from khoj.routers.api import execute_search
|
from khoj.routers.helpers import execute_search
|
||||||
from khoj.utils.helpers import get_absolute_path
|
from khoj.utils.helpers import get_absolute_path
|
||||||
from tests.helpers import ChatModelFactory
|
from tests.helpers import ChatModelFactory
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user