Encode the asymmetric, symmetric search queries in parallel for speed

Use timer to measure time to encode queries and total search time
This commit is contained in:
Debanjum Singh Solanky
2023-06-20 01:17:21 -07:00
parent db07362ca3
commit 6d94d6e75a

View File

@@ -2,6 +2,7 @@
from collections import defaultdict from collections import defaultdict
import concurrent.futures import concurrent.futures
import math import math
import time
import yaml import yaml
import logging import logging
from datetime import datetime from datetime import datetime
@@ -118,6 +119,8 @@ def search(
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
client: Optional[str] = None, client: Optional[str] = None,
): ):
start_time = time.time()
results: List[SearchResponse] = [] results: List[SearchResponse] = []
if q is None or q == "": if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search") logger.warn(f"No query param (q) passed in API call to initiate search")
@@ -139,15 +142,26 @@ def search(
for filter in [DateFilter(), WordFilter(), FileFilter()]: for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(user_query) defiltered_query = filter.defilter(user_query)
encoded_asymmetric_query = state.model.org_search.bi_encoder.encode( with concurrent.futures.ThreadPoolExecutor() as executor:
[defiltered_query], convert_to_tensor=True, device=state.device with timer("Encoding query for asymmetric search took", logger=logger):
) encode_asymmetric_futures = executor.submit(
encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query) state.model.org_search.bi_encoder.encode,
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
encoded_symmetric_query = state.model.org_search.bi_encoder.encode( with timer("Encoding query for symmetric search took", logger=logger):
[defiltered_query], convert_to_tensor=True, device=state.device encode_symmetric_futures = executor.submit(
) state.model.org_search.bi_encoder.encode,
encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query) [defiltered_query],
convert_to_tensor=True,
device=state.device,
)
with timer("Normalizing query embeddings took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result())
encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result())
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == None) and state.model.org_search: if (t == SearchType.Org or t == None) and state.model.org_search:
@@ -279,6 +293,9 @@ def search(
] ]
state.previous_query = user_query state.previous_query = user_query
end_time = time.time()
logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds")
return results return results