diff --git a/src/main.py b/src/main.py index 2b580cb8..00e68612 100644 --- a/src/main.py +++ b/src/main.py @@ -17,6 +17,7 @@ from src.utils.cli import cli from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.rawconfig import FullConfig from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize +from src.search_filter.explicit_filter import explicit_filter # Application Global State config = FullConfig() @@ -58,14 +59,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits, entries = asymmetric.query(user_query, model.notes_search, device=device) + hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits, entries = asymmetric.query(user_query, model.music_search, device=device) + hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py new file mode 100644 index 00000000..363dbd71 --- /dev/null +++ b/src/search_filter/explicit_filter.py @@ -0,0 +1,46 @@ +# Standard Packages +import re + +# External Packages +import torch + + +def explicit_filter(raw_query, entries, embeddings): + # Separate natural query from explicit required, blocked words filters + query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) + required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) + blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + + if len(required_words) == 0 and len(blocked_words) == 0: + return query, entries, embeddings + + # convert each entry to a set of words + entries_by_word_set = [set(word.lower() + for word + in re.split( + r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets + entry[0]) + if word != "") + for entry in entries] + + # track id of entries to exclude + entries_to_exclude = set() + + # mark entries that do not contain all required_words for exclusion + if len(required_words) > 0: + for id, words_in_entry in enumerate(entries_by_word_set): + if not required_words.issubset(words_in_entry): + entries_to_exclude.add(id) + + # mark entries that contain any blocked_words for exclusion + if len(blocked_words) > 0: + for id, words_in_entry in enumerate(entries_by_word_set): + if words_in_entry.intersection(blocked_words): + entries_to_exclude.add(id) + + # delete entries (and their embeddings) marked for exclusion + for id in sorted(list(entries_to_exclude), reverse=True): + del entries[id] + embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + + return query, entries, embeddings \ No newline at end of file diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 611501f7..47526393 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -3,7 +3,6 @@ # Standard Packages import json import gzip -import re import argparse import pathlib from copy import deepcopy @@ -15,6 +14,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.explicit_filter import explicit_filter from src.utils.config import TextSearchModel from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig from src.utils.constants import empty_escape_sequences @@ -94,19 +94,17 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): +def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []): "Search all notes for entries that answer the query" - # Separate natural query from explicit required, blocked words filters - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) # Copy original embeddings, entries to filter them for query + query = raw_query corpus_embeddings = deepcopy(model.corpus_embeddings) entries = deepcopy(model.entries) - # Filter to entries that contain all required_words and no blocked_words - entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words) + # Filter query, entries and embeddings before semantic search + for filter in filters: + query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) if entries is None or len(entries) == 0: return {} @@ -133,42 +131,6 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): return hits, entries -def explicit_filter(entries, embeddings, required_words, blocked_words): - if len(required_words) == 0 and len(blocked_words) == 0: - return entries, embeddings - - # convert each entry to a set of words - entries_by_word_set = [set(word.lower() - for word - in re.split( - r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets - entry[0]) - if word != "") - for entry in entries] - - # track id of entries to exclude - entries_to_exclude = set() - - # mark entries that do not contain all required_words for exclusion - if len(required_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if not required_words.issubset(words_in_entry): - entries_to_exclude.add(id) - - # mark entries that contain any blocked_words for exclusion - if len(blocked_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if words_in_entry.intersection(blocked_words): - entries_to_exclude.add(id) - - # delete entries (and their embeddings) marked for exclusion - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - - return entries, embeddings - - def render_results(hits, entries, count=5, display_biencoder_results=False): "Render the Results returned by Search for the Query" if display_biencoder_results: