mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 05:39:11 +00:00
Pre-compute entry word sets to improve explicit filter query performance
This commit is contained in:
@@ -8,11 +8,13 @@ from copy import deepcopy
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.config import SearchType, TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
@@ -73,13 +75,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||
start = time.time()
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||
if filters_in_query:
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
@@ -92,7 +94,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
||||
end = time.time()
|
||||
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
||||
|
||||
@@ -163,7 +165,7 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
@@ -180,7 +182,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
||||
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
||||
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user