mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Rerank search results with cross-encoder when using an inference server
If an inference server is being used, we can expect the cross encoder to be running fast enough to rerank search results by default
This commit is contained in:
@@ -33,8 +33,11 @@ class EmbeddingsModel:
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
|
||||
def embed_query(self, query):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
if self.inference_server_enabled():
|
||||
return self.embed_with_api([query])[0]
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
@@ -62,11 +65,10 @@ class EmbeddingsModel:
|
||||
return response.json()["embeddings"]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
if "huggingface" not in target_url:
|
||||
if self.inference_server_enabled():
|
||||
if "huggingface" not in self.inference_endpoint:
|
||||
logger.warning(
|
||||
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
@@ -93,12 +95,11 @@ class CrossEncoderModel:
|
||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
|
||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||
if (
|
||||
self.api_key is not None
|
||||
and self.inference_endpoint is not None
|
||||
and "huggingface" in self.inference_endpoint
|
||||
):
|
||||
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
@@ -177,9 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
||||
# Rerank results if explicitly requested or if device has GPU
|
||||
# Rerank results if explicitly requested, if can use inference server or if device has GPU
|
||||
# AND if we have more than one result
|
||||
rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1
|
||||
rank_results = (
|
||||
rank_results
|
||||
or state.cross_encoder_model[search_model_name].inference_server_enabled()
|
||||
or state.device.type != "cpu"
|
||||
) and len(list(hits)) > 1
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
|
||||
Reference in New Issue
Block a user