From f93032435042d4fd3129cca6ae1f37da83a422be Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 17:18:47 +0300 Subject: [PATCH 01/13] Rename explicit filter to word filter to be more specific --- Readme.md | 1 - src/router.py | 2 -- .../{explicit_filter.py => word_filter.py} | 15 ++++---- src/search_type/text_search.py | 4 +-- tests/test_client.py | 4 +-- ...explicit_filter.py => test_word_filter.py} | 34 +++++++++---------- 6 files changed, 28 insertions(+), 32 deletions(-) rename src/search_filter/{explicit_filter.py => word_filter.py} (88%) rename tests/{test_explicit_filter.py => test_word_filter.py} (56%) diff --git a/Readme.md b/Readme.md index 76182560..628ce458 100644 --- a/Readme.md +++ b/Readme.md @@ -125,7 +125,6 @@ pip install --upgrade khoj-assistant - Semantic search using the bi-encoder is fairly fast at \<50 ms - Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results -- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc ### Indexing performance diff --git a/src/router.py b/src/router.py index 127623c6..a4bd2f84 100644 --- a/src/router.py +++ b/src/router.py @@ -16,8 +16,6 @@ from fastapi.templating import Jinja2Templates from src.configure import configure_search from src.search_type import image_search, text_search from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize -from src.search_filter.explicit_filter import ExplicitFilter -from src.search_filter.date_filter import DateFilter from src.utils.rawconfig import FullConfig from src.utils.config import SearchType from src.utils.helpers import get_absolute_path, get_from_dict diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/word_filter.py similarity index 88% rename from src/search_filter/explicit_filter.py rename to src/search_filter/word_filter.py index 7a26f830..f47ae6b7 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/word_filter.py @@ -15,13 +15,13 @@ from src.utils.config import SearchType logger = logging.getLogger(__name__) -class ExplicitFilter: +class WordFilter: # Filter Regex required_regex = r'\+"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?' def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): - self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") + self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl") self.entry_key = entry_key self.search_type = search_type self.word_to_entry_index = dict() @@ -34,7 +34,7 @@ class ExplicitFilter: with self.filter_file.open('rb') as f: self.word_to_entry_index = pickle.load(f) end = time.time() - logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") + logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds") else: start = time.time() self.cache = {} # Clear cache on (re-)generating entries_by_word_set @@ -51,14 +51,13 @@ class ExplicitFilter: with self.filter_file.open('wb') as f: pickle.dump(self.word_to_entry_index, f) end = time.time() - logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") + logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds") return self.word_to_entry_index def can_filter(self, raw_query): - "Check if query contains explicit filters" - # Extract explicit query portion with required, blocked words to filter from natural query + "Check if query contains word filters" required_words = re.findall(self.required_regex, raw_query) blocked_words = re.findall(self.blocked_regex, raw_query) @@ -67,7 +66,7 @@ class ExplicitFilter: def apply(self, raw_query, raw_entries, raw_embeddings): "Find entries containing required and not blocked words specified in query" - # Separate natural query from explicit required, blocked words filters + # Separate natural query from required, blocked words filters start = time.time() required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) @@ -83,7 +82,7 @@ class ExplicitFilter: # Return item from cache if exists cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) if cache_key in self.cache: - logger.info(f"Explicit filter results from cache") + logger.info(f"Return word filter results from cache") entries, embeddings = self.cache[cache_key] return query, entries, embeddings diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 742ff5ed..a674d712 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -8,7 +8,7 @@ import time import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util from src.search_filter.date_filter import DateFilter -from src.search_filter.explicit_filter import ExplicitFilter +from src.search_filter.word_filter import WordFilter # Internal Packages from src.utils import state @@ -171,7 +171,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) - filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)] + filters = [DateFilter(), WordFilter(filter_directory, search_type=search_type)] for filter in filters: filter.load(entries, regenerate=regenerate) diff --git a/tests/test_client.py b/tests/test_client.py index e9b632be..e7ddac33 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -140,7 +140,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # Assert assert response.status_code == 200 - # assert actual_data contains explicitly included word "Emacs" + # assert actual_data contains word "Emacs" search_result = response.json()[0]["entry"] assert "Emacs" in search_result @@ -156,6 +156,6 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_ # Assert assert response.status_code == 200 - # assert actual_data does not contains explicitly excluded word "Emacs" + # assert actual_data does not contains word "Emacs" search_result = response.json()[0]["entry"] assert "clone" not in search_result diff --git a/tests/test_explicit_filter.py b/tests/test_word_filter.py similarity index 56% rename from tests/test_explicit_filter.py rename to tests/test_word_filter.py index 5f34b0ac..3d584077 100644 --- a/tests/test_explicit_filter.py +++ b/tests/test_word_filter.py @@ -2,19 +2,19 @@ import torch # Application Packages -from src.search_filter.explicit_filter import ExplicitFilter +from src.search_filter.word_filter import WordFilter from src.utils.config import SearchType -def test_no_explicit_filter(tmp_path): +def test_no_word_filter(tmp_path): # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + word_filter = WordFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() q_with_no_filter = 'head tail' # Act - can_filter = explicit_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) + can_filter = word_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False @@ -23,15 +23,15 @@ def test_no_explicit_filter(tmp_path): assert ret_entries == entries -def test_explicit_exclude_filter(tmp_path): +def test_word_exclude_filter(tmp_path): # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + word_filter = WordFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() q_with_exclude_filter = 'head -"exclude_word" tail' # Act - can_filter = explicit_filter.can_filter(q_with_exclude_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + can_filter = word_filter.can_filter(q_with_exclude_filter) + ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True @@ -40,15 +40,15 @@ def test_explicit_exclude_filter(tmp_path): assert ret_entries == [entries[0], entries[2]] -def test_explicit_include_filter(tmp_path): +def test_word_include_filter(tmp_path): # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + word_filter = WordFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() query_with_include_filter = 'head +"include_word" tail' # Act - can_filter = explicit_filter.can_filter(query_with_include_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) + can_filter = word_filter.can_filter(query_with_include_filter) + ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) # Assert assert can_filter == True @@ -57,15 +57,15 @@ def test_explicit_include_filter(tmp_path): assert ret_entries == [entries[2], entries[3]] -def test_explicit_include_and_exclude_filter(tmp_path): +def test_word_include_and_exclude_filter(tmp_path): # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + word_filter = WordFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' # Act - can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) + ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True From c9f620000754674bb0489f05c2c352dce017fddf Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 17:19:22 +0300 Subject: [PATCH 02/13] Ignore pytest_cache directory from git using .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9d33a849..a2e89c26 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ src/.data /dist/ /khoj_assistant.egg-info/ /config/khoj*.yml +.pytest_cache From e4418746f22b16e0d7fbcfa7f6b1657bc44f921f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 18:05:38 +0300 Subject: [PATCH 03/13] Create Abstract Base Class for Filters. Make Word, Date Filter Child of BaseFilter --- src/search_filter/base_filter.py | 20 ++++++++++++++++++++ src/search_filter/date_filter.py | 7 +++++-- src/search_filter/word_filter.py | 3 ++- src/utils/config.py | 4 +++- 4 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 src/search_filter/base_filter.py diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py new file mode 100644 index 00000000..dc079b45 --- /dev/null +++ b/src/search_filter/base_filter.py @@ -0,0 +1,20 @@ +# Standard Packages +from abc import ABC, abstractmethod +from typing import List, Tuple + +# External Packages +import torch + + +class BaseFilter(ABC): + @abstractmethod + def load(self, *args, **kwargs): + pass + + @abstractmethod + def can_filter(self, raw_query:str) -> bool: + pass + + @abstractmethod + def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]: + pass \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index cab47cbb..54a8b625 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -1,7 +1,7 @@ # Standard Packages import re from datetime import timedelta, datetime -from dateutil.relativedelta import relativedelta, MO +from dateutil.relativedelta import relativedelta from math import inf from copy import deepcopy @@ -9,8 +9,11 @@ from copy import deepcopy import torch import dateparser as dtparse +# Internal Packages +from src.search_filter.base_filter import BaseFilter -class DateFilter: + +class DateFilter(BaseFilter): # Date Range Filter Regexes # Example filter queries: # - dt>="yesterday" dt<"tomorrow" diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index f47ae6b7..9f46edd2 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -8,6 +8,7 @@ import logging import torch # Internal Packages +from src.search_filter.base_filter import BaseFilter from src.utils.helpers import LRU, resolve_absolute_path from src.utils.config import SearchType @@ -15,7 +16,7 @@ from src.utils.config import SearchType logger = logging.getLogger(__name__) -class WordFilter: +class WordFilter(BaseFilter): # Filter Regex required_regex = r'\+"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?' diff --git a/src/utils/config.py b/src/utils/config.py index a4de6b81..c163e22f 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -2,9 +2,11 @@ from enum import Enum from dataclasses import dataclass from pathlib import Path +from typing import List # Internal Packages from src.utils.rawconfig import ConversationProcessorConfig +from src.search_filter.base_filter import BaseFilter class SearchType(str, Enum): @@ -20,7 +22,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder From 1f9fd28b341d6255dc0b358dd103a76e252418b5 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 19:38:29 +0300 Subject: [PATCH 04/13] Create File Filter to filter files to query. Add tests for file filter --- src/search_filter/file_filter.py | 37 +++++++++++ tests/test_file_filter.py | 101 +++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 src/search_filter/file_filter.py create mode 100644 tests/test_file_filter.py diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py new file mode 100644 index 00000000..065badc0 --- /dev/null +++ b/src/search_filter/file_filter.py @@ -0,0 +1,37 @@ +# Standard Packages +import re +import fnmatch + +# External Packages +import torch + +# Internal Packages +from src.search_filter.base_filter import BaseFilter + + +class FileFilter(BaseFilter): + file_filter_regex = r'file:"(.+?)" ?' + + def __init__(self, entry_key='file'): + self.entry_key = entry_key + + def load(self, *args, **kwargs): + pass + + def can_filter(self, raw_query): + return re.search(self.file_filter_regex, raw_query) is not None + + def apply(self, raw_query, raw_entries, raw_embeddings): + files_to_search = re.findall(self.file_filter_regex, raw_query) + if not files_to_search: + return raw_query, raw_entries, raw_embeddings + + query = re.sub(self.file_filter_regex, '', raw_query).strip() + included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] + if not included_entry_indices: + return query, [], torch.empty(0) + + entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] + embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) + + return query, entries, embeddings diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py new file mode 100644 index 00000000..401adfc7 --- /dev/null +++ b/tests/test_file_filter.py @@ -0,0 +1,101 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.file_filter import FileFilter + + +def test_no_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def test_file_filter_with_partial_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"*.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def test_file_filter_with_non_existent_file(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"nonexistent.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 0 + assert ret_entries == [] + + +def test_single_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"file1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[0], entries[2]] + + +def test_multiple_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail file:"file1.org" file:"file2.org"' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'First Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, + {'compiled': '', 'raw': 'Third Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] + + return embeddings, entries From 092b9e329d88c3645ef737d04c2233d181e51007 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 01:05:13 +0300 Subject: [PATCH 05/13] Setup Filters when configuring Text Search for each Search Type - Allows enabling different filters for different Text Search Types - Use FileFilter in Text Search on Org Files --- src/configure.py | 13 +++++++++---- src/search_type/text_search.py | 9 +++------ tests/conftest.py | 14 +++++++++----- tests/test_client.py | 12 +++++++----- tests/test_text_search.py | 10 +++++----- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/configure.py b/src/configure.py index 938062eb..3ee594b4 100644 --- a/src/configure.py +++ b/src/configure.py @@ -14,6 +14,9 @@ from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, Con from src.utils import state from src.utils.helpers import resolve_absolute_path from src.utils.rawconfig import FullConfig, ProcessorConfig +from src.search_filter.date_filter import DateFilter +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter logger = logging.getLogger(__name__) @@ -39,23 +42,25 @@ def configure_server(args, required=False): def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None): # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: + filter_directory = resolve_absolute_path(config.content_type.org.compressed_jsonl.parent) + filters = [DateFilter(), WordFilter(filter_directory, search_type=SearchType.Org), FileFilter()] # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, filters=filters) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate) + model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index a674d712..1dc30ef2 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -7,13 +7,12 @@ import time # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util -from src.search_filter.date_filter import DateFilter -from src.search_filter.word_filter import WordFilter +from src.search_filter.base_filter import BaseFilter # Internal Packages from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.utils.config import SearchType, TextSearchModel +from src.utils.config import TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -153,7 +152,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -170,8 +169,6 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) - filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) - filters = [DateFilter(), WordFilter(filter_directory, search_type=search_type)] for filter in filters: filter.load(entries, regenerate=regenerate) diff --git a/tests/conftest.py b/tests/conftest.py index 930ec734..7545527f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Standard Packages +# External Packages import pytest # Internal Packages @@ -6,10 +6,13 @@ from src.search_type import image_search, text_search from src.utils.config import SearchType from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.date_filter import DateFilter +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter @pytest.fixture(scope='session') -def search_config(tmp_path_factory): +def search_config(tmp_path_factory) -> SearchConfig: model_dir = tmp_path_factory.mktemp('data') search_config = SearchConfig() @@ -35,7 +38,7 @@ def search_config(tmp_path_factory): @pytest.fixture(scope='session') -def model_dir(search_config): +def model_dir(search_config: SearchConfig): model_dir = search_config.asymmetric.model_directory # Generate Image Embeddings from Test Images @@ -55,13 +58,14 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + filters = [DateFilter(), WordFilter(model_dir, search_type=SearchType.Org), FileFilter()] + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) return model_dir @pytest.fixture(scope='session') -def content_config(model_dir): +def content_config(model_dir) -> ContentConfig: content_config = ContentConfig() content_config.org = TextContentConfig( input_files = None, diff --git a/tests/test_client.py b/tests/test_client.py index e7ddac33..578c789c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,6 @@ from PIL import Image # External Packages from fastapi.testclient import TestClient -import pytest # Internal Packages from src.main import app @@ -12,7 +11,8 @@ from src.utils.config import SearchType from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig -from src.processor.org_mode import org_to_jsonl +from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.word_filter import WordFilter # Arrange @@ -116,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application?" # Act @@ -132,7 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = 'How to git install application? +"Emacs"' # Act @@ -148,7 +149,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = 'How to git install application? -"clone"' # Act diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 84f16df5..d56d304d 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -14,7 +14,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) + notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -24,7 +24,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) query = "How to git install application?" # Act @@ -47,7 +47,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -60,11 +60,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) + regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) # Assert assert len(regenerated_notes_model.entries) == 11 From f634399f23bce23bb8867cf2f699c4625ec21411 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 01:45:18 +0300 Subject: [PATCH 06/13] Convert simple file filters with no path separator into regex - Specify just file name to get all notes associated with file at path - E.g `query` with `file:"file1.org"` will return `entry1` if `entry1` is in `file1.org` at `~/notes/file.org` - Test - Test converting simple file name filter to regex for path match - Test file filter with space in file name --- src/search_filter/file_filter.py | 13 +++++-- tests/test_file_filter.py | 59 ++++++++++++++++++++------------ 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 065badc0..674d88ed 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -22,10 +22,19 @@ class FileFilter(BaseFilter): return re.search(self.file_filter_regex, raw_query) is not None def apply(self, raw_query, raw_entries, raw_embeddings): - files_to_search = re.findall(self.file_filter_regex, raw_query) - if not files_to_search: + # Extract file filters from raw query + raw_files_to_search = re.findall(self.file_filter_regex, raw_query) + if not raw_files_to_search: return raw_query, raw_entries, raw_embeddings + # Convert simple file filters with no path separator into regex + # e.g. "file:notes.org" -> "file:.*notes.org" + files_to_search = [] + for file in sorted(raw_files_to_search): + if '/' not in file and '\\' not in file and '*' not in file: + files_to_search += [f'*{file}'] + else: + files_to_search += [file] query = re.sub(self.file_filter_regex, '', raw_query).strip() included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] if not included_entry_indices: diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 401adfc7..b15b8a69 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -22,23 +22,6 @@ def test_no_file_filter(): assert ret_entries == entries -def test_file_filter_with_partial_match(): - # Arrange - file_filter = FileFilter() - embeddings, entries = arrange_content() - q_with_no_filter = 'head file:"*.org" tail' - - # Act - can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) - - # Assert - assert can_filter == True - assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries - - def test_file_filter_with_non_existent_file(): # Arrange file_filter = FileFilter() @@ -60,7 +43,7 @@ def test_single_file_filter(): # Arrange file_filter = FileFilter() embeddings, entries = arrange_content() - q_with_no_filter = 'head file:"file1.org" tail' + q_with_no_filter = 'head file:"file 1.org" tail' # Act can_filter = file_filter.can_filter(q_with_no_filter) @@ -73,11 +56,45 @@ def test_single_file_filter(): assert ret_entries == [entries[0], entries[2]] +def test_file_filter_with_partial_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[0], entries[2]] + + +def test_file_filter_with_regex_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"*.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + def test_multiple_file_filter(): # Arrange file_filter = FileFilter() embeddings, entries = arrange_content() - q_with_no_filter = 'head tail file:"file1.org" file:"file2.org"' + q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' # Act can_filter = file_filter.can_filter(q_with_no_filter) @@ -93,9 +110,9 @@ def test_multiple_file_filter(): def arrange_content(): embeddings = torch.randn(4, 10) entries = [ - {'compiled': '', 'raw': 'First Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, - {'compiled': '', 'raw': 'Third Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] return embeddings, entries From 7e083d3e96b0db42219649b5b54d82dd075b315d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 01:51:11 +0300 Subject: [PATCH 07/13] Cache results for file filters passed in query for faster filtering --- src/search_filter/file_filter.py | 33 ++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 674d88ed..2c565982 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -1,12 +1,18 @@ # Standard Packages import re import fnmatch +import time +import logging # External Packages import torch # Internal Packages from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU + + +logger = logging.getLogger(__name__) class FileFilter(BaseFilter): @@ -14,6 +20,7 @@ class FileFilter(BaseFilter): def __init__(self, entry_key='file'): self.entry_key = entry_key + self.cache = LRU() def load(self, *args, **kwargs): pass @@ -23,6 +30,7 @@ class FileFilter(BaseFilter): def apply(self, raw_query, raw_entries, raw_embeddings): # Extract file filters from raw query + start = time.time() raw_files_to_search = re.findall(self.file_filter_regex, raw_query) if not raw_files_to_search: return raw_query, raw_entries, raw_embeddings @@ -35,12 +43,37 @@ class FileFilter(BaseFilter): files_to_search += [f'*{file}'] else: files_to_search += [file] + end = time.time() + logger.debug(f"Extract files_to_search from query: {end - start} seconds") + + # Return item from cache if exists query = re.sub(self.file_filter_regex, '', raw_query).strip() + cache_key = tuple(files_to_search) + if cache_key in self.cache: + logger.info(f"Return file filter results from cache") + entries, embeddings = self.cache[cache_key] + return query, entries, embeddings + + # Mark entries that contain any blocked_words for exclusion + start = time.time() + included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] if not included_entry_indices: return query, [], torch.empty(0) + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # Get entries (and associated embeddings) satisfying file filters + start = time.time() + entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) + end = time.time() + logger.debug(f"Keep entries satisfying filter: {end - start} seconds") + + # Cache results + self.cache[cache_key] = entries, embeddings + return query, entries, embeddings From 7606724dbc395d92aa7231f291b0476aa27fb4ef Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 01:57:17 +0300 Subject: [PATCH 08/13] Add file of each entry to entry dict in org_to_jsonl converter - This will help filter query to org content type using file filter - Do not explicitly specify items being extracted from json of each entry in text_search as all text search content types do not have file being set in jsonl converters --- src/processor/org_mode/org_to_jsonl.py | 18 ++++++++++-------- src/search_type/text_search.py | 6 ++---- tests/test_org_to_jsonl.py | 8 ++++---- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index a41705e0..131ef919 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -28,10 +28,10 @@ def org_to_jsonl(org_files, org_file_filter, output_file): org_files = get_org_files(org_files, org_file_filter) # Extract Entries from specified Org files - entries = extract_org_entries(org_files) + entries, file_to_entries = extract_org_entries(org_files) # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries) + jsonl_data = convert_org_entries_to_jsonl(entries, file_to_entries) # Compress JSONL formatted Data if output_file.suffix == ".gz": @@ -66,18 +66,19 @@ def get_org_files(org_files=None, org_file_filter=None): def extract_org_entries(org_files): "Extract entries from specified Org files" entries = [] + entry_to_file_map = [] for org_file in org_files: - entries.extend( - orgnode.makelist( - str(org_file))) + org_file_entries = orgnode.makelist(str(org_file)) + entry_to_file_map += [org_file]*len(org_file_entries) + entries.extend(org_file_entries) - return entries + return entries, entry_to_file_map -def convert_org_entries_to_jsonl(entries) -> str: +def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str: "Convert each Org-Mode entries to JSON and collate as JSONL" jsonl = '' - for entry in entries: + for entry_id, entry in enumerate(entries): entry_dict = dict() # Ignore title notes i.e notes with just headings and empty body @@ -106,6 +107,7 @@ def convert_org_entries_to_jsonl(entries) -> str: if entry_dict: entry_dict["raw"] = f'{entry}' + entry_dict["file"] = f'{entry_to_file_map[entry_id]}' # Convert Dictionary to JSON and Append to JSONL string jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 1dc30ef2..3b050bf0 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -52,9 +52,7 @@ def initialize_model(search_config: TextSearchConfig): def extract_entries(jsonl_file): "Load entries from compressed jsonl" - return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'} - for entry - in load_jsonl(jsonl_file)] + return load_jsonl(jsonl_file) def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): @@ -83,7 +81,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): for filter in filters_in_query: query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) end = time.time() - logger.debug(f"Filter Time: {end - start:.3f} seconds") + logger.debug(f"Total Filter Time: {end - start:.3f} seconds") if entries is None or len(entries) == 0: return [], [] diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index cadd4a6a..6a626299 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -21,10 +21,10 @@ def test_entry_with_empty_body_line_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries = extract_org_entries(org_files=[orgfile]) + entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries) + jsonl_data = convert_org_entries_to_jsonl(entries, entry_to_file_map) # Assert assert is_none_or_empty(jsonl_data) @@ -43,10 +43,10 @@ def test_entry_with_body_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries = extract_org_entries(org_files=[orgfile]) + entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_org_entries_to_jsonl(entries) + jsonl_string = convert_org_entries_to_jsonl(entries, entry_to_file_map) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert From 2890b4cd4425686de9c191fcfdbea358c764cf4e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 02:09:36 +0300 Subject: [PATCH 09/13] Simplify extracting entries satisfying file filter --- src/search_filter/file_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 2c565982..6aa7db78 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -67,7 +67,7 @@ class FileFilter(BaseFilter): # Get entries (and associated embeddings) satisfying file filters start = time.time() - entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] + entries = [raw_entries[id] for id in included_entry_indices] embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) end = time.time() From 7dd20d764c53bae81c4ed50027b6422b8eb06b6b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 02:51:15 +0300 Subject: [PATCH 10/13] Pre-compute file to entry map in file filter to mark ids to include faster --- src/search_filter/file_filter.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 6aa7db78..45677bf6 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -3,6 +3,7 @@ import re import fnmatch import time import logging +from collections import defaultdict # External Packages import torch @@ -20,10 +21,15 @@ class FileFilter(BaseFilter): def __init__(self, entry_key='file'): self.entry_key = entry_key + self.file_to_entry_map = defaultdict(set) self.cache = LRU() - def load(self, *args, **kwargs): - pass + def load(self, entries, *args, **kwargs): + start = time.time() + for id, entry in enumerate(entries): + self.file_to_entry_map[entry[self.entry_key]].add(id) + end = time.time() + logger.debug(f"Created file filter index: {end - start} seconds") def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None @@ -57,7 +63,10 @@ class FileFilter(BaseFilter): # Mark entries that contain any blocked_words for exclusion start = time.time() - included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] + included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] + for entry_file in self.file_to_entry_map.keys() + for search_file in files_to_search + if fnmatch.fnmatch(entry_file, search_file)], set()) if not included_entry_indices: return query, [], torch.empty(0) From 965bd052f1eccdebc2388cb7eab1dddbbedf6c7c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 03:17:41 +0300 Subject: [PATCH 11/13] Make search filters return entry ids satisfying filter - Filter entries, embeddings by ids satisfying all filters in query func, after each filter has returned entry ids satisfying their individual acceptance criteria - Previously each filter would return a filtered list of entries. Each filter would be applied on entries filtered by previous filters. This made the filtering order dependent - Benefits - Filters can be applied independent of their order of execution - Precomputed indexes for each filter is not in danger of running into index out of bound errors, as filters run on original entries instead of on entries filtered by filters that have run before it - Extract entries satisfying filter only once instead of doing this for each filter - Costs - Each filter has to process all entries even if previous filters may have already marked them as non-satisfactory --- src/search_filter/date_filter.py | 16 +++------------- src/search_filter/file_filter.py | 27 +++++++++------------------ src/search_filter/word_filter.py | 19 ++++++------------- src/search_type/text_search.py | 15 ++++++++++++++- tests/test_date_filter.py | 30 ++++++++++++------------------ tests/test_file_filter.py | 30 ++++++++++++------------------ tests/test_word_filter.py | 20 ++++++++------------ 7 files changed, 64 insertions(+), 93 deletions(-) diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 54a8b625..73feaeed 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -42,19 +42,15 @@ class DateFilter(BaseFilter): # if no date in query, return all entries if query_daterange is None: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces - # deep copy original embeddings, entries before filtering - embeddings= deepcopy(raw_embeddings) - entries = deepcopy(raw_entries) - # find entries containing any dates that fall with date range specified in query entries_to_include = set() - for id, entry in enumerate(entries): + for id, entry in enumerate(raw_entries): # Extract dates from entry for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): # Convert date string in entry to unix timestamp @@ -67,13 +63,7 @@ class DateFilter(BaseFilter): entries_to_include.add(id) break - # delete entries (and their embeddings) marked for exclusion - entries_to_exclude = set(range(len(entries))) - entries_to_include - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - - return query, entries, embeddings + return query, entries_to_include def extract_date_range(self, query): diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 45677bf6..3af67705 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -5,9 +5,6 @@ import time import logging from collections import defaultdict -# External Packages -import torch - # Internal Packages from src.search_filter.base_filter import BaseFilter from src.utils.helpers import LRU @@ -39,7 +36,7 @@ class FileFilter(BaseFilter): start = time.time() raw_files_to_search = re.findall(self.file_filter_regex, raw_query) if not raw_files_to_search: - return raw_query, raw_entries, raw_embeddings + return raw_query, set(range(len(raw_entries))) # Convert simple file filters with no path separator into regex # e.g. "file:notes.org" -> "file:.*notes.org" @@ -57,8 +54,11 @@ class FileFilter(BaseFilter): cache_key = tuple(files_to_search) if cache_key in self.cache: logger.info(f"Return file filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices + + if not self.file_to_entry_map: + self.load(raw_entries, regenerate=False) # Mark entries that contain any blocked_words for exclusion start = time.time() @@ -68,21 +68,12 @@ class FileFilter(BaseFilter): for search_file in files_to_search if fnmatch.fnmatch(entry_file, search_file)], set()) if not included_entry_indices: - return query, [], torch.empty(0) + return query, {} end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # Get entries (and associated embeddings) satisfying file filters - start = time.time() - - entries = [raw_entries[id] for id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") - # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index 9f46edd2..bae6764e 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -78,14 +78,14 @@ class WordFilter(BaseFilter): logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # Return item from cache if exists cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) if cache_key in self.cache: logger.info(f"Return word filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices if not self.word_to_entry_index: self.load(raw_entries, regenerate=False) @@ -105,17 +105,10 @@ class WordFilter(BaseFilter): end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # get entries (and their embeddings) satisfying inclusion and exclusion filters - start = time.time() - + # get entries satisfying inclusion and exclusion filters included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words - entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 3b050bf0..eb7b0d34 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -78,8 +78,21 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): # Filter query, entries and embeddings before semantic search start = time.time() filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] + included_entry_indices = set(range(len(entries))) for filter in filters_in_query: - query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) + query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings) + included_entry_indices.intersection_update(included_entry_indices_by_filter) + + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return [], [] + else: + start = time.time() + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) + end = time.time() + logger.debug(f"Keep entries satisfying all filter: {end - start} seconds") + end = time.time() logger.debug(f"Total Filter Time: {end - start:.3f} seconds") diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index ddb1fcf0..0ac444fd 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,40 +18,34 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings) assert ret_query == 'head tail' - assert len(ret_emb) == 3 - assert ret_entries == entries + assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings) assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[1]] - assert len(ret_emb) == 1 + assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[1], entries[2]] - assert len(ret_emb) == 2 + assert entry_indices == {1, 2} def test_extract_date_range(): diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index b15b8a69..bde53ae0 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -13,13 +13,12 @@ def test_no_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_file_filter_with_non_existent_file(): @@ -30,13 +29,12 @@ def test_file_filter_with_non_existent_file(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == {} def test_single_file_filter(): @@ -47,13 +45,12 @@ def test_single_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_file_filter_with_partial_match(): @@ -64,13 +61,12 @@ def test_file_filter_with_partial_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_file_filter_with_regex_match(): @@ -81,13 +77,12 @@ def test_file_filter_with_regex_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_multiple_file_filter(): @@ -98,13 +93,12 @@ def test_multiple_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def arrange_content(): diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 3d584077..95743bfd 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -14,13 +14,12 @@ def test_no_word_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_word_exclude_filter(tmp_path): @@ -31,13 +30,12 @@ def test_word_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_exclude_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_word_include_filter(tmp_path): @@ -48,13 +46,12 @@ def test_word_include_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[2], entries[3]] + assert entry_indices == {2, 3} def test_word_include_and_exclude_filter(tmp_path): @@ -65,13 +62,12 @@ def test_word_include_and_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 1 - assert ret_entries == [entries[2]] + assert entry_indices == {2} def arrange_content(): From 31503e7afd5030fd703fb66df766ba262a0d04da Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 15:46:54 +0300 Subject: [PATCH 12/13] Do not pass embeddings as argument to filter.apply method --- src/search_filter/base_filter.py | 4 ++-- src/search_filter/date_filter.py | 2 +- src/search_filter/file_filter.py | 2 +- src/search_filter/word_filter.py | 2 +- src/search_type/text_search.py | 12 ++++++------ tests/test_date_filter.py | 12 ++++++------ tests/test_file_filter.py | 12 ++++++------ tests/test_text_search.py | 1 - tests/test_word_filter.py | 8 ++++---- 9 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py index dc079b45..735b6915 100644 --- a/src/search_filter/base_filter.py +++ b/src/search_filter/base_filter.py @@ -1,6 +1,6 @@ # Standard Packages from abc import ABC, abstractmethod -from typing import List, Tuple +from typing import List, Set, Tuple # External Packages import torch @@ -16,5 +16,5 @@ class BaseFilter(ABC): pass @abstractmethod - def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]: + def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]: pass \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 73feaeed..683d7a64 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -35,7 +35,7 @@ class DateFilter(BaseFilter): return self.extract_date_range(raw_query) is not None - def apply(self, query, raw_entries, raw_embeddings): + def apply(self, query, raw_entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 3af67705..41f80274 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -31,7 +31,7 @@ class FileFilter(BaseFilter): def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None - def apply(self, raw_query, raw_entries, raw_embeddings): + def apply(self, raw_query, raw_entries): # Extract file filters from raw query start = time.time() raw_files_to_search = re.findall(self.file_filter_regex, raw_query) diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index bae6764e..dcf9ca6b 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -65,7 +65,7 @@ class WordFilter(BaseFilter): return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, raw_query, raw_entries, raw_embeddings): + def apply(self, raw_query, raw_entries): "Find entries containing required and not blocked words specified in query" # Separate natural query from required, blocked words filters start = time.time() diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index eb7b0d34..8666056c 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -76,11 +76,11 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search - start = time.time() - filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] + start_filter = time.time() included_entry_indices = set(range(len(entries))) + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: - query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings) + query, included_entry_indices_by_filter = filter.apply(query, entries) included_entry_indices.intersection_update(included_entry_indices_by_filter) # Get entries (and associated embeddings) satisfying all filters @@ -91,10 +91,10 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): entries = [entries[id] for id in included_entry_indices] corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) end = time.time() - logger.debug(f"Keep entries satisfying all filter: {end - start} seconds") + logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") - end = time.time() - logger.debug(f"Total Filter Time: {end - start:.3f} seconds") + end_filter = time.time() + logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds") if entries is None or len(entries) == 0: return [], [] diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 0ac444fd..345c5c4f 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,32 +18,32 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) assert ret_query == 'head tail' assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) assert ret_query == 'head tail' assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {1, 2} diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index bde53ae0..3f9c22b3 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -13,7 +13,7 @@ def test_no_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == False @@ -29,7 +29,7 @@ def test_file_filter_with_non_existent_file(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -45,7 +45,7 @@ def test_single_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -61,7 +61,7 @@ def test_file_filter_with_partial_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -77,7 +77,7 @@ def test_file_filter_with_regex_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -93,7 +93,7 @@ def test_multiple_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True diff --git a/tests/test_text_search.py b/tests/test_text_search.py index d56d304d..39fed92e 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,6 +1,5 @@ # System Packages from pathlib import Path -from src.utils.config import SearchType # Internal Packages from src.utils.state import model diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 95743bfd..3efe8ed9 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -14,7 +14,7 @@ def test_no_word_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == False @@ -30,7 +30,7 @@ def test_word_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_exclude_filter) - ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries) # Assert assert can_filter == True @@ -46,7 +46,7 @@ def test_word_include_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_filter) - ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries) # Assert assert can_filter == True @@ -62,7 +62,7 @@ def test_word_include_and_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries) # Assert assert can_filter == True From 3707a4cdd420a5e9bc749d7efa6e236b3f32deff Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 18:21:29 +0300 Subject: [PATCH 13/13] Improve date filter perf. Precompute date to entry map, Cache results - Precompute date to entry map - Cache results for faster recall - Log preformance timers in date filter --- src/search_filter/date_filter.py | 67 +++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 683d7a64..53c7b266 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -1,33 +1,51 @@ # Standard Packages import re +import time +import logging +from collections import defaultdict from datetime import timedelta, datetime from dateutil.relativedelta import relativedelta from math import inf -from copy import deepcopy # External Packages -import torch import dateparser as dtparse # Internal Packages from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU + + +logger = logging.getLogger(__name__) class DateFilter(BaseFilter): # Date Range Filter Regexes # Example filter queries: - # - dt>="yesterday" dt<"tomorrow" - # - dt>="last week" - # - dt:"2 years ago" + # - dt>="yesterday" dt<"tomorrow" + # - dt>="last week" + # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" def __init__(self, entry_key='raw'): self.entry_key = entry_key + self.date_to_entry_ids = defaultdict(set) + self.cache = LRU() - def load(*args, **kwargs): - pass + def load(self, entries, **_): + start = time.time() + for id, entry in enumerate(entries): + # Extract dates from entry + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): + # Convert date string in entry to unix timestamp + try: + date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() + except ValueError: + continue + self.date_to_entry_ids[date_in_entry].add(id) + end = time.time() + logger.debug(f"Created file filter index: {end - start} seconds") def can_filter(self, raw_query): @@ -38,7 +56,10 @@ class DateFilter(BaseFilter): def apply(self, query, raw_entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query + start = time.time() query_daterange = self.extract_date_range(query) + end = time.time() + logger.debug(f"Extract date range to filter from query: {end - start} seconds") # if no date in query, return all entries if query_daterange is None: @@ -48,20 +69,28 @@ class DateFilter(BaseFilter): query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + # return results from cache if exists + cache_key = tuple(query_daterange) + if cache_key in self.cache: + logger.info(f"Return date filter results from cache") + entries_to_include = self.cache[cache_key] + return query, entries_to_include + + if not self.date_to_entry_ids: + self.load(raw_entries) + # find entries containing any dates that fall with date range specified in query + start = time.time() entries_to_include = set() - for id, entry in enumerate(raw_entries): - # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): - # Convert date string in entry to unix timestamp - try: - date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() - except ValueError: - continue - # Check if date in entry is within date range specified in query - if query_daterange[0] <= date_in_entry < query_daterange[1]: - entries_to_include.add(id) - break + for date_in_entry in self.date_to_entry_ids.keys(): + # Check if date in entry is within date range specified in query + if query_daterange[0] <= date_in_entry < query_daterange[1]: + entries_to_include |= self.date_to_entry_ids[date_in_entry] + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # cache results + self.cache[cache_key] = entries_to_include return query, entries_to_include