From b6ceaeeffc962135bc2af5963b3707cf5c7031bc Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 7 Jun 2025 13:06:49 -0700 Subject: [PATCH] Execute doc search in parallel using asyncio instead of threadpool --- src/khoj/routers/helpers.py | 55 +++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 92cfaeea..e1a7f812 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,3 +1,4 @@ +import asyncio import base64 import concurrent.futures import hashlib @@ -1362,7 +1363,7 @@ async def execute_search( user_query = q.strip() results_count = n or 5 t = t or state.SearchType.All - search_futures: List[concurrent.futures.Future] = [] + search_tasks = [] # return cached results, if available if user: @@ -1382,33 +1383,33 @@ async def execute_search( search_model = await sync_to_async(get_default_search_model)() encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) - with concurrent.futures.ThreadPoolExecutor() as executor: - if t.value in [ - SearchType.All.value, - SearchType.Org.value, - SearchType.Markdown.value, - SearchType.Github.value, - SearchType.Notion.value, - SearchType.Plaintext.value, - SearchType.Pdf.value, - ]: - # query markdown notes - search_futures += [ - executor.submit( - text_search.query, - user_query, - user, - t, - question_embedding=encoded_asymmetric_query, - max_distance=max_distance, - agent=agent, - ) - ] + # Use asyncio to run searches in parallel + if t.value in [ + SearchType.All.value, + SearchType.Org.value, + SearchType.Markdown.value, + SearchType.Github.value, + SearchType.Notion.value, + SearchType.Plaintext.value, + SearchType.Pdf.value, + ]: + # query markdown notes + search_tasks.append( + text_search.query( + user_query, + user, + t, + question_embedding=encoded_asymmetric_query, + max_distance=max_distance, + agent=agent, + ) + ) - # Query across each requested content types in parallel - with timer("Query took", logger): - for search_future in concurrent.futures.as_completed(search_futures): - hits = await search_future.result() + # Query across each requested content types in parallel + with timer("Query took", logger): + if search_tasks: + hits_list = await asyncio.gather(*search_tasks) + for hits in hits_list: # Collate results results += text_search.collate_results(hits, dedupe=dedupe)