Make filters to apply before semantic search configurable

Details
--
- The filters to apply are configured for each type in the search controller
- Muliple filters can be applied on the query, entries etc before search
- The asymmetric query method now just applies the passed filters to the
  query, entries and embeddings before semantic search is performed

Reason
--
This abstraction will simplify adding other pre-search filters. E.g datetime filter
This commit is contained in:
Debanjum Singh Solanky
2022-07-13 16:29:23 +04:00
parent c92789d20a
commit b82aef26bf
2 changed files with 7 additions and 5 deletions

View File

@@ -17,6 +17,7 @@ from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import explicit_filter
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
@@ -58,14 +59,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and model.notes_search: if (t == SearchType.Notes or t == None) and model.notes_search:
# query notes # query notes
hits, entries = asymmetric.query(user_query, model.notes_search, device=device) hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter])
# collate and return results # collate and return results
return asymmetric.collate_results(hits, entries, results_count) return asymmetric.collate_results(hits, entries, results_count)
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
hits, entries = asymmetric.query(user_query, model.music_search, device=device) hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter])
# collate and return results # collate and return results
return asymmetric.collate_results(hits, entries, results_count) return asymmetric.collate_results(hits, entries, results_count)

View File

@@ -94,7 +94,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []):
"Search all notes for entries that answer the query" "Search all notes for entries that answer the query"
# Copy original embeddings, entries to filter them for query # Copy original embeddings, entries to filter them for query
@@ -102,8 +102,9 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
corpus_embeddings = deepcopy(model.corpus_embeddings) corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries) entries = deepcopy(model.entries)
# Filter to entries that contain all required_words and no blocked_words # Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = explicit_filter(query, entries, corpus_embeddings) for filter in filters:
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
return {} return {}