Execute doc search in parallel using asyncio instead of threadpool

This commit is contained in:
Debanjum
2025-06-07 13:06:49 -07:00
parent dc1c3561fe
commit b6ceaeeffc

View File

@@ -1,3 +1,4 @@
import asyncio
import base64 import base64
import concurrent.futures import concurrent.futures
import hashlib import hashlib
@@ -1362,7 +1363,7 @@ async def execute_search(
user_query = q.strip() user_query = q.strip()
results_count = n or 5 results_count = n or 5
t = t or state.SearchType.All t = t or state.SearchType.All
search_futures: List[concurrent.futures.Future] = [] search_tasks = []
# return cached results, if available # return cached results, if available
if user: if user:
@@ -1382,33 +1383,33 @@ async def execute_search(
search_model = await sync_to_async(get_default_search_model)() search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor: # Use asyncio to run searches in parallel
if t.value in [ if t.value in [
SearchType.All.value, SearchType.All.value,
SearchType.Org.value, SearchType.Org.value,
SearchType.Markdown.value, SearchType.Markdown.value,
SearchType.Github.value, SearchType.Github.value,
SearchType.Notion.value, SearchType.Notion.value,
SearchType.Plaintext.value, SearchType.Plaintext.value,
SearchType.Pdf.value, SearchType.Pdf.value,
]: ]:
# query markdown notes # query markdown notes
search_futures += [ search_tasks.append(
executor.submit( text_search.query(
text_search.query, user_query,
user_query, user,
user, t,
t, question_embedding=encoded_asymmetric_query,
question_embedding=encoded_asymmetric_query, max_distance=max_distance,
max_distance=max_distance, agent=agent,
agent=agent, )
) )
]
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures): if search_tasks:
hits = await search_future.result() hits_list = await asyncio.gather(*search_tasks)
for hits in hits_list:
# Collate results # Collate results
results += text_search.collate_results(hits, dedupe=dedupe) results += text_search.collate_results(hits, dedupe=dedupe)