From 93f39dbd432a8ad15b475a13cd1f2ff0ebe1d1f6 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 9 Jan 2023 18:53:23 -0300 Subject: [PATCH] Add typing to text_search. Reformat code to set existing_embedding --- src/search_type/text_search.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index a7d1b7dd..43d6d187 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -1,5 +1,6 @@ # Standard Packages import logging +from pathlib import Path import time from typing import Type @@ -57,7 +58,7 @@ def extract_entries(jsonl_file) -> list[Entry]: return list(map(Entry.from_dict, load_jsonl(jsonl_file))) -def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False): +def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] # Load pre-computed embeddings from file if exists and update them if required @@ -70,7 +71,10 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba if new_entries: new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] - existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) if existing_entry_ids else torch.tensor([], device=state.device) + if existing_entry_ids: + existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) + else: + existing_embeddings = torch.tensor([], device=state.device) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) # Else compute the corpus embeddings from scratch else: @@ -86,7 +90,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results: bool = False): +def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> tuple[list[dict], list[Entry]]: "Search for entries that answer the query" query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings