mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Pre-compute entry word sets to improve explicit filter query performance
This commit is contained in:
@@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||||
# Extract Entries, Generate Music Embeddings
|
# Extract Entries, Generate Music Embeddings
|
||||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
|
||||||
|
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||||
# Extract Entries, Generate Ledger Embeddings
|
# Extract Entries, Generate Ledger Embeddings
|
||||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
|
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||||
# query markdown files
|
# query markdown files
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
|||||||
@@ -17,12 +17,21 @@ class DateFilter:
|
|||||||
# - dt:"2 years ago"
|
# - dt:"2 years ago"
|
||||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, entry_key='raw'):
|
||||||
|
self.entry_key = entry_key
|
||||||
|
|
||||||
|
|
||||||
|
def load(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
"Check if query contains date filters"
|
"Check if query contains date filters"
|
||||||
return self.extract_date_range(raw_query) is not None
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
|
||||||
def filter(self, query, entries, embeddings, entry_key='raw'):
|
def apply(self, query, entries, embeddings):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
query_daterange = self.extract_date_range(query)
|
query_daterange = self.extract_date_range(query)
|
||||||
@@ -39,7 +48,7 @@ class DateFilter:
|
|||||||
entries_to_include = set()
|
entries_to_include = set()
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
# Extract dates from entry
|
# Extract dates from entry
|
||||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
|
||||||
# Convert date string in entry to unix timestamp
|
# Convert date string in entry to unix timestamp
|
||||||
try:
|
try:
|
||||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||||
|
|||||||
@@ -1,11 +1,46 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
import pickle
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from src.utils.helpers import resolve_absolute_path
|
||||||
|
from src.utils.config import SearchType
|
||||||
|
|
||||||
|
|
||||||
class ExplicitFilter:
|
class ExplicitFilter:
|
||||||
|
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
|
||||||
|
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
|
||||||
|
self.entry_key = entry_key
|
||||||
|
self.search_type = search_type
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, entries, regenerate=False):
|
||||||
|
if self.filter_file.exists() and not regenerate:
|
||||||
|
start = time.time()
|
||||||
|
with self.filter_file.open('rb') as f:
|
||||||
|
entries_by_word_set = pickle.load(f)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
|
||||||
|
else:
|
||||||
|
start = time.time()
|
||||||
|
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||||
|
entries_by_word_set = [set(word.lower()
|
||||||
|
for word
|
||||||
|
in re.split(entry_splitter, entry[self.entry_key])
|
||||||
|
if word != "")
|
||||||
|
for entry in entries]
|
||||||
|
with self.filter_file.open('wb') as f:
|
||||||
|
pickle.dump(entries_by_word_set, f)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
|
||||||
|
|
||||||
|
return entries_by_word_set
|
||||||
|
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
"Check if query contains explicit filters"
|
"Check if query contains explicit filters"
|
||||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||||
@@ -15,26 +50,24 @@ class ExplicitFilter:
|
|||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
|
||||||
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
|
def apply(self, raw_query, entries, embeddings):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
|
start = time.time()
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
end = time.time()
|
||||||
|
print(f"Time to extract required, blocked words: {end - start} seconds")
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
return query, entries, embeddings
|
return query, entries, embeddings
|
||||||
|
|
||||||
# convert each entry to a set of words
|
# load or generate word set for each entry
|
||||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
entries_by_word_set = self.load(entries, regenerate=False)
|
||||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
|
||||||
entries_by_word_set = [set(word.lower()
|
|
||||||
for word
|
|
||||||
in re.split(entry_splitter, entry[entry_key])
|
|
||||||
if word != "")
|
|
||||||
for entry in entries]
|
|
||||||
|
|
||||||
# track id of entries to exclude
|
# track id of entries to exclude
|
||||||
|
start = time.time()
|
||||||
entries_to_exclude = set()
|
entries_to_exclude = set()
|
||||||
|
|
||||||
# mark entries that do not contain all required_words for exclusion
|
# mark entries that do not contain all required_words for exclusion
|
||||||
@@ -48,10 +81,15 @@ class ExplicitFilter:
|
|||||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
if words_in_entry.intersection(blocked_words):
|
if words_in_entry.intersection(blocked_words):
|
||||||
entries_to_exclude.add(id)
|
entries_to_exclude.add(id)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Mark entries to filter: {end - start} seconds")
|
||||||
|
|
||||||
# delete entries (and their embeddings) marked for exclusion
|
# delete entries (and their embeddings) marked for exclusion
|
||||||
|
start = time.time()
|
||||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||||
del entries[id]
|
del entries[id]
|
||||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||||
|
end = time.time()
|
||||||
|
print(f"Remove entries to filter from embeddings: {end - start} seconds")
|
||||||
|
|
||||||
return query, entries, embeddings
|
return query, entries, embeddings
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ from copy import deepcopy
|
|||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
|
from src.search_filter.date_filter import DateFilter
|
||||||
|
from src.search_filter.explicit_filter import ExplicitFilter
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils import state
|
from src.utils import state
|
||||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||||
from src.utils.config import TextSearchModel
|
from src.utils.config import SearchType, TextSearchModel
|
||||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||||
from src.utils.jsonl import load_jsonl
|
from src.utils.jsonl import load_jsonl
|
||||||
|
|
||||||
@@ -73,13 +75,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
|||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
|
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query = raw_query
|
query = raw_query
|
||||||
|
|
||||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||||
start = time.time()
|
start = time.time()
|
||||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||||
if filters_in_query:
|
if filters_in_query:
|
||||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||||
entries = deepcopy(model.entries)
|
entries = deepcopy(model.entries)
|
||||||
@@ -92,7 +94,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
|||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for filter in filters_in_query:
|
for filter in filters_in_query:
|
||||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
@@ -163,7 +165,7 @@ def collate_results(hits, entries, count=5):
|
|||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||||
|
|
||||||
@@ -180,7 +182,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
|||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||||
|
|
||||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
||||||
|
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
|
||||||
|
for filter in filters:
|
||||||
|
filter.load(entries, regenerate=regenerate)
|
||||||
|
|
||||||
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -20,11 +20,12 @@ class ProcessorType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class TextSearchModel():
|
class TextSearchModel():
|
||||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.corpus_embeddings = corpus_embeddings
|
self.corpus_embeddings = corpus_embeddings
|
||||||
self.bi_encoder = bi_encoder
|
self.bi_encoder = bi_encoder
|
||||||
self.cross_encoder = cross_encoder
|
self.cross_encoder = cross_encoder
|
||||||
|
self.filters = filters
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user