diff --git a/src/main.py b/src/main.py index addd55d5..ef40c148 100644 --- a/src/main.py +++ b/src/main.py @@ -74,10 +74,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Ledger or t == None) and model.ledger_search: # query transactions - hits = symmetric_ledger.query(user_query, model.ledger_search) + hits, entries = symmetric_ledger.query(user_query, model.ledger_search) # collate and return results - return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count) + return symmetric_ledger.collate_results(hits, entries, results_count) if (t == SearchType.Image or t == None) and model.image_search: # query transactions diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index 5243c1aa..814813a1 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -1,9 +1,7 @@ # Standard Packages -import json -import gzip -import re import argparse import pathlib +from copy import deepcopy # External Packages import torch @@ -62,27 +60,27 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v return corpus_embeddings -def query(raw_query, model: TextSearchModel): +def query(raw_query, model: TextSearchModel, filters=[]): "Search all notes for entries that answer the query" - # Separate natural query from explicit required, blocked words filters - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + # Copy original embeddings, entries to filter them for query + query = raw_query + corpus_embeddings = deepcopy(model.corpus_embeddings) + entries = deepcopy(model.entries) + + # Filter query, entries and embeddings before semantic search + for filter in filters: + query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) + if entries is None or len(entries) == 0: + return [], [] # Encode the query using the bi-encoder question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True) # Find relevant entries for the query - hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k) - hits = hits[0] # Get the hits for the first query - - # Filter results using explicit filters - hits = explicit_filter(hits, model.entries, required_words, blocked_words) - if hits is None or len(hits) == 0: - return hits + hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0] # Score all retrieved entries using the cross-encoder - cross_inp = [[query, model.entries[hit['corpus_id']]] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking @@ -93,28 +91,7 @@ def query(raw_query, model: TextSearchModel): hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score - return hits - - -def explicit_filter(hits, entries, required_words, blocked_words): - hits_by_word_set = [(set(word.lower() - for word - in re.split( - r',|\.| |\]|\[\(|\)|\{|\}', - entries[hit['corpus_id']]) - if word != ""), - hit) - for hit in hits] - - if len(required_words) == 0 and len(blocked_words) == 0: - return hits - if len(required_words) > 0: - return [hit for (words_in_entry, hit) in hits_by_word_set - if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)] - if len(blocked_words) > 0: - return [hit for (words_in_entry, hit) in hits_by_word_set - if not blocked_words.intersection(words_in_entry)] - return hits + return hits, entries def render_results(hits, entries, count=5, display_biencoder_results=False):