mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Rerank Search Results by Default on GPU machines (#668)
- Trigger SentenceTransformer Cross Encoder models now run fast on GPU enabled machines, including Mac ARM devices since UKPLab/sentence-transformers#2463 - Details - Use cross-encoder to rerank search results by default on GPU machines and when using an inference server - Only call search API when pause in typing search query on web, desktop apps
This commit is contained in:
@@ -50,7 +50,7 @@ dependencies = [
|
||||
"pyyaml == 6.0",
|
||||
"rich >= 13.3.1",
|
||||
"schedule == 1.1.0",
|
||||
"sentence-transformers == 2.3.1",
|
||||
"sentence-transformers == 2.5.1",
|
||||
"transformers >= 4.28.0",
|
||||
"torch == 2.0.1",
|
||||
"uvicorn == 0.17.6",
|
||||
|
||||
@@ -192,16 +192,17 @@
|
||||
});
|
||||
}
|
||||
|
||||
let debounceTimeout;
|
||||
function incrementalSearch(event) {
|
||||
type = 'all';
|
||||
// Search with reranking on 'Enter'
|
||||
if (event.key === 'Enter') {
|
||||
search(rerank=true);
|
||||
}
|
||||
// Limit incremental search to text types
|
||||
else if (type !== "image") {
|
||||
search(rerank=false);
|
||||
}
|
||||
// Run incremental search only after waitTime passed since the last key press
|
||||
let waitTime = 300;
|
||||
clearTimeout(debounceTimeout);
|
||||
debounceTimeout = setTimeout(() => {
|
||||
type = 'all';
|
||||
// Search with reranking on 'Enter'
|
||||
let should_rerank = event.key === 'Enter';
|
||||
search(rerank=should_rerank);
|
||||
}, waitTime);
|
||||
}
|
||||
|
||||
async function populate_type_dropdown() {
|
||||
|
||||
@@ -193,16 +193,17 @@
|
||||
});
|
||||
}
|
||||
|
||||
let debounceTimeout;
|
||||
function incrementalSearch(event) {
|
||||
type = document.getElementById("type").value;
|
||||
// Search with reranking on 'Enter'
|
||||
if (event.key === 'Enter') {
|
||||
search(rerank=true);
|
||||
}
|
||||
// Limit incremental search to text types
|
||||
else if (type !== "image") {
|
||||
search(rerank=false);
|
||||
}
|
||||
// Run incremental search only after waitTime passed since the last key press
|
||||
let waitTime = 300;
|
||||
clearTimeout(debounceTimeout);
|
||||
debounceTimeout = setTimeout(() => {
|
||||
type = document.getElementById("type").value;
|
||||
// Search with reranking on 'Enter'
|
||||
let should_rerank = event.key === 'Enter';
|
||||
search(rerank=should_rerank);
|
||||
}, waitTime);
|
||||
}
|
||||
|
||||
function populate_type_dropdown() {
|
||||
|
||||
@@ -33,8 +33,11 @@ class EmbeddingsModel:
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
|
||||
def embed_query(self, query):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
if self.inference_server_enabled():
|
||||
return self.embed_with_api([query])[0]
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
@@ -62,11 +65,10 @@ class EmbeddingsModel:
|
||||
return response.json()["embeddings"]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
if "huggingface" not in target_url:
|
||||
if self.inference_server_enabled():
|
||||
if "huggingface" not in self.inference_endpoint:
|
||||
logger.warning(
|
||||
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
@@ -93,12 +95,11 @@ class CrossEncoderModel:
|
||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
|
||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||
if (
|
||||
self.api_key is not None
|
||||
and self.inference_endpoint is not None
|
||||
and "huggingface" in self.inference_endpoint
|
||||
):
|
||||
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
@@ -177,8 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
||||
# If we have more than one result and reranking is enabled
|
||||
rank_results = rank_results and len(list(hits)) > 1
|
||||
# Rerank results if explicitly requested, if can use inference server or if device has GPU
|
||||
# AND if we have more than one result
|
||||
rank_results = (
|
||||
rank_results
|
||||
or state.cross_encoder_model[search_model_name].inference_server_enabled()
|
||||
or state.device.type != "cpu"
|
||||
) and len(list(hits)) > 1
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
|
||||
@@ -331,7 +331,12 @@ def batcher(iterable, max_n):
|
||||
yield (x for x in chunk if x is not None)
|
||||
|
||||
|
||||
def is_env_var_true(env_var: str, default: str = "false") -> bool:
|
||||
"""Get state of boolean environment variable"""
|
||||
return os.getenv(env_var, default).lower() == "true"
|
||||
|
||||
|
||||
def in_debug_mode():
|
||||
"""Check if Khoj is running in debug mode.
|
||||
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
|
||||
return os.getenv("KHOJ_DEBUG", "false").lower() == "true"
|
||||
return is_env_var_true("KHOJ_DEBUG")
|
||||
|
||||
Reference in New Issue
Block a user