Pre-compute entry word sets to improve explicit filter query performance

This commit is contained in:
Debanjum Singh Solanky
2022-09-03 16:01:54 +03:00
parent 094bd18e57
commit c7de57b8ea
6 changed files with 81 additions and 26 deletions

View File

@@ -8,11 +8,13 @@ from copy import deepcopy
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.search_filter.date_filter import DateFilter
from src.search_filter.explicit_filter import ExplicitFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel
from src.utils.config import SearchType, TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
@@ -73,13 +75,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
def query(raw_query: str, model: TextSearchModel, rank_results=False):
"Search for entries that answer the query"
query = raw_query
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
@@ -92,7 +94,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
# Filter query, entries and embeddings before semantic search
start = time.time()
for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
end = time.time()
logger.debug(f"Filter Time: {end - start:.3f} seconds")
@@ -163,7 +165,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
@@ -180,7 +182,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
if __name__ == '__main__':