From 1105d8814fcd83ae6fa1357547042e3c6a29d34e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 14 Feb 2024 18:37:53 +0530 Subject: [PATCH 1/3] Use cross-encoder to rerank search results by default on GPU machines Latest sentence-transformer package uses GPU for cross-encoder. This makes it fast enough to enable reranking on machines with GPU. Enabling search reranking by default allows (at least) users with GPUs to side-step learning the UI affordance to rerank results (i.e hitting Cmd/Ctrl-Enter or ENTER). --- pyproject.toml | 2 +- src/khoj/search_type/text_search.py | 5 +++-- src/khoj/utils/helpers.py | 7 ++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17003c6c..8bab7876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "pyyaml == 6.0", "rich >= 13.3.1", "schedule == 1.1.0", - "sentence-transformers == 2.3.1", + "sentence-transformers == 2.5.1", "transformers >= 4.28.0", "torch == 2.0.1", "uvicorn == 0.17.6", diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index d5ea35e6..48bc9e46 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -177,8 +177,9 @@ def deduplicated_search_responses(hits: List[SearchResponse]): def rerank_and_sort_results(hits, query, rank_results, search_model_name): - # If we have more than one result and reranking is enabled - rank_results = rank_results and len(list(hits)) > 1 + # Rerank results if explicitly requested or if device has GPU + # AND if we have more than one result + rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d2b64296..f30ddd04 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -331,7 +331,12 @@ def batcher(iterable, max_n): yield (x for x in chunk if x is not None) +def is_env_var_true(env_var: str, default: str = "false") -> bool: + """Get state of boolean environment variable""" + return os.getenv(env_var, default).lower() == "true" + + def in_debug_mode(): """Check if Khoj is running in debug mode. Set KHOJ_DEBUG environment variable to true to enable debug mode.""" - return os.getenv("KHOJ_DEBUG", "false").lower() == "true" + return is_env_var_true("KHOJ_DEBUG") From 44c8d09342d3995854b53b1360645a660dcebddb Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 9 Mar 2024 17:53:04 +0530 Subject: [PATCH 2/3] Only call search API when pause in typing search query on web, desktop apps Wait for 300ms since stop typing before calling search API. This smooths out UI jitter when rendering search results, especially now that we're reranking for every search query on GPU enabled devices Emacs already has 300ms debounce time. More convoluted to add debounce time to Obsidian search modal, so not updating that yet --- src/interface/desktop/search.html | 19 ++++++++++--------- src/khoj/interface/web/search.html | 19 ++++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/interface/desktop/search.html b/src/interface/desktop/search.html index 4ca248c5..ffc1bf64 100644 --- a/src/interface/desktop/search.html +++ b/src/interface/desktop/search.html @@ -192,16 +192,17 @@ }); } + let debounceTimeout; function incrementalSearch(event) { - type = 'all'; - // Search with reranking on 'Enter' - if (event.key === 'Enter') { - search(rerank=true); - } - // Limit incremental search to text types - else if (type !== "image") { - search(rerank=false); - } + // Run incremental search only after waitTime passed since the last key press + let waitTime = 300; + clearTimeout(debounceTimeout); + debounceTimeout = setTimeout(() => { + type = 'all'; + // Search with reranking on 'Enter' + let should_rerank = event.key === 'Enter'; + search(rerank=should_rerank); + }, waitTime); } async function populate_type_dropdown() { diff --git a/src/khoj/interface/web/search.html b/src/khoj/interface/web/search.html index 98e37cb8..8bbd9f32 100644 --- a/src/khoj/interface/web/search.html +++ b/src/khoj/interface/web/search.html @@ -193,16 +193,17 @@ }); } + let debounceTimeout; function incrementalSearch(event) { - type = document.getElementById("type").value; - // Search with reranking on 'Enter' - if (event.key === 'Enter') { - search(rerank=true); - } - // Limit incremental search to text types - else if (type !== "image") { - search(rerank=false); - } + // Run incremental search only after waitTime passed since the last key press + let waitTime = 300; + clearTimeout(debounceTimeout); + debounceTimeout = setTimeout(() => { + type = document.getElementById("type").value; + // Search with reranking on 'Enter' + let should_rerank = event.key === 'Enter'; + search(rerank=should_rerank); + }, waitTime); } function populate_type_dropdown() { From 53d402480c650f8a4c34514f1c9904ef478aeebd Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 10 Mar 2024 12:25:19 +0530 Subject: [PATCH 3/3] Rerank search results with cross-encoder when using an inference server If an inference server is being used, we can expect the cross encoder to be running fast enough to rerank search results by default --- src/khoj/processor/embeddings.py | 21 +++++++++++---------- src/khoj/search_type/text_search.py | 8 ++++++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index cada1532..ec8e08f0 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -33,8 +33,11 @@ class EmbeddingsModel: self.api_key = embeddings_inference_endpoint_api_key self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) + def inference_server_enabled(self) -> bool: + return self.api_key is not None and self.inference_endpoint is not None + def embed_query(self, query): - if self.api_key is not None and self.inference_endpoint is not None: + if self.inference_server_enabled(): return self.embed_with_api([query])[0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] @@ -62,11 +65,10 @@ class EmbeddingsModel: return response.json()["embeddings"] def embed_documents(self, docs): - if self.api_key is not None and self.inference_endpoint is not None: - target_url = f"{self.inference_endpoint}" - if "huggingface" not in target_url: + if self.inference_server_enabled(): + if "huggingface" not in self.inference_endpoint: logger.warning( - f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint." + f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." ) return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits @@ -93,12 +95,11 @@ class CrossEncoderModel: self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key + def inference_server_enabled(self) -> bool: + return self.api_key is not None and self.inference_endpoint is not None + def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): - if ( - self.api_key is not None - and self.inference_endpoint is not None - and "huggingface" in self.inference_endpoint - ): + if self.inference_server_enabled() and "huggingface" in self.inference_endpoint: target_url = f"{self.inference_endpoint}" payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 48bc9e46..a172529f 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -177,9 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]): def rerank_and_sort_results(hits, query, rank_results, search_model_name): - # Rerank results if explicitly requested or if device has GPU + # Rerank results if explicitly requested, if can use inference server or if device has GPU # AND if we have more than one result - rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1 + rank_results = ( + rank_results + or state.cross_encoder_model[search_model_name].inference_server_enabled() + or state.device.type != "cpu" + ) and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: