mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Make filters applied before semantic search configurable
Reason -- This abstraction will simplify adding other pre-search filters. E.g A date-time filter Capabilities -- - Multiple filters can be applied on the query, entries etc before search - The filters to apply are configured for each type in the search controller Details -- - Move `explicit_filters` function into separate module under `search_filter` - Update signature of explicit filter to take and return `query`, `entries`, `embeddings` - Use this `explicit_filter` function from `search_filters` module in `search` method in controller - The asymmetric query method now just applies the passed filters to the `query`, `entries` and `embeddings` before semantic search is performed
This commit is contained in:
@@ -17,6 +17,7 @@ from src.utils.cli import cli
|
|||||||
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||||
from src.utils.rawconfig import FullConfig
|
from src.utils.rawconfig import FullConfig
|
||||||
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
||||||
|
from src.search_filter.explicit_filter import explicit_filter
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
@@ -58,14 +59,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||||||
|
|
||||||
if (t == SearchType.Notes or t == None) and model.notes_search:
|
if (t == SearchType.Notes or t == None) and model.notes_search:
|
||||||
# query notes
|
# query notes
|
||||||
hits, entries = asymmetric.query(user_query, model.notes_search, device=device)
|
hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter])
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, entries, results_count)
|
return asymmetric.collate_results(hits, entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and model.music_search:
|
if (t == SearchType.Music or t == None) and model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
hits, entries = asymmetric.query(user_query, model.music_search, device=device)
|
hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter])
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, entries, results_count)
|
return asymmetric.collate_results(hits, entries, results_count)
|
||||||
|
|||||||
46
src/search_filter/explicit_filter.py
Normal file
46
src/search_filter/explicit_filter.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Standard Packages
|
||||||
|
import re
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def explicit_filter(raw_query, entries, embeddings):
|
||||||
|
# Separate natural query from explicit required, blocked words filters
|
||||||
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
|
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||||
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
|
return query, entries, embeddings
|
||||||
|
|
||||||
|
# convert each entry to a set of words
|
||||||
|
entries_by_word_set = [set(word.lower()
|
||||||
|
for word
|
||||||
|
in re.split(
|
||||||
|
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
|
||||||
|
entry[0])
|
||||||
|
if word != "")
|
||||||
|
for entry in entries]
|
||||||
|
|
||||||
|
# track id of entries to exclude
|
||||||
|
entries_to_exclude = set()
|
||||||
|
|
||||||
|
# mark entries that do not contain all required_words for exclusion
|
||||||
|
if len(required_words) > 0:
|
||||||
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
|
if not required_words.issubset(words_in_entry):
|
||||||
|
entries_to_exclude.add(id)
|
||||||
|
|
||||||
|
# mark entries that contain any blocked_words for exclusion
|
||||||
|
if len(blocked_words) > 0:
|
||||||
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
|
if words_in_entry.intersection(blocked_words):
|
||||||
|
entries_to_exclude.add(id)
|
||||||
|
|
||||||
|
# delete entries (and their embeddings) marked for exclusion
|
||||||
|
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||||
|
del entries[id]
|
||||||
|
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||||
|
|
||||||
|
return query, entries, embeddings
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import json
|
import json
|
||||||
import gzip
|
import gzip
|
||||||
import re
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -15,6 +14,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||||
|
from src.search_filter.explicit_filter import explicit_filter
|
||||||
from src.utils.config import TextSearchModel
|
from src.utils.config import TextSearchModel
|
||||||
from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig
|
from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig
|
||||||
from src.utils.constants import empty_escape_sequences
|
from src.utils.constants import empty_escape_sequences
|
||||||
@@ -94,19 +94,17 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
|||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
|
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
|
||||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
|
||||||
|
|
||||||
# Copy original embeddings, entries to filter them for query
|
# Copy original embeddings, entries to filter them for query
|
||||||
|
query = raw_query
|
||||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||||
entries = deepcopy(model.entries)
|
entries = deepcopy(model.entries)
|
||||||
|
|
||||||
# Filter to entries that contain all required_words and no blocked_words
|
# Filter query, entries and embeddings before semantic search
|
||||||
entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words)
|
for filter in filters:
|
||||||
|
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -133,42 +131,6 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
|
|||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
|
|
||||||
def explicit_filter(entries, embeddings, required_words, blocked_words):
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
|
||||||
return entries, embeddings
|
|
||||||
|
|
||||||
# convert each entry to a set of words
|
|
||||||
entries_by_word_set = [set(word.lower()
|
|
||||||
for word
|
|
||||||
in re.split(
|
|
||||||
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
|
|
||||||
entry[0])
|
|
||||||
if word != "")
|
|
||||||
for entry in entries]
|
|
||||||
|
|
||||||
# track id of entries to exclude
|
|
||||||
entries_to_exclude = set()
|
|
||||||
|
|
||||||
# mark entries that do not contain all required_words for exclusion
|
|
||||||
if len(required_words) > 0:
|
|
||||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
|
||||||
if not required_words.issubset(words_in_entry):
|
|
||||||
entries_to_exclude.add(id)
|
|
||||||
|
|
||||||
# mark entries that contain any blocked_words for exclusion
|
|
||||||
if len(blocked_words) > 0:
|
|
||||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
|
||||||
if words_in_entry.intersection(blocked_words):
|
|
||||||
entries_to_exclude.add(id)
|
|
||||||
|
|
||||||
# delete entries (and their embeddings) marked for exclusion
|
|
||||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
|
||||||
del entries[id]
|
|
||||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
|
||||||
|
|
||||||
return entries, embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||||
"Render the Results returned by Search for the Query"
|
"Render the Results returned by Search for the Query"
|
||||||
if display_biencoder_results:
|
if display_biencoder_results:
|
||||||
|
|||||||
Reference in New Issue
Block a user