diff --git a/src/main.py b/src/main.py index 4666e30f..2b580cb8 100644 --- a/src/main.py +++ b/src/main.py @@ -58,17 +58,17 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits = asymmetric.query(user_query, model.notes_search, device=device) + hits, entries = asymmetric.query(user_query, model.notes_search, device=device) # collate and return results - return asymmetric.collate_results(hits, model.notes_search.entries, results_count) + return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits = asymmetric.query(user_query, model.music_search, device=device) + hits, entries = asymmetric.query(user_query, model.music_search, device=device) # collate and return results - return asymmetric.collate_results(hits, model.music_search.entries, results_count) + return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Ledger or t == None) and model.ledger_search: # query transactions diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 524eb9a1..611501f7 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -6,6 +6,7 @@ import gzip import re import argparse import pathlib +from copy import deepcopy # External Packages import torch @@ -100,24 +101,25 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + # Copy original embeddings, entries to filter them for query + corpus_embeddings = deepcopy(model.corpus_embeddings) + entries = deepcopy(model.entries) + + # Filter to entries that contain all required_words and no blocked_words + entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words) + if entries is None or len(entries) == 0: + return {} + # Encode the query using the bi-encoder question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding.to(device) question_embedding = util.normalize_embeddings(question_embedding) # Find relevant entries for the query - hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score) - hits = hits[0] # Get the hits for the first query - - # Filter out entries that contain required words and do not contain blocked words - hits = explicit_filter(hits, - [entry[0] for entry in model.entries], - required_words,blocked_words) - if hits is None or len(hits) == 0: - return hits + hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] # Score all retrieved entries using the cross-encoder - cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking @@ -128,28 +130,43 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score - return hits + return hits, entries -def explicit_filter(hits, entries, required_words, blocked_words): - hits_by_word_set = [(set(word.lower() +def explicit_filter(entries, embeddings, required_words, blocked_words): + if len(required_words) == 0 and len(blocked_words) == 0: + return entries, embeddings + + # convert each entry to a set of words + entries_by_word_set = [set(word.lower() for word in re.split( - r',|\.| |\]|\[\(|\)|\{|\}', - entries[hit['corpus_id']]) - if word != ""), - hit) - for hit in hits] + r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets + entry[0]) + if word != "") + for entry in entries] - if len(required_words) == 0 and len(blocked_words) == 0: - return hits + # track id of entries to exclude + entries_to_exclude = set() + + # mark entries that do not contain all required_words for exclusion if len(required_words) > 0: - return [hit for (words_in_entry, hit) in hits_by_word_set - if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)] + for id, words_in_entry in enumerate(entries_by_word_set): + if not required_words.issubset(words_in_entry): + entries_to_exclude.add(id) + + # mark entries that contain any blocked_words for exclusion if len(blocked_words) > 0: - return [hit for (words_in_entry, hit) in hits_by_word_set - if not blocked_words.intersection(words_in_entry)] - return hits + for id, words_in_entry in enumerate(entries_by_word_set): + if words_in_entry.intersection(blocked_words): + entries_to_exclude.add(id) + + # delete entries (and their embeddings) marked for exclusion + for id in sorted(list(entries_to_exclude), reverse=True): + del entries[id] + embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + + return entries, embeddings def render_results(hits, entries, count=5, display_biencoder_results=False):