Make search type comparison in document search more robust

This commit is contained in:
Debanjum
2025-06-07 12:52:10 -07:00
parent b9c6252a4a
commit dc1c3561fe

View File

@@ -136,6 +136,7 @@ from khoj.utils.rawconfig import (
LocationData, LocationData,
SearchResponse, SearchResponse,
) )
from khoj.utils.state import SearchType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -1245,7 +1246,7 @@ async def search_documents(
user if not should_limit_to_agent_knowledge else None, user if not should_limit_to_agent_knowledge else None,
f"{query} {filters_in_query}", f"{query} {filters_in_query}",
n=n, n=n,
t=state.SearchType.All, t=SearchType.All,
r=True, r=True,
max_distance=d, max_distance=d,
dedupe=False, dedupe=False,
@@ -1337,7 +1338,7 @@ async def execute_search(
user: KhojUser, user: KhojUser,
q: str, q: str,
n: Optional[int] = 5, n: Optional[int] = 5,
t: Optional[state.SearchType] = None, t: Optional[SearchType] = None,
r: Optional[bool] = False, r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None, max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
@@ -1376,20 +1377,20 @@ async def execute_search(
defiltered_query = filter.defilter(defiltered_query) defiltered_query = filter.defilter(defiltered_query)
encoded_asymmetric_query = None encoded_asymmetric_query = None
if t != state.SearchType.Image: if t.value != SearchType.Image.value:
with timer("Encoding query took", logger=logger): with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_default_search_model)() search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if t in [ if t.value in [
state.SearchType.All, SearchType.All.value,
state.SearchType.Org, SearchType.Org.value,
state.SearchType.Markdown, SearchType.Markdown.value,
state.SearchType.Github, SearchType.Github.value,
state.SearchType.Notion, SearchType.Notion.value,
state.SearchType.Plaintext, SearchType.Plaintext.value,
state.SearchType.Pdf, SearchType.Pdf.value,
]: ]:
# query markdown notes # query markdown notes
search_futures += [ search_futures += [