From ad41ef39918120463f6af7e53de0a5b50a1a8230 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 16 Jul 2023 02:16:33 -0700 Subject: [PATCH] Make normalizing embeddings configurable --- src/khoj/search_type/text_search.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index ed3be33c..c123974a 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -58,7 +58,11 @@ def extract_entries(jsonl_file) -> List[Entry]: def compute_embeddings( - entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False + entries_with_ids: List[Tuple[int, Entry]], + bi_encoder: BaseEncoder, + embeddings_file: Path, + regenerate=False, + normalize=True, ): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] @@ -87,8 +91,11 @@ def compute_embeddings( existing_embeddings = torch.tensor([], device=state.device) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) + if normalize: + # Normalize embeddings for faster lookup via dot product when querying + corpus_embeddings = util.normalize_embeddings(corpus_embeddings) + # Save regenerated or updated embeddings to file - corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) logger.info(f"📩 Saved computed text embeddings to {embeddings_file}") @@ -169,6 +176,7 @@ def setup( bi_encoder: BaseEncoder, regenerate: bool, filters: List[BaseFilter] = [], + normalize: bool = True, ) -> TextContent: # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) @@ -186,7 +194,7 @@ def setup( # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings( - entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate + entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize ) for filter in filters: