diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 93fa0fda..35216343 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -10,12 +10,16 @@ from typing import List, Optional, Union # External Packages from fastapi import APIRouter from fastapi import HTTPException +from sentence_transformers import util # Internal Packages from khoj.configure import configure_processor, configure_search from khoj.processor.conversation.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log, message_to_prompt from khoj.search_type import image_search, text_search +from khoj.search_filter.date_filter import DateFilter +from khoj.search_filter.file_filter import FileFilter +from khoj.search_filter.word_filter import WordFilter from khoj.utils.helpers import log_telemetry, timer from khoj.utils.rawconfig import ( FullConfig, @@ -131,6 +135,20 @@ def search( logger.debug(f"Return response from query cache") return state.query_cache[query_cache_key] + # Encode query with filter terms removed + for filter in [DateFilter(), WordFilter(), FileFilter()]: + defiltered_query = filter.defilter(user_query) + + encoded_asymmetric_query = state.model.org_search.bi_encoder.encode( + [defiltered_query], convert_to_tensor=True, device=state.device + ) + encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query) + + encoded_symmetric_query = state.model.org_search.bi_encoder.encode( + [defiltered_query], convert_to_tensor=True, device=state.device + ) + encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query) + with concurrent.futures.ThreadPoolExecutor() as executor: if (t == SearchType.Org or t == None) and state.model.org_search: # query org-mode notes @@ -139,6 +157,7 @@ def search( text_search.query, user_query, state.model.org_search, + question_embedding=encoded_asymmetric_query, rank_results=r, score_threshold=score_threshold, dedupe=dedupe, @@ -152,6 +171,7 @@ def search( text_search.query, user_query, state.model.markdown_search, + question_embedding=encoded_asymmetric_query, rank_results=r, score_threshold=score_threshold, dedupe=dedupe, @@ -165,6 +185,7 @@ def search( text_search.query, user_query, state.model.pdf_search, + question_embedding=encoded_asymmetric_query, rank_results=r, score_threshold=score_threshold, dedupe=dedupe, @@ -178,6 +199,7 @@ def search( text_search.query, user_query, state.model.ledger_search, + question_embedding=encoded_symmetric_query, rank_results=r, score_threshold=score_threshold, dedupe=dedupe, @@ -191,6 +213,7 @@ def search( text_search.query, user_query, state.model.music_search, + question_embedding=encoded_asymmetric_query, rank_results=r, score_threshold=score_threshold, dedupe=dedupe, @@ -217,6 +240,7 @@ def search( user_query, # Get plugin search model for specified search type, or the first one if none specified state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), + question_embedding=encoded_asymmetric_query, rank_results=r, score_threshold=score_threshold, dedupe=dedupe, diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py index c273f9b8..aa4fa2e4 100644 --- a/src/khoj/search_filter/base_filter.py +++ b/src/khoj/search_filter/base_filter.py @@ -18,3 +18,7 @@ class BaseFilter(ABC): @abstractmethod def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: ... + + @abstractmethod + def defilter(self, query: str) -> str: + ... diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index 36dc7974..be07eefd 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -49,6 +49,12 @@ class DateFilter(BaseFilter): "Check if query contains date filters" return self.extract_date_range(raw_query) is not None + def defilter(self, query): + # remove date range filter from query + query = re.sub(rf"\s+{self.date_regex}", " ", query) + query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces + return query + def apply(self, query, entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query @@ -59,9 +65,7 @@ class DateFilter(BaseFilter): if query_daterange is None: return query, set(range(len(entries))) - # remove date range filter from query - query = re.sub(rf"\s+{self.date_regex}", " ", query) - query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces + query = self.defilter(query) # return results from cache if exists cache_key = tuple(query_daterange) diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py index 28610796..26f416fe 100644 --- a/src/khoj/search_filter/file_filter.py +++ b/src/khoj/search_filter/file_filter.py @@ -28,6 +28,9 @@ class FileFilter(BaseFilter): def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None + def defilter(self, query: str) -> str: + return re.sub(self.file_filter_regex, "", query).strip() + def apply(self, query, entries): # Extract file filters from raw query with timer("Extract files_to_search from query", logger): @@ -44,8 +47,10 @@ class FileFilter(BaseFilter): else: files_to_search += [file] + # Remove filter terms from original query + query = self.defilter(query) + # Return item from cache if exists - query = re.sub(self.file_filter_regex, "", query).strip() cache_key = tuple(files_to_search) if cache_key in self.cache: logger.debug(f"Return file filter results from cache") diff --git a/src/khoj/search_filter/word_filter.py b/src/khoj/search_filter/word_filter.py index 9ee81b21..9c98e848 100644 --- a/src/khoj/search_filter/word_filter.py +++ b/src/khoj/search_filter/word_filter.py @@ -43,13 +43,16 @@ class WordFilter(BaseFilter): return len(required_words) != 0 or len(blocked_words) != 0 + def defilter(self, query: str) -> str: + return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() + def apply(self, query, entries): "Find entries containing required and not blocked words specified in query" # Separate natural query from required, blocked words filters with timer("Extract required, blocked filters from query", logger): required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) - query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() + query = self.defilter(query) if len(required_words) == 0 and len(blocked_words) == 0: return query, set(range(len(entries))) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 9d8d5c3a..96ffac7a 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -105,6 +105,7 @@ def compute_embeddings( def query( raw_query: str, model: TextSearchModel, + question_embedding: torch.Tensor = None, rank_results: bool = False, score_threshold: float = -math.inf, dedupe: bool = True, @@ -124,9 +125,10 @@ def query( return hits, entries # Encode the query using the bi-encoder - with timer("Query Encode Time", logger, state.device): - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) - question_embedding = util.normalize_embeddings(question_embedding) + if question_embedding is None: + with timer("Query Encode Time", logger, state.device): + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) + question_embedding = util.normalize_embeddings(question_embedding) # Find relevant entries for the query with timer("Search Time", logger, state.device):