From 53d402480c650f8a4c34514f1c9904ef478aeebd Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 10 Mar 2024 12:25:19 +0530 Subject: [PATCH] Rerank search results with cross-encoder when using an inference server If an inference server is being used, we can expect the cross encoder to be running fast enough to rerank search results by default --- src/khoj/processor/embeddings.py | 21 +++++++++++---------- src/khoj/search_type/text_search.py | 8 ++++++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index cada1532..ec8e08f0 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -33,8 +33,11 @@ class EmbeddingsModel: self.api_key = embeddings_inference_endpoint_api_key self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) + def inference_server_enabled(self) -> bool: + return self.api_key is not None and self.inference_endpoint is not None + def embed_query(self, query): - if self.api_key is not None and self.inference_endpoint is not None: + if self.inference_server_enabled(): return self.embed_with_api([query])[0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] @@ -62,11 +65,10 @@ class EmbeddingsModel: return response.json()["embeddings"] def embed_documents(self, docs): - if self.api_key is not None and self.inference_endpoint is not None: - target_url = f"{self.inference_endpoint}" - if "huggingface" not in target_url: + if self.inference_server_enabled(): + if "huggingface" not in self.inference_endpoint: logger.warning( - f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint." + f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." ) return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits @@ -93,12 +95,11 @@ class CrossEncoderModel: self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key + def inference_server_enabled(self) -> bool: + return self.api_key is not None and self.inference_endpoint is not None + def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): - if ( - self.api_key is not None - and self.inference_endpoint is not None - and "huggingface" in self.inference_endpoint - ): + if self.inference_server_enabled() and "huggingface" in self.inference_endpoint: target_url = f"{self.inference_endpoint}" payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 48bc9e46..a172529f 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -177,9 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]): def rerank_and_sort_results(hits, query, rank_results, search_model_name): - # Rerank results if explicitly requested or if device has GPU + # Rerank results if explicitly requested, if can use inference server or if device has GPU # AND if we have more than one result - rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1 + rank_results = ( + rank_results + or state.cross_encoder_model[search_model_name].inference_server_enabled() + or state.device.type != "cpu" + ) and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: