diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index d91ebd83..cab47cbb 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -3,6 +3,7 @@ import re from datetime import timedelta, datetime from dateutil.relativedelta import relativedelta, MO from math import inf +from copy import deepcopy # External Packages import torch @@ -31,19 +32,23 @@ class DateFilter: return self.extract_date_range(raw_query) is not None - def apply(self, query, entries, embeddings): + def apply(self, query, raw_entries, raw_embeddings): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) # if no date in query, return all entries if query_daterange is None: - return query, entries, embeddings + return query, raw_entries, raw_embeddings # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + # deep copy original embeddings, entries before filtering + embeddings= deepcopy(raw_embeddings) + entries = deepcopy(raw_entries) + # find entries containing any dates that fall with date range specified in query entries_to_include = set() for id, entry in enumerate(entries): diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 2707155b..e715e8b6 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -3,6 +3,7 @@ import re import time import pickle import logging +from copy import deepcopy # External Packages import torch @@ -61,7 +62,7 @@ class ExplicitFilter: return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, raw_query, entries, embeddings): + def apply(self, raw_query, raw_entries, raw_embeddings): "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters start = time.time() @@ -83,6 +84,13 @@ class ExplicitFilter: entries, embeddings = self.cache[cache_key] return query, entries, embeddings + # deep copy original embeddings, entries before filtering + start = time.time() + embeddings= deepcopy(raw_embeddings) + entries = deepcopy(raw_entries) + end = time.time() + logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds") + if not self.entries_by_word_set: self.load(entries, regenerate=False) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 4f83236e..742ff5ed 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -3,7 +3,6 @@ import argparse import pathlib import logging import time -from copy import deepcopy # External Packages import torch @@ -77,22 +76,11 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): def query(raw_query: str, model: TextSearchModel, rank_results=False): "Search for entries that answer the query" - query = raw_query - - # Use deep copy of original embeddings, entries to filter if query contains filters - start = time.time() - filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] - if filters_in_query: - corpus_embeddings = deepcopy(model.corpus_embeddings) - entries = deepcopy(model.entries) - else: - corpus_embeddings = model.corpus_embeddings - entries = model.entries - end = time.time() - logger.debug(f"Copy Time: {end - start:.3f} seconds") + query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search start = time.time() + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) end = time.time()