mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Encode the asymmetric, symmetric search queries in parallel for speed
Use timer to measure time to encode queries and total search time
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -118,6 +119,8 @@ def search(
|
|||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||||
@@ -139,15 +142,26 @@ def search(
|
|||||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
defiltered_query = filter.defilter(user_query)
|
defiltered_query = filter.defilter(user_query)
|
||||||
|
|
||||||
encoded_asymmetric_query = state.model.org_search.bi_encoder.encode(
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
[defiltered_query], convert_to_tensor=True, device=state.device
|
with timer("Encoding query for asymmetric search took", logger=logger):
|
||||||
)
|
encode_asymmetric_futures = executor.submit(
|
||||||
encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query)
|
state.model.org_search.bi_encoder.encode,
|
||||||
|
[defiltered_query],
|
||||||
|
convert_to_tensor=True,
|
||||||
|
device=state.device,
|
||||||
|
)
|
||||||
|
|
||||||
encoded_symmetric_query = state.model.org_search.bi_encoder.encode(
|
with timer("Encoding query for symmetric search took", logger=logger):
|
||||||
[defiltered_query], convert_to_tensor=True, device=state.device
|
encode_symmetric_futures = executor.submit(
|
||||||
)
|
state.model.org_search.bi_encoder.encode,
|
||||||
encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query)
|
[defiltered_query],
|
||||||
|
convert_to_tensor=True,
|
||||||
|
device=state.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
with timer("Normalizing query embeddings took", logger=logger):
|
||||||
|
encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result())
|
||||||
|
encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result())
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||||
@@ -279,6 +293,9 @@ def search(
|
|||||||
]
|
]
|
||||||
state.previous_query = user_query
|
state.previous_query = user_query
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user