Rerank Search Results by Default on GPU machines (#668)

- Trigger
   SentenceTransformer Cross Encoder models now run fast on GPU enabled machines, including Mac ARM devices since UKPLab/sentence-transformers#2463

- Details
  - Use cross-encoder to rerank search results by default on GPU machines and when using an inference server
  - Only call search API when pause in typing search query on web, desktop apps
This commit is contained in:
Debanjum
2024-03-10 15:15:25 +05:30
committed by GitHub
6 changed files with 45 additions and 32 deletions

View File

@@ -50,7 +50,7 @@ dependencies = [
"pyyaml == 6.0",
"rich >= 13.3.1",
"schedule == 1.1.0",
"sentence-transformers == 2.3.1",
"sentence-transformers == 2.5.1",
"transformers >= 4.28.0",
"torch == 2.0.1",
"uvicorn == 0.17.6",

View File

@@ -192,16 +192,17 @@
});
}
let debounceTimeout;
function incrementalSearch(event) {
type = 'all';
// Search with reranking on 'Enter'
if (event.key === 'Enter') {
search(rerank=true);
}
// Limit incremental search to text types
else if (type !== "image") {
search(rerank=false);
}
// Run incremental search only after waitTime passed since the last key press
let waitTime = 300;
clearTimeout(debounceTimeout);
debounceTimeout = setTimeout(() => {
type = 'all';
// Search with reranking on 'Enter'
let should_rerank = event.key === 'Enter';
search(rerank=should_rerank);
}, waitTime);
}
async function populate_type_dropdown() {

View File

@@ -193,16 +193,17 @@
});
}
let debounceTimeout;
function incrementalSearch(event) {
type = document.getElementById("type").value;
// Search with reranking on 'Enter'
if (event.key === 'Enter') {
search(rerank=true);
}
// Limit incremental search to text types
else if (type !== "image") {
search(rerank=false);
}
// Run incremental search only after waitTime passed since the last key press
let waitTime = 300;
clearTimeout(debounceTimeout);
debounceTimeout = setTimeout(() => {
type = document.getElementById("type").value;
// Search with reranking on 'Enter'
let should_rerank = event.key === 'Enter';
search(rerank=should_rerank);
}, waitTime);
}
function populate_type_dropdown() {

View File

@@ -33,8 +33,11 @@ class EmbeddingsModel:
self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None:
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
@@ -62,11 +65,10 @@ class EmbeddingsModel:
return response.json()["embeddings"]
def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
if "huggingface" not in target_url:
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
@@ -93,12 +95,11 @@ class CrossEncoderModel:
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if (
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

View File

@@ -177,8 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Rerank results if explicitly requested, if can use inference server or if device has GPU
# AND if we have more than one result
rank_results = (
rank_results
or state.cross_encoder_model[search_model_name].inference_server_enabled()
or state.device.type != "cpu"
) and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder
if rank_results:

View File

@@ -331,7 +331,12 @@ def batcher(iterable, max_n):
yield (x for x in chunk if x is not None)
def is_env_var_true(env_var: str, default: str = "false") -> bool:
"""Get state of boolean environment variable"""
return os.getenv(env_var, default).lower() == "true"
def in_debug_mode():
"""Check if Khoj is running in debug mode.
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
return os.getenv("KHOJ_DEBUG", "false").lower() == "true"
return is_env_var_true("KHOJ_DEBUG")