Deep copy entries, embeddings in filters. Defer till actual filtering

- Only the filter knows when entries, embeddings are to be manipulated.
  So move the responsibility to deep copy before manipulating entries,
  embeddings to the filters

- Create deep copy in filters. Avoids creating deep copy of entries,
  embeddings when filter results are being loaded from cache etc
This commit is contained in:
Debanjum Singh Solanky
2022-09-04 02:22:42 +03:00
parent 3308e68edf
commit 28d3dc1434
3 changed files with 18 additions and 17 deletions

View File

@@ -3,6 +3,7 @@ import re
from datetime import timedelta, datetime from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO from dateutil.relativedelta import relativedelta, MO
from math import inf from math import inf
from copy import deepcopy
# External Packages # External Packages
import torch import torch
@@ -31,19 +32,23 @@ class DateFilter:
return self.extract_date_range(raw_query) is not None return self.extract_date_range(raw_query) is not None
def apply(self, query, entries, embeddings): def apply(self, query, raw_entries, raw_embeddings):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
query_daterange = self.extract_date_range(query) query_daterange = self.extract_date_range(query)
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
return query, entries, embeddings return query, raw_entries, raw_embeddings
# remove date range filter from query # remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# deep copy original embeddings, entries before filtering
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
entries_to_include = set() entries_to_include = set()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):

View File

@@ -3,6 +3,7 @@ import re
import time import time
import pickle import pickle
import logging import logging
from copy import deepcopy
# External Packages # External Packages
import torch import torch
@@ -61,7 +62,7 @@ class ExplicitFilter:
return len(required_words) != 0 or len(blocked_words) != 0 return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, entries, embeddings): def apply(self, raw_query, raw_entries, raw_embeddings):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters # Separate natural query from explicit required, blocked words filters
start = time.time() start = time.time()
@@ -83,6 +84,13 @@ class ExplicitFilter:
entries, embeddings = self.cache[cache_key] entries, embeddings = self.cache[cache_key]
return query, entries, embeddings return query, entries, embeddings
# deep copy original embeddings, entries before filtering
start = time.time()
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
end = time.time()
logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds")
if not self.entries_by_word_set: if not self.entries_by_word_set:
self.load(entries, regenerate=False) self.load(entries, regenerate=False)

View File

@@ -3,7 +3,6 @@ import argparse
import pathlib import pathlib
import logging import logging
import time import time
from copy import deepcopy
# External Packages # External Packages
import torch import torch
@@ -77,22 +76,11 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
def query(raw_query: str, model: TextSearchModel, rank_results=False): def query(raw_query: str, model: TextSearchModel, rank_results=False):
"Search for entries that answer the query" "Search for entries that answer the query"
query = raw_query query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
logger.debug(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
start = time.time() start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query: for filter in filters_in_query:
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
end = time.time() end = time.time()