Deep copy entries, embeddings in filters. Defer till actual filtering

- Only the filter knows when entries, embeddings are to be manipulated.
  So move the responsibility to deep copy before manipulating entries,
  embeddings to the filters

- Create deep copy in filters. Avoids creating deep copy of entries,
  embeddings when filter results are being loaded from cache etc
This commit is contained in:
Debanjum Singh Solanky
2022-09-04 02:22:42 +03:00
parent 3308e68edf
commit 28d3dc1434
3 changed files with 18 additions and 17 deletions

View File

@@ -3,6 +3,7 @@ import re
from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO
from math import inf
from copy import deepcopy
# External Packages
import torch
@@ -31,19 +32,23 @@ class DateFilter:
return self.extract_date_range(raw_query) is not None
def apply(self, query, entries, embeddings):
def apply(self, query, raw_entries, raw_embeddings):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange is None:
return query, entries, embeddings
return query, raw_entries, raw_embeddings
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# deep copy original embeddings, entries before filtering
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
# find entries containing any dates that fall with date range specified in query
entries_to_include = set()
for id, entry in enumerate(entries):

View File

@@ -3,6 +3,7 @@ import re
import time
import pickle
import logging
from copy import deepcopy
# External Packages
import torch
@@ -61,7 +62,7 @@ class ExplicitFilter:
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, entries, embeddings):
def apply(self, raw_query, raw_entries, raw_embeddings):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters
start = time.time()
@@ -83,6 +84,13 @@ class ExplicitFilter:
entries, embeddings = self.cache[cache_key]
return query, entries, embeddings
# deep copy original embeddings, entries before filtering
start = time.time()
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
end = time.time()
logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds")
if not self.entries_by_word_set:
self.load(entries, regenerate=False)