Rerank Search Results by Default on GPU machines (#668)

- Trigger
   SentenceTransformer Cross Encoder models now run fast on GPU enabled machines, including Mac ARM devices since UKPLab/sentence-transformers#2463

- Details
  - Use cross-encoder to rerank search results by default on GPU machines and when using an inference server
  - Only call search API when pause in typing search query on web, desktop apps
This commit is contained in:
Debanjum
2024-03-10 15:15:25 +05:30
committed by GitHub
6 changed files with 45 additions and 32 deletions

View File

@@ -50,7 +50,7 @@ dependencies = [
"pyyaml == 6.0", "pyyaml == 6.0",
"rich >= 13.3.1", "rich >= 13.3.1",
"schedule == 1.1.0", "schedule == 1.1.0",
"sentence-transformers == 2.3.1", "sentence-transformers == 2.5.1",
"transformers >= 4.28.0", "transformers >= 4.28.0",
"torch == 2.0.1", "torch == 2.0.1",
"uvicorn == 0.17.6", "uvicorn == 0.17.6",

View File

@@ -192,16 +192,17 @@
}); });
} }
let debounceTimeout;
function incrementalSearch(event) { function incrementalSearch(event) {
type = 'all'; // Run incremental search only after waitTime passed since the last key press
// Search with reranking on 'Enter' let waitTime = 300;
if (event.key === 'Enter') { clearTimeout(debounceTimeout);
search(rerank=true); debounceTimeout = setTimeout(() => {
} type = 'all';
// Limit incremental search to text types // Search with reranking on 'Enter'
else if (type !== "image") { let should_rerank = event.key === 'Enter';
search(rerank=false); search(rerank=should_rerank);
} }, waitTime);
} }
async function populate_type_dropdown() { async function populate_type_dropdown() {

View File

@@ -193,16 +193,17 @@
}); });
} }
let debounceTimeout;
function incrementalSearch(event) { function incrementalSearch(event) {
type = document.getElementById("type").value; // Run incremental search only after waitTime passed since the last key press
// Search with reranking on 'Enter' let waitTime = 300;
if (event.key === 'Enter') { clearTimeout(debounceTimeout);
search(rerank=true); debounceTimeout = setTimeout(() => {
} type = document.getElementById("type").value;
// Limit incremental search to text types // Search with reranking on 'Enter'
else if (type !== "image") { let should_rerank = event.key === 'Enter';
search(rerank=false); search(rerank=should_rerank);
} }, waitTime);
} }
function populate_type_dropdown() { function populate_type_dropdown() {

View File

@@ -33,8 +33,11 @@ class EmbeddingsModel:
self.api_key = embeddings_inference_endpoint_api_key self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
def embed_query(self, query): def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None: if self.inference_server_enabled():
return self.embed_with_api([query])[0] return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
@@ -62,11 +65,10 @@ class EmbeddingsModel:
return response.json()["embeddings"] return response.json()["embeddings"]
def embed_documents(self, docs): def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None: if self.inference_server_enabled():
target_url = f"{self.inference_endpoint}" if "huggingface" not in self.inference_endpoint:
if "huggingface" not in target_url:
logger.warning( logger.warning(
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint." f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
) )
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits # break up the docs payload in chunks of 1000 to avoid hitting rate limits
@@ -93,12 +95,11 @@ class CrossEncoderModel:
self.inference_endpoint = cross_encoder_inference_endpoint self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key self.api_key = cross_encoder_inference_endpoint_api_key
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if ( if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
target_url = f"{self.inference_endpoint}" target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

View File

@@ -177,8 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
def rerank_and_sort_results(hits, query, rank_results, search_model_name): def rerank_and_sort_results(hits, query, rank_results, search_model_name):
# If we have more than one result and reranking is enabled # Rerank results if explicitly requested, if can use inference server or if device has GPU
rank_results = rank_results and len(list(hits)) > 1 # AND if we have more than one result
rank_results = (
rank_results
or state.cross_encoder_model[search_model_name].inference_server_enabled()
or state.device.type != "cpu"
) and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results:

View File

@@ -331,7 +331,12 @@ def batcher(iterable, max_n):
yield (x for x in chunk if x is not None) yield (x for x in chunk if x is not None)
def is_env_var_true(env_var: str, default: str = "false") -> bool:
"""Get state of boolean environment variable"""
return os.getenv(env_var, default).lower() == "true"
def in_debug_mode(): def in_debug_mode():
"""Check if Khoj is running in debug mode. """Check if Khoj is running in debug mode.
Set KHOJ_DEBUG environment variable to true to enable debug mode.""" Set KHOJ_DEBUG environment variable to true to enable debug mode."""
return os.getenv("KHOJ_DEBUG", "false").lower() == "true" return is_env_var_true("KHOJ_DEBUG")