diff --git a/pyproject.toml b/pyproject.toml index 17003c6c..8bab7876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "pyyaml == 6.0", "rich >= 13.3.1", "schedule == 1.1.0", - "sentence-transformers == 2.3.1", + "sentence-transformers == 2.5.1", "transformers >= 4.28.0", "torch == 2.0.1", "uvicorn == 0.17.6", diff --git a/src/interface/desktop/search.html b/src/interface/desktop/search.html index 4ca248c5..ffc1bf64 100644 --- a/src/interface/desktop/search.html +++ b/src/interface/desktop/search.html @@ -192,16 +192,17 @@ }); } + let debounceTimeout; function incrementalSearch(event) { - type = 'all'; - // Search with reranking on 'Enter' - if (event.key === 'Enter') { - search(rerank=true); - } - // Limit incremental search to text types - else if (type !== "image") { - search(rerank=false); - } + // Run incremental search only after waitTime passed since the last key press + let waitTime = 300; + clearTimeout(debounceTimeout); + debounceTimeout = setTimeout(() => { + type = 'all'; + // Search with reranking on 'Enter' + let should_rerank = event.key === 'Enter'; + search(rerank=should_rerank); + }, waitTime); } async function populate_type_dropdown() { diff --git a/src/khoj/interface/web/search.html b/src/khoj/interface/web/search.html index 98e37cb8..8bbd9f32 100644 --- a/src/khoj/interface/web/search.html +++ b/src/khoj/interface/web/search.html @@ -193,16 +193,17 @@ }); } + let debounceTimeout; function incrementalSearch(event) { - type = document.getElementById("type").value; - // Search with reranking on 'Enter' - if (event.key === 'Enter') { - search(rerank=true); - } - // Limit incremental search to text types - else if (type !== "image") { - search(rerank=false); - } + // Run incremental search only after waitTime passed since the last key press + let waitTime = 300; + clearTimeout(debounceTimeout); + debounceTimeout = setTimeout(() => { + type = document.getElementById("type").value; + // Search with reranking on 'Enter' + let should_rerank = event.key === 'Enter'; + search(rerank=should_rerank); + }, waitTime); } function populate_type_dropdown() { diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index cada1532..ec8e08f0 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -33,8 +33,11 @@ class EmbeddingsModel: self.api_key = embeddings_inference_endpoint_api_key self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) + def inference_server_enabled(self) -> bool: + return self.api_key is not None and self.inference_endpoint is not None + def embed_query(self, query): - if self.api_key is not None and self.inference_endpoint is not None: + if self.inference_server_enabled(): return self.embed_with_api([query])[0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] @@ -62,11 +65,10 @@ class EmbeddingsModel: return response.json()["embeddings"] def embed_documents(self, docs): - if self.api_key is not None and self.inference_endpoint is not None: - target_url = f"{self.inference_endpoint}" - if "huggingface" not in target_url: + if self.inference_server_enabled(): + if "huggingface" not in self.inference_endpoint: logger.warning( - f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint." + f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." ) return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits @@ -93,12 +95,11 @@ class CrossEncoderModel: self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key + def inference_server_enabled(self) -> bool: + return self.api_key is not None and self.inference_endpoint is not None + def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): - if ( - self.api_key is not None - and self.inference_endpoint is not None - and "huggingface" in self.inference_endpoint - ): + if self.inference_server_enabled() and "huggingface" in self.inference_endpoint: target_url = f"{self.inference_endpoint}" payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index d5ea35e6..a172529f 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -177,8 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]): def rerank_and_sort_results(hits, query, rank_results, search_model_name): - # If we have more than one result and reranking is enabled - rank_results = rank_results and len(list(hits)) > 1 + # Rerank results if explicitly requested, if can use inference server or if device has GPU + # AND if we have more than one result + rank_results = ( + rank_results + or state.cross_encoder_model[search_model_name].inference_server_enabled() + or state.device.type != "cpu" + ) and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d2b64296..f30ddd04 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -331,7 +331,12 @@ def batcher(iterable, max_n): yield (x for x in chunk if x is not None) +def is_env_var_true(env_var: str, default: str = "false") -> bool: + """Get state of boolean environment variable""" + return os.getenv(env_var, default).lower() == "true" + + def in_debug_mode(): """Check if Khoj is running in debug mode. Set KHOJ_DEBUG environment variable to true to enable debug mode.""" - return os.getenv("KHOJ_DEBUG", "false").lower() == "true" + return is_env_var_true("KHOJ_DEBUG")