Make filters applied before semantic search configurable

Reason
--
This abstraction will simplify adding other pre-search filters. E.g A date-time filter

Capabilities
--
- Multiple filters can be applied on the query, entries etc before search
- The filters to apply are configured for each type in the search controller

Details
--
- Move `explicit_filters` function into separate module under `search_filter`
- Update signature of explicit filter to take and return `query`, `entries`, `embeddings`
- Use this `explicit_filter` function from `search_filters` module in 
   `search` method in controller
- The asymmetric query method now just applies the passed filters to the
  `query`, `entries` and `embeddings` before semantic search is performed
This commit is contained in:
Debanjum
2022-07-13 05:53:02 -07:00
committed by GitHub
3 changed files with 55 additions and 46 deletions

View File

@@ -17,6 +17,7 @@ from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import explicit_filter
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
@@ -58,14 +59,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and model.notes_search: if (t == SearchType.Notes or t == None) and model.notes_search:
# query notes # query notes
hits, entries = asymmetric.query(user_query, model.notes_search, device=device) hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter])
# collate and return results # collate and return results
return asymmetric.collate_results(hits, entries, results_count) return asymmetric.collate_results(hits, entries, results_count)
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
hits, entries = asymmetric.query(user_query, model.music_search, device=device) hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter])
# collate and return results # collate and return results
return asymmetric.collate_results(hits, entries, results_count) return asymmetric.collate_results(hits, entries, results_count)

View File

@@ -0,0 +1,46 @@
# Standard Packages
import re
# External Packages
import torch
def explicit_filter(raw_query, entries, embeddings):
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings
# convert each entry to a set of words
entries_by_word_set = [set(word.lower()
for word
in re.split(
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
entry[0])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings

View File

@@ -3,7 +3,6 @@
# Standard Packages # Standard Packages
import json import json
import gzip import gzip
import re
import argparse import argparse
import pathlib import pathlib
from copy import deepcopy from copy import deepcopy
@@ -15,6 +14,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_filter.explicit_filter import explicit_filter
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
@@ -94,19 +94,17 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []):
"Search all notes for entries that answer the query" "Search all notes for entries that answer the query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
# Copy original embeddings, entries to filter them for query # Copy original embeddings, entries to filter them for query
query = raw_query
corpus_embeddings = deepcopy(model.corpus_embeddings) corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries) entries = deepcopy(model.entries)
# Filter to entries that contain all required_words and no blocked_words # Filter query, entries and embeddings before semantic search
entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words) for filter in filters:
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
return {} return {}
@@ -133,42 +131,6 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
return hits, entries return hits, entries
def explicit_filter(entries, embeddings, required_words, blocked_words):
if len(required_words) == 0 and len(blocked_words) == 0:
return entries, embeddings
# convert each entry to a set of words
entries_by_word_set = [set(word.lower()
for word
in re.split(
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
entry[0])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return entries, embeddings
def render_results(hits, entries, count=5, display_biencoder_results=False): def render_results(hits, entries, count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query" "Render the Results returned by Search for the Query"
if display_biencoder_results: if display_biencoder_results: