Encode user query as same across search types to speed up query time

- Add new filter abstract method to remove filter terms from query
- Use the filter method to remove filter terms, encode this defiltered
  query and pass it to the query methods of each search types

TODO: Encoding query is still taking 100-200 ms unlike before. Need to
investigate why
This commit is contained in:
Debanjum Singh Solanky
2023-06-08 13:37:19 +05:30
parent 285d17af2a
commit db07362ca3
6 changed files with 50 additions and 8 deletions

View File

@@ -10,12 +10,16 @@ from typing import List, Optional, Union
# External Packages
from fastapi import APIRouter
from fastapi import HTTPException
from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.helpers import log_telemetry, timer
from khoj.utils.rawconfig import (
FullConfig,
@@ -131,6 +135,20 @@ def search(
logger.debug(f"Return response from query cache")
return state.query_cache[query_cache_key]
# Encode query with filter terms removed
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(user_query)
encoded_asymmetric_query = state.model.org_search.bi_encoder.encode(
[defiltered_query], convert_to_tensor=True, device=state.device
)
encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query)
encoded_symmetric_query = state.model.org_search.bi_encoder.encode(
[defiltered_query], convert_to_tensor=True, device=state.device
)
encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == None) and state.model.org_search:
# query org-mode notes
@@ -139,6 +157,7 @@ def search(
text_search.query,
user_query,
state.model.org_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
@@ -152,6 +171,7 @@ def search(
text_search.query,
user_query,
state.model.markdown_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
@@ -165,6 +185,7 @@ def search(
text_search.query,
user_query,
state.model.pdf_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
@@ -178,6 +199,7 @@ def search(
text_search.query,
user_query,
state.model.ledger_search,
question_embedding=encoded_symmetric_query,
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
@@ -191,6 +213,7 @@ def search(
text_search.query,
user_query,
state.model.music_search,
question_embedding=encoded_asymmetric_query,
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
@@ -217,6 +240,7 @@ def search(
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
question_embedding=encoded_asymmetric_query,
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,

View File

@@ -18,3 +18,7 @@ class BaseFilter(ABC):
@abstractmethod
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
...
@abstractmethod
def defilter(self, query: str) -> str:
...

View File

@@ -49,6 +49,12 @@ class DateFilter(BaseFilter):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def defilter(self, query):
# remove date range filter from query
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
return query
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
@@ -59,9 +65,7 @@ class DateFilter(BaseFilter):
if query_daterange is None:
return query, set(range(len(entries)))
# remove date range filter from query
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
query = self.defilter(query)
# return results from cache if exists
cache_key = tuple(query_daterange)

View File

@@ -28,6 +28,9 @@ class FileFilter(BaseFilter):
def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None
def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip()
def apply(self, query, entries):
# Extract file filters from raw query
with timer("Extract files_to_search from query", logger):
@@ -44,8 +47,10 @@ class FileFilter(BaseFilter):
else:
files_to_search += [file]
# Remove filter terms from original query
query = self.defilter(query)
# Return item from cache if exists
query = re.sub(self.file_filter_regex, "", query).strip()
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.debug(f"Return file filter results from cache")

View File

@@ -43,13 +43,16 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0
def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
query = self.defilter(query)
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries)))

View File

@@ -105,6 +105,7 @@ def compute_embeddings(
def query(
raw_query: str,
model: TextSearchModel,
question_embedding: torch.Tensor = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
@@ -124,9 +125,10 @@ def query(
return hits, entries
# Encode the query using the bi-encoder
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query
with timer("Search Time", logger, state.device):