mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Rerank Search Results by Default on GPU machines (#668)
- Trigger SentenceTransformer Cross Encoder models now run fast on GPU enabled machines, including Mac ARM devices since UKPLab/sentence-transformers#2463 - Details - Use cross-encoder to rerank search results by default on GPU machines and when using an inference server - Only call search API when pause in typing search query on web, desktop apps
This commit is contained in:
@@ -50,7 +50,7 @@ dependencies = [
|
|||||||
"pyyaml == 6.0",
|
"pyyaml == 6.0",
|
||||||
"rich >= 13.3.1",
|
"rich >= 13.3.1",
|
||||||
"schedule == 1.1.0",
|
"schedule == 1.1.0",
|
||||||
"sentence-transformers == 2.3.1",
|
"sentence-transformers == 2.5.1",
|
||||||
"transformers >= 4.28.0",
|
"transformers >= 4.28.0",
|
||||||
"torch == 2.0.1",
|
"torch == 2.0.1",
|
||||||
"uvicorn == 0.17.6",
|
"uvicorn == 0.17.6",
|
||||||
|
|||||||
@@ -192,16 +192,17 @@
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let debounceTimeout;
|
||||||
function incrementalSearch(event) {
|
function incrementalSearch(event) {
|
||||||
type = 'all';
|
// Run incremental search only after waitTime passed since the last key press
|
||||||
// Search with reranking on 'Enter'
|
let waitTime = 300;
|
||||||
if (event.key === 'Enter') {
|
clearTimeout(debounceTimeout);
|
||||||
search(rerank=true);
|
debounceTimeout = setTimeout(() => {
|
||||||
}
|
type = 'all';
|
||||||
// Limit incremental search to text types
|
// Search with reranking on 'Enter'
|
||||||
else if (type !== "image") {
|
let should_rerank = event.key === 'Enter';
|
||||||
search(rerank=false);
|
search(rerank=should_rerank);
|
||||||
}
|
}, waitTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function populate_type_dropdown() {
|
async function populate_type_dropdown() {
|
||||||
|
|||||||
@@ -193,16 +193,17 @@
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let debounceTimeout;
|
||||||
function incrementalSearch(event) {
|
function incrementalSearch(event) {
|
||||||
type = document.getElementById("type").value;
|
// Run incremental search only after waitTime passed since the last key press
|
||||||
// Search with reranking on 'Enter'
|
let waitTime = 300;
|
||||||
if (event.key === 'Enter') {
|
clearTimeout(debounceTimeout);
|
||||||
search(rerank=true);
|
debounceTimeout = setTimeout(() => {
|
||||||
}
|
type = document.getElementById("type").value;
|
||||||
// Limit incremental search to text types
|
// Search with reranking on 'Enter'
|
||||||
else if (type !== "image") {
|
let should_rerank = event.key === 'Enter';
|
||||||
search(rerank=false);
|
search(rerank=should_rerank);
|
||||||
}
|
}, waitTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
function populate_type_dropdown() {
|
function populate_type_dropdown() {
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ class EmbeddingsModel:
|
|||||||
self.api_key = embeddings_inference_endpoint_api_key
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
|
|
||||||
|
def inference_server_enabled(self) -> bool:
|
||||||
|
return self.api_key is not None and self.inference_endpoint is not None
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
if self.api_key is not None and self.inference_endpoint is not None:
|
if self.inference_server_enabled():
|
||||||
return self.embed_with_api([query])[0]
|
return self.embed_with_api([query])[0]
|
||||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||||
|
|
||||||
@@ -62,11 +65,10 @@ class EmbeddingsModel:
|
|||||||
return response.json()["embeddings"]
|
return response.json()["embeddings"]
|
||||||
|
|
||||||
def embed_documents(self, docs):
|
def embed_documents(self, docs):
|
||||||
if self.api_key is not None and self.inference_endpoint is not None:
|
if self.inference_server_enabled():
|
||||||
target_url = f"{self.inference_endpoint}"
|
if "huggingface" not in self.inference_endpoint:
|
||||||
if "huggingface" not in target_url:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
|
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||||
)
|
)
|
||||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||||
@@ -93,12 +95,11 @@ class CrossEncoderModel:
|
|||||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||||
|
|
||||||
|
def inference_server_enabled(self) -> bool:
|
||||||
|
return self.api_key is not None and self.inference_endpoint is not None
|
||||||
|
|
||||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||||
if (
|
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
|
||||||
self.api_key is not None
|
|
||||||
and self.inference_endpoint is not None
|
|
||||||
and "huggingface" in self.inference_endpoint
|
|
||||||
):
|
|
||||||
target_url = f"{self.inference_endpoint}"
|
target_url = f"{self.inference_endpoint}"
|
||||||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
|
|||||||
@@ -177,8 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
|||||||
|
|
||||||
|
|
||||||
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
||||||
# If we have more than one result and reranking is enabled
|
# Rerank results if explicitly requested, if can use inference server or if device has GPU
|
||||||
rank_results = rank_results and len(list(hits)) > 1
|
# AND if we have more than one result
|
||||||
|
rank_results = (
|
||||||
|
rank_results
|
||||||
|
or state.cross_encoder_model[search_model_name].inference_server_enabled()
|
||||||
|
or state.device.type != "cpu"
|
||||||
|
) and len(list(hits)) > 1
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
|
|||||||
@@ -331,7 +331,12 @@ def batcher(iterable, max_n):
|
|||||||
yield (x for x in chunk if x is not None)
|
yield (x for x in chunk if x is not None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_env_var_true(env_var: str, default: str = "false") -> bool:
|
||||||
|
"""Get state of boolean environment variable"""
|
||||||
|
return os.getenv(env_var, default).lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
def in_debug_mode():
|
def in_debug_mode():
|
||||||
"""Check if Khoj is running in debug mode.
|
"""Check if Khoj is running in debug mode.
|
||||||
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
|
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
|
||||||
return os.getenv("KHOJ_DEBUG", "false").lower() == "true"
|
return is_env_var_true("KHOJ_DEBUG")
|
||||||
|
|||||||
Reference in New Issue
Block a user