Execute doc search in parallel using asyncio instead of threadpool

This commit is contained in:
Debanjum
2025-06-07 13:06:49 -07:00
parent dc1c3561fe
commit b6ceaeeffc

View File

@@ -1,3 +1,4 @@
import asyncio
import base64 import base64
import concurrent.futures import concurrent.futures
import hashlib import hashlib
@@ -1362,7 +1363,7 @@ async def execute_search(
user_query = q.strip() user_query = q.strip()
results_count = n or 5 results_count = n or 5
t = t or state.SearchType.All t = t or state.SearchType.All
search_futures: List[concurrent.futures.Future] = [] search_tasks = []
# return cached results, if available # return cached results, if available
if user: if user:
@@ -1382,7 +1383,7 @@ async def execute_search(
search_model = await sync_to_async(get_default_search_model)() search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor: # Use asyncio to run searches in parallel
if t.value in [ if t.value in [
SearchType.All.value, SearchType.All.value,
SearchType.Org.value, SearchType.Org.value,
@@ -1393,9 +1394,8 @@ async def execute_search(
SearchType.Pdf.value, SearchType.Pdf.value,
]: ]:
# query markdown notes # query markdown notes
search_futures += [ search_tasks.append(
executor.submit( text_search.query(
text_search.query,
user_query, user_query,
user, user,
t, t,
@@ -1403,12 +1403,13 @@ async def execute_search(
max_distance=max_distance, max_distance=max_distance,
agent=agent, agent=agent,
) )
] )
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures): if search_tasks:
hits = await search_future.result() hits_list = await asyncio.gather(*search_tasks)
for hits in hits_list:
# Collate results # Collate results
results += text_search.collate_results(hits, dedupe=dedupe) results += text_search.collate_results(hits, dedupe=dedupe)