mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Encode user query as same across search types to speed up query time
- Add new filter abstract method to remove filter terms from query - Use the filter method to remove filter terms, encode this defiltered query and pass it to the query methods of each search types TODO: Encoding query is still taking 100-200 ms unlike before. Need to investigate why
This commit is contained in:
@@ -10,12 +10,16 @@ from typing import List, Optional, Union
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.helpers import log_telemetry, timer
|
||||
from khoj.utils.rawconfig import (
|
||||
FullConfig,
|
||||
@@ -131,6 +135,20 @@ def search(
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[query_cache_key]
|
||||
|
||||
# Encode query with filter terms removed
|
||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(user_query)
|
||||
|
||||
encoded_asymmetric_query = state.model.org_search.bi_encoder.encode(
|
||||
[defiltered_query], convert_to_tensor=True, device=state.device
|
||||
)
|
||||
encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query)
|
||||
|
||||
encoded_symmetric_query = state.model.org_search.bi_encoder.encode(
|
||||
[defiltered_query], convert_to_tensor=True, device=state.device
|
||||
)
|
||||
encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||
# query org-mode notes
|
||||
@@ -139,6 +157,7 @@ def search(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.org_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
@@ -152,6 +171,7 @@ def search(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.markdown_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
@@ -165,6 +185,7 @@ def search(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.pdf_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
@@ -178,6 +199,7 @@ def search(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.ledger_search,
|
||||
question_embedding=encoded_symmetric_query,
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
@@ -191,6 +213,7 @@ def search(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.music_search,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
@@ -217,6 +240,7 @@ def search(
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
|
||||
@@ -18,3 +18,7 @@ class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def defilter(self, query: str) -> str:
|
||||
...
|
||||
|
||||
@@ -49,6 +49,12 @@ class DateFilter(BaseFilter):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
def defilter(self, query):
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||
return query
|
||||
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
@@ -59,9 +65,7 @@ class DateFilter(BaseFilter):
|
||||
if query_daterange is None:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||
query = self.defilter(query)
|
||||
|
||||
# return results from cache if exists
|
||||
cache_key = tuple(query_daterange)
|
||||
|
||||
@@ -28,6 +28,9 @@ class FileFilter(BaseFilter):
|
||||
def can_filter(self, raw_query):
|
||||
return re.search(self.file_filter_regex, raw_query) is not None
|
||||
|
||||
def defilter(self, query: str) -> str:
|
||||
return re.sub(self.file_filter_regex, "", query).strip()
|
||||
|
||||
def apply(self, query, entries):
|
||||
# Extract file filters from raw query
|
||||
with timer("Extract files_to_search from query", logger):
|
||||
@@ -44,8 +47,10 @@ class FileFilter(BaseFilter):
|
||||
else:
|
||||
files_to_search += [file]
|
||||
|
||||
# Remove filter terms from original query
|
||||
query = self.defilter(query)
|
||||
|
||||
# Return item from cache if exists
|
||||
query = re.sub(self.file_filter_regex, "", query).strip()
|
||||
cache_key = tuple(files_to_search)
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"Return file filter results from cache")
|
||||
|
||||
@@ -43,13 +43,16 @@ class WordFilter(BaseFilter):
|
||||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
def defilter(self, query: str) -> str:
|
||||
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from required, blocked words filters
|
||||
with timer("Extract required, blocked filters from query", logger):
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||
query = self.defilter(query)
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
@@ -105,6 +105,7 @@ def compute_embeddings(
|
||||
def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
question_embedding: torch.Tensor = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
@@ -124,9 +125,10 @@ def query(
|
||||
return hits, entries
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
with timer("Search Time", logger, state.device):
|
||||
|
||||
Reference in New Issue
Block a user