diff --git a/src/main.py b/src/main.py index 2b580cb8..00e68612 100644 --- a/src/main.py +++ b/src/main.py @@ -17,6 +17,7 @@ from src.utils.cli import cli from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.rawconfig import FullConfig from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize +from src.search_filter.explicit_filter import explicit_filter # Application Global State config = FullConfig() @@ -58,14 +59,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits, entries = asymmetric.query(user_query, model.notes_search, device=device) + hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits, entries = asymmetric.query(user_query, model.music_search, device=device) + hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 53d41bf4..47526393 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -94,7 +94,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): +def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []): "Search all notes for entries that answer the query" # Copy original embeddings, entries to filter them for query @@ -102,8 +102,9 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): corpus_embeddings = deepcopy(model.corpus_embeddings) entries = deepcopy(model.entries) - # Filter to entries that contain all required_words and no blocked_words - query, entries, corpus_embeddings = explicit_filter(query, entries, corpus_embeddings) + # Filter query, entries and embeddings before semantic search + for filter in filters: + query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) if entries is None or len(entries) == 0: return {}