mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Pre-compute entry word sets to improve explicit filter query performance
This commit is contained in:
@@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
|
||||
@@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
||||
@@ -17,12 +17,21 @@ class DateFilter:
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
|
||||
|
||||
def __init__(self, entry_key='raw'):
|
||||
self.entry_key = entry_key
|
||||
|
||||
|
||||
def load(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def filter(self, query, entries, embeddings, entry_key='raw'):
|
||||
def apply(self, query, entries, embeddings):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = self.extract_date_range(query)
|
||||
@@ -39,7 +48,7 @@ class DateFilter:
|
||||
entries_to_include = set()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
|
||||
@@ -1,11 +1,46 @@
|
||||
# Standard Packages
|
||||
import re
|
||||
import time
|
||||
import pickle
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
||||
class ExplicitFilter:
|
||||
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
|
||||
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
|
||||
self.entry_key = entry_key
|
||||
self.search_type = search_type
|
||||
|
||||
|
||||
def load(self, entries, regenerate=False):
|
||||
if self.filter_file.exists() and not regenerate:
|
||||
start = time.time()
|
||||
with self.filter_file.open('rb') as f:
|
||||
entries_by_word_set = pickle.load(f)
|
||||
end = time.time()
|
||||
print(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
|
||||
else:
|
||||
start = time.time()
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[self.entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
with self.filter_file.open('wb') as f:
|
||||
pickle.dump(entries_by_word_set, f)
|
||||
end = time.time()
|
||||
print(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
|
||||
|
||||
return entries_by_word_set
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains explicit filters"
|
||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||
@@ -15,26 +50,24 @@ class ExplicitFilter:
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
|
||||
def apply(self, raw_query, entries, embeddings):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
start = time.time()
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
end = time.time()
|
||||
print(f"Time to extract required, blocked words: {end - start} seconds")
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, entries, embeddings
|
||||
|
||||
# convert each entry to a set of words
|
||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
# load or generate word set for each entry
|
||||
entries_by_word_set = self.load(entries, regenerate=False)
|
||||
|
||||
# track id of entries to exclude
|
||||
start = time.time()
|
||||
entries_to_exclude = set()
|
||||
|
||||
# mark entries that do not contain all required_words for exclusion
|
||||
@@ -48,10 +81,15 @@ class ExplicitFilter:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if words_in_entry.intersection(blocked_words):
|
||||
entries_to_exclude.add(id)
|
||||
end = time.time()
|
||||
print(f"Mark entries to filter: {end - start} seconds")
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
start = time.time()
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
end = time.time()
|
||||
print(f"Remove entries to filter from embeddings: {end - start} seconds")
|
||||
|
||||
return query, entries, embeddings
|
||||
|
||||
@@ -8,11 +8,13 @@ from copy import deepcopy
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.config import SearchType, TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
@@ -73,13 +75,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||
start = time.time()
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||
if filters_in_query:
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
@@ -92,7 +94,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
||||
end = time.time()
|
||||
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
||||
|
||||
@@ -163,7 +165,7 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
@@ -180,7 +182,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
||||
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
||||
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -20,11 +20,12 @@ class ProcessorType(str, Enum):
|
||||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user