Run Explicit Filter on Entries, Embeddings before Semantic Search for Query

- Issue
  - Explicit filtering was earlier being done after search by bi-encoder
    but before re-ranking by cross-encoder

  - This was limiting the quality of results being returned. As the
    bi-encoder returned results which were going to be excluded. So the
    burden of improving those limited results post filtering was on the
    cross-encoder by re-ranking the remaining results based on query

- Fix
  - Given the embeddings corresponding to an entry are at the same index
    in their respective lists. We can run the filter for blocked,
    required words before the search by the bi-encoder model. And limit
    entries, embeddings being considered for the current query

- Result
  - Semantic search by the bi-encoder gets to return most relevant
    results for the query, knowing that the results aren't going to be
    filtered out after. So the cross-encoder shoulders less of the
    burden of improving results

- Corollary
  - This pre-filtering technique allows us to apply other explicit
    filters on entries relevant for the current query
    - E.g limit search for entries within date/time specified in query
This commit is contained in:
Debanjum Singh Solanky
2022-07-12 13:58:32 +04:00
parent 8bb9a49994
commit 6d7ab50113
2 changed files with 46 additions and 29 deletions

View File

@@ -58,17 +58,17 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and model.notes_search: if (t == SearchType.Notes or t == None) and model.notes_search:
# query notes # query notes
hits = asymmetric.query(user_query, model.notes_search, device=device) hits, entries = asymmetric.query(user_query, model.notes_search, device=device)
# collate and return results # collate and return results
return asymmetric.collate_results(hits, model.notes_search.entries, results_count) return asymmetric.collate_results(hits, entries, results_count)
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
hits = asymmetric.query(user_query, model.music_search, device=device) hits, entries = asymmetric.query(user_query, model.music_search, device=device)
# collate and return results # collate and return results
return asymmetric.collate_results(hits, model.music_search.entries, results_count) return asymmetric.collate_results(hits, entries, results_count)
if (t == SearchType.Ledger or t == None) and model.ledger_search: if (t == SearchType.Ledger or t == None) and model.ledger_search:
# query transactions # query transactions

View File

@@ -6,6 +6,7 @@ import gzip
import re import re
import argparse import argparse
import pathlib import pathlib
from copy import deepcopy
# External Packages # External Packages
import torch import torch
@@ -100,24 +101,25 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
# Copy original embeddings, entries to filter them for query
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
# Filter to entries that contain all required_words and no blocked_words
entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words)
if entries is None or len(entries) == 0:
return {}
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
question_embedding.to(device) question_embedding.to(device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query # Find relevant entries for the query
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score) hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
hits = hits[0] # Get the hits for the first query
# Filter out entries that contain required words and do not contain blocked words
hits = explicit_filter(hits,
[entry[0] for entry in model.entries],
required_words,blocked_words)
if hits is None or len(hits) == 0:
return hits
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp) cross_scores = model.cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
@@ -128,28 +130,43 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
return hits return hits, entries
def explicit_filter(hits, entries, required_words, blocked_words): def explicit_filter(entries, embeddings, required_words, blocked_words):
hits_by_word_set = [(set(word.lower() if len(required_words) == 0 and len(blocked_words) == 0:
return entries, embeddings
# convert each entry to a set of words
entries_by_word_set = [set(word.lower()
for word for word
in re.split( in re.split(
r',|\.| |\]|\[\(|\)|\{|\}', r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
entries[hit['corpus_id']]) entry[0])
if word != ""), if word != "")
hit) for entry in entries]
for hit in hits]
if len(required_words) == 0 and len(blocked_words) == 0: # track id of entries to exclude
return hits entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0: if len(required_words) > 0:
return [hit for (words_in_entry, hit) in hits_by_word_set for id, words_in_entry in enumerate(entries_by_word_set):
if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)] if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0: if len(blocked_words) > 0:
return [hit for (words_in_entry, hit) in hits_by_word_set for id, words_in_entry in enumerate(entries_by_word_set):
if not blocked_words.intersection(words_in_entry)] if words_in_entry.intersection(blocked_words):
return hits entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return entries, embeddings
def render_results(hits, entries, count=5, display_biencoder_results=False): def render_results(hits, entries, count=5, display_biencoder_results=False):