Use new Text Entry class to track text entries in Intermediate Format

- Context
  - The app maintains all text content in a standard, intermediate format
  - The intermediate format was loaded, passed around as a dictionary
    for easier, faster updates to the intermediate format schema initially
  - The intermediate format is reasonably stable now, given it's usage
    by all 3 text content types currently implemented

- Changes
  - Concretize text entries into `Entries' class instead of using dictionaries
    - Code is updated to load, pass around entries as `Entries' objects
      instead of as dictionaries
    - `text_search' and `text_to_jsonl' methods are annotated with
       type hints for the new `Entries' type
    - Code and Tests referencing entries are updated to use class style
      access patterns instead of the previous dictionary access patterns

  - Move `mark_entries_for_update' method into `TextToJsonl' base class
    - This is a more natural location for the method as it is only
      (to be) used by `text_to_jsonl' classes
    - Avoid circular reference issues on importing `Entries' class
This commit is contained in:
Debanjum Singh Solanky
2022-09-15 23:34:43 +03:00
parent 99754970ab
commit 7e9298f315
15 changed files with 161 additions and 131 deletions

View File

@@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
from src.utils.jsonl import load_jsonl
@@ -50,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file):
def extract_entries(jsonl_file) -> list[Entry]:
"Load entries from compressed jsonl"
return load_jsonl(jsonl_file)
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
@@ -64,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate
logger.info(f"Loaded embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
# Save regenerated or updated embeddings to file
@@ -133,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Score all retrieved entries using the cross-encoder
if rank_results:
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
@@ -153,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
return hits, entries
def render_results(hits, entries, count=5, display_biencoder_results=False):
def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query"
if display_biencoder_results:
# Output of top hits from bi-encoder
@@ -161,20 +161,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}")
# Output of top hits from re-ranker
print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}")
def collate_results(hits, entries, count=5) -> list[SearchResponse]:
def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
return [SearchResponse.parse_obj(
{
"entry": entries[hit['corpus_id']]['raw'],
"entry": entries[hit['corpus_id']].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
})
for hit