From dc1c3561fe9b992cc46ed9ff8d53fe4689b7b5e8 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 7 Jun 2025 12:52:10 -0700 Subject: [PATCH] Make search type comparison in document search more robust --- src/khoj/routers/helpers.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 65d4d7b1..92cfaeea 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -136,6 +136,7 @@ from khoj.utils.rawconfig import ( LocationData, SearchResponse, ) +from khoj.utils.state import SearchType logger = logging.getLogger(__name__) @@ -1245,7 +1246,7 @@ async def search_documents( user if not should_limit_to_agent_knowledge else None, f"{query} {filters_in_query}", n=n, - t=state.SearchType.All, + t=SearchType.All, r=True, max_distance=d, dedupe=False, @@ -1337,7 +1338,7 @@ async def execute_search( user: KhojUser, q: str, n: Optional[int] = 5, - t: Optional[state.SearchType] = None, + t: Optional[SearchType] = None, r: Optional[bool] = False, max_distance: Optional[Union[float, None]] = None, dedupe: Optional[bool] = True, @@ -1376,20 +1377,20 @@ async def execute_search( defiltered_query = filter.defilter(defiltered_query) encoded_asymmetric_query = None - if t != state.SearchType.Image: + if t.value != SearchType.Image.value: with timer("Encoding query took", logger=logger): search_model = await sync_to_async(get_default_search_model)() encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: - if t in [ - state.SearchType.All, - state.SearchType.Org, - state.SearchType.Markdown, - state.SearchType.Github, - state.SearchType.Notion, - state.SearchType.Plaintext, - state.SearchType.Pdf, + if t.value in [ + SearchType.All.value, + SearchType.Org.value, + SearchType.Markdown.value, + SearchType.Github.value, + SearchType.Notion.value, + SearchType.Plaintext.value, + SearchType.Pdf.value, ]: # query markdown notes search_futures += [