Rerank search results with cross-encoder when using an inference server

If an inference server is being used, we can expect the cross encoder
to be running fast enough to rerank search results by default
This commit is contained in:
Debanjum Singh Solanky
2024-03-10 12:25:19 +05:30
parent 44c8d09342
commit 53d402480c
2 changed files with 17 additions and 12 deletions

View File

@@ -33,8 +33,11 @@ class EmbeddingsModel:
self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None:
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
@@ -62,11 +65,10 @@ class EmbeddingsModel:
return response.json()["embeddings"]
def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
if "huggingface" not in target_url:
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
@@ -93,12 +95,11 @@ class CrossEncoderModel:
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if (
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

View File

@@ -177,9 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
# Rerank results if explicitly requested or if device has GPU
# Rerank results if explicitly requested, if can use inference server or if device has GPU
# AND if we have more than one result
rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1
rank_results = (
rank_results
or state.cross_encoder_model[search_model_name].inference_server_enabled()
or state.device.type != "cpu"
) and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder
if rank_results: