diff --git a/.gitignore b/.gitignore index 9d33a849..a2e89c26 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ src/.data /dist/ /khoj_assistant.egg-info/ /config/khoj*.yml +.pytest_cache diff --git a/Readme.md b/Readme.md index 76182560..628ce458 100644 --- a/Readme.md +++ b/Readme.md @@ -125,7 +125,6 @@ pip install --upgrade khoj-assistant - Semantic search using the bi-encoder is fairly fast at \<50 ms - Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results -- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc ### Indexing performance diff --git a/src/configure.py b/src/configure.py index 938062eb..3ee594b4 100644 --- a/src/configure.py +++ b/src/configure.py @@ -14,6 +14,9 @@ from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, Con from src.utils import state from src.utils.helpers import resolve_absolute_path from src.utils.rawconfig import FullConfig, ProcessorConfig +from src.search_filter.date_filter import DateFilter +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter logger = logging.getLogger(__name__) @@ -39,23 +42,25 @@ def configure_server(args, required=False): def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None): # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: + filter_directory = resolve_absolute_path(config.content_type.org.compressed_jsonl.parent) + filters = [DateFilter(), WordFilter(filter_directory, search_type=SearchType.Org), FileFilter()] # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, filters=filters) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate) + model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index a41705e0..131ef919 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -28,10 +28,10 @@ def org_to_jsonl(org_files, org_file_filter, output_file): org_files = get_org_files(org_files, org_file_filter) # Extract Entries from specified Org files - entries = extract_org_entries(org_files) + entries, file_to_entries = extract_org_entries(org_files) # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries) + jsonl_data = convert_org_entries_to_jsonl(entries, file_to_entries) # Compress JSONL formatted Data if output_file.suffix == ".gz": @@ -66,18 +66,19 @@ def get_org_files(org_files=None, org_file_filter=None): def extract_org_entries(org_files): "Extract entries from specified Org files" entries = [] + entry_to_file_map = [] for org_file in org_files: - entries.extend( - orgnode.makelist( - str(org_file))) + org_file_entries = orgnode.makelist(str(org_file)) + entry_to_file_map += [org_file]*len(org_file_entries) + entries.extend(org_file_entries) - return entries + return entries, entry_to_file_map -def convert_org_entries_to_jsonl(entries) -> str: +def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str: "Convert each Org-Mode entries to JSON and collate as JSONL" jsonl = '' - for entry in entries: + for entry_id, entry in enumerate(entries): entry_dict = dict() # Ignore title notes i.e notes with just headings and empty body @@ -106,6 +107,7 @@ def convert_org_entries_to_jsonl(entries) -> str: if entry_dict: entry_dict["raw"] = f'{entry}' + entry_dict["file"] = f'{entry_to_file_map[entry_id]}' # Convert Dictionary to JSON and Append to JSONL string jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' diff --git a/src/router.py b/src/router.py index 127623c6..a4bd2f84 100644 --- a/src/router.py +++ b/src/router.py @@ -16,8 +16,6 @@ from fastapi.templating import Jinja2Templates from src.configure import configure_search from src.search_type import image_search, text_search from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize -from src.search_filter.explicit_filter import ExplicitFilter -from src.search_filter.date_filter import DateFilter from src.utils.rawconfig import FullConfig from src.utils.config import SearchType from src.utils.helpers import get_absolute_path, get_from_dict diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py new file mode 100644 index 00000000..735b6915 --- /dev/null +++ b/src/search_filter/base_filter.py @@ -0,0 +1,20 @@ +# Standard Packages +from abc import ABC, abstractmethod +from typing import List, Set, Tuple + +# External Packages +import torch + + +class BaseFilter(ABC): + @abstractmethod + def load(self, *args, **kwargs): + pass + + @abstractmethod + def can_filter(self, raw_query:str) -> bool: + pass + + @abstractmethod + def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]: + pass \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index cab47cbb..53c7b266 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -1,56 +1,40 @@ # Standard Packages import re +import time +import logging +from collections import defaultdict from datetime import timedelta, datetime -from dateutil.relativedelta import relativedelta, MO +from dateutil.relativedelta import relativedelta from math import inf -from copy import deepcopy # External Packages -import torch import dateparser as dtparse +# Internal Packages +from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU -class DateFilter: + +logger = logging.getLogger(__name__) + + +class DateFilter(BaseFilter): # Date Range Filter Regexes # Example filter queries: - # - dt>="yesterday" dt<"tomorrow" - # - dt>="last week" - # - dt:"2 years ago" + # - dt>="yesterday" dt<"tomorrow" + # - dt>="last week" + # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" def __init__(self, entry_key='raw'): self.entry_key = entry_key + self.date_to_entry_ids = defaultdict(set) + self.cache = LRU() - def load(*args, **kwargs): - pass - - - def can_filter(self, raw_query): - "Check if query contains date filters" - return self.extract_date_range(raw_query) is not None - - - def apply(self, query, raw_entries, raw_embeddings): - "Find entries containing any dates that fall within date range specified in query" - # extract date range specified in date filter of query - query_daterange = self.extract_date_range(query) - - # if no date in query, return all entries - if query_daterange is None: - return query, raw_entries, raw_embeddings - - # remove date range filter from query - query = re.sub(rf'\s+{self.date_regex}', ' ', query) - query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces - - # deep copy original embeddings, entries before filtering - embeddings= deepcopy(raw_embeddings) - entries = deepcopy(raw_entries) - - # find entries containing any dates that fall with date range specified in query - entries_to_include = set() + def load(self, entries, **_): + start = time.time() for id, entry in enumerate(entries): # Extract dates from entry for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): @@ -59,18 +43,56 @@ class DateFilter: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() except ValueError: continue - # Check if date in entry is within date range specified in query - if query_daterange[0] <= date_in_entry < query_daterange[1]: - entries_to_include.add(id) - break + self.date_to_entry_ids[date_in_entry].add(id) + end = time.time() + logger.debug(f"Created file filter index: {end - start} seconds") - # delete entries (and their embeddings) marked for exclusion - entries_to_exclude = set(range(len(entries))) - entries_to_include - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - return query, entries, embeddings + def can_filter(self, raw_query): + "Check if query contains date filters" + return self.extract_date_range(raw_query) is not None + + + def apply(self, query, raw_entries): + "Find entries containing any dates that fall within date range specified in query" + # extract date range specified in date filter of query + start = time.time() + query_daterange = self.extract_date_range(query) + end = time.time() + logger.debug(f"Extract date range to filter from query: {end - start} seconds") + + # if no date in query, return all entries + if query_daterange is None: + return query, set(range(len(raw_entries))) + + # remove date range filter from query + query = re.sub(rf'\s+{self.date_regex}', ' ', query) + query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + + # return results from cache if exists + cache_key = tuple(query_daterange) + if cache_key in self.cache: + logger.info(f"Return date filter results from cache") + entries_to_include = self.cache[cache_key] + return query, entries_to_include + + if not self.date_to_entry_ids: + self.load(raw_entries) + + # find entries containing any dates that fall with date range specified in query + start = time.time() + entries_to_include = set() + for date_in_entry in self.date_to_entry_ids.keys(): + # Check if date in entry is within date range specified in query + if query_daterange[0] <= date_in_entry < query_daterange[1]: + entries_to_include |= self.date_to_entry_ids[date_in_entry] + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # cache results + self.cache[cache_key] = entries_to_include + + return query, entries_to_include def extract_date_range(self, query): diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py new file mode 100644 index 00000000..41f80274 --- /dev/null +++ b/src/search_filter/file_filter.py @@ -0,0 +1,79 @@ +# Standard Packages +import re +import fnmatch +import time +import logging +from collections import defaultdict + +# Internal Packages +from src.search_filter.base_filter import BaseFilter +from src.utils.helpers import LRU + + +logger = logging.getLogger(__name__) + + +class FileFilter(BaseFilter): + file_filter_regex = r'file:"(.+?)" ?' + + def __init__(self, entry_key='file'): + self.entry_key = entry_key + self.file_to_entry_map = defaultdict(set) + self.cache = LRU() + + def load(self, entries, *args, **kwargs): + start = time.time() + for id, entry in enumerate(entries): + self.file_to_entry_map[entry[self.entry_key]].add(id) + end = time.time() + logger.debug(f"Created file filter index: {end - start} seconds") + + def can_filter(self, raw_query): + return re.search(self.file_filter_regex, raw_query) is not None + + def apply(self, raw_query, raw_entries): + # Extract file filters from raw query + start = time.time() + raw_files_to_search = re.findall(self.file_filter_regex, raw_query) + if not raw_files_to_search: + return raw_query, set(range(len(raw_entries))) + + # Convert simple file filters with no path separator into regex + # e.g. "file:notes.org" -> "file:.*notes.org" + files_to_search = [] + for file in sorted(raw_files_to_search): + if '/' not in file and '\\' not in file and '*' not in file: + files_to_search += [f'*{file}'] + else: + files_to_search += [file] + end = time.time() + logger.debug(f"Extract files_to_search from query: {end - start} seconds") + + # Return item from cache if exists + query = re.sub(self.file_filter_regex, '', raw_query).strip() + cache_key = tuple(files_to_search) + if cache_key in self.cache: + logger.info(f"Return file filter results from cache") + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices + + if not self.file_to_entry_map: + self.load(raw_entries, regenerate=False) + + # Mark entries that contain any blocked_words for exclusion + start = time.time() + + included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] + for entry_file in self.file_to_entry_map.keys() + for search_file in files_to_search + if fnmatch.fnmatch(entry_file, search_file)], set()) + if not included_entry_indices: + return query, {} + + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # Cache results + self.cache[cache_key] = included_entry_indices + + return query, included_entry_indices diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/word_filter.py similarity index 73% rename from src/search_filter/explicit_filter.py rename to src/search_filter/word_filter.py index 7a26f830..dcf9ca6b 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/word_filter.py @@ -8,6 +8,7 @@ import logging import torch # Internal Packages +from src.search_filter.base_filter import BaseFilter from src.utils.helpers import LRU, resolve_absolute_path from src.utils.config import SearchType @@ -15,13 +16,13 @@ from src.utils.config import SearchType logger = logging.getLogger(__name__) -class ExplicitFilter: +class WordFilter(BaseFilter): # Filter Regex required_regex = r'\+"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?' def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): - self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") + self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl") self.entry_key = entry_key self.search_type = search_type self.word_to_entry_index = dict() @@ -34,7 +35,7 @@ class ExplicitFilter: with self.filter_file.open('rb') as f: self.word_to_entry_index = pickle.load(f) end = time.time() - logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") + logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds") else: start = time.time() self.cache = {} # Clear cache on (re-)generating entries_by_word_set @@ -51,23 +52,22 @@ class ExplicitFilter: with self.filter_file.open('wb') as f: pickle.dump(self.word_to_entry_index, f) end = time.time() - logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") + logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds") return self.word_to_entry_index def can_filter(self, raw_query): - "Check if query contains explicit filters" - # Extract explicit query portion with required, blocked words to filter from natural query + "Check if query contains word filters" required_words = re.findall(self.required_regex, raw_query) blocked_words = re.findall(self.blocked_regex, raw_query) return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, raw_query, raw_entries, raw_embeddings): + def apply(self, raw_query, raw_entries): "Find entries containing required and not blocked words specified in query" - # Separate natural query from explicit required, blocked words filters + # Separate natural query from required, blocked words filters start = time.time() required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) @@ -78,14 +78,14 @@ class ExplicitFilter: logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # Return item from cache if exists cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) if cache_key in self.cache: - logger.info(f"Explicit filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + logger.info(f"Return word filter results from cache") + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices if not self.word_to_entry_index: self.load(raw_entries, regenerate=False) @@ -105,17 +105,10 @@ class ExplicitFilter: end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # get entries (and their embeddings) satisfying inclusion and exclusion filters - start = time.time() - + # get entries satisfying inclusion and exclusion filters included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words - entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 742ff5ed..8666056c 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -7,13 +7,12 @@ import time # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util -from src.search_filter.date_filter import DateFilter -from src.search_filter.explicit_filter import ExplicitFilter +from src.search_filter.base_filter import BaseFilter # Internal Packages from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.utils.config import SearchType, TextSearchModel +from src.utils.config import TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -53,9 +52,7 @@ def initialize_model(search_config: TextSearchConfig): def extract_entries(jsonl_file): "Load entries from compressed jsonl" - return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'} - for entry - in load_jsonl(jsonl_file)] + return load_jsonl(jsonl_file) def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): @@ -79,12 +76,25 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search - start = time.time() + start_filter = time.time() + included_entry_indices = set(range(len(entries))) filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: - query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) - end = time.time() - logger.debug(f"Filter Time: {end - start:.3f} seconds") + query, included_entry_indices_by_filter = filter.apply(query, entries) + included_entry_indices.intersection_update(included_entry_indices_by_filter) + + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return [], [] + else: + start = time.time() + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) + end = time.time() + logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") + + end_filter = time.time() + logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds") if entries is None or len(entries) == 0: return [], [] @@ -153,7 +163,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -170,8 +180,6 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) - filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) - filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)] for filter in filters: filter.load(entries, regenerate=regenerate) diff --git a/src/utils/config.py b/src/utils/config.py index a4de6b81..c163e22f 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -2,9 +2,11 @@ from enum import Enum from dataclasses import dataclass from pathlib import Path +from typing import List # Internal Packages from src.utils.rawconfig import ConversationProcessorConfig +from src.search_filter.base_filter import BaseFilter class SearchType(str, Enum): @@ -20,7 +22,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder diff --git a/tests/conftest.py b/tests/conftest.py index 930ec734..7545527f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Standard Packages +# External Packages import pytest # Internal Packages @@ -6,10 +6,13 @@ from src.search_type import image_search, text_search from src.utils.config import SearchType from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.date_filter import DateFilter +from src.search_filter.word_filter import WordFilter +from src.search_filter.file_filter import FileFilter @pytest.fixture(scope='session') -def search_config(tmp_path_factory): +def search_config(tmp_path_factory) -> SearchConfig: model_dir = tmp_path_factory.mktemp('data') search_config = SearchConfig() @@ -35,7 +38,7 @@ def search_config(tmp_path_factory): @pytest.fixture(scope='session') -def model_dir(search_config): +def model_dir(search_config: SearchConfig): model_dir = search_config.asymmetric.model_directory # Generate Image Embeddings from Test Images @@ -55,13 +58,14 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + filters = [DateFilter(), WordFilter(model_dir, search_type=SearchType.Org), FileFilter()] + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) return model_dir @pytest.fixture(scope='session') -def content_config(model_dir): +def content_config(model_dir) -> ContentConfig: content_config = ContentConfig() content_config.org = TextContentConfig( input_files = None, diff --git a/tests/test_client.py b/tests/test_client.py index e9b632be..578c789c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,6 @@ from PIL import Image # External Packages from fastapi.testclient import TestClient -import pytest # Internal Packages from src.main import app @@ -12,7 +11,8 @@ from src.utils.config import SearchType from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig -from src.processor.org_mode import org_to_jsonl +from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.word_filter import WordFilter # Arrange @@ -116,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application?" # Act @@ -132,7 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = 'How to git install application? +"Emacs"' # Act @@ -140,7 +141,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # Assert assert response.status_code == 200 - # assert actual_data contains explicitly included word "Emacs" + # assert actual_data contains word "Emacs" search_result = response.json()[0]["entry"] assert "Emacs" in search_result @@ -148,7 +149,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)] + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = 'How to git install application? -"clone"' # Act @@ -156,6 +158,6 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_ # Assert assert response.status_code == 200 - # assert actual_data does not contains explicitly excluded word "Emacs" + # assert actual_data does not contains word "Emacs" search_result = response.json()[0]["entry"] assert "clone" not in search_result diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index ddb1fcf0..345c5c4f 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,40 +18,34 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) assert ret_query == 'head tail' - assert len(ret_emb) == 3 - assert ret_entries == entries + assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[1]] - assert len(ret_emb) == 1 + assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' - assert ret_entries == [entries[1], entries[2]] - assert len(ret_emb) == 2 + assert entry_indices == {1, 2} def test_extract_date_range(): diff --git a/tests/test_explicit_filter.py b/tests/test_explicit_filter.py deleted file mode 100644 index 5f34b0ac..00000000 --- a/tests/test_explicit_filter.py +++ /dev/null @@ -1,85 +0,0 @@ -# External Packages -import torch - -# Application Packages -from src.search_filter.explicit_filter import ExplicitFilter -from src.utils.config import SearchType - - -def test_no_explicit_filter(tmp_path): - # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() - q_with_no_filter = 'head tail' - - # Act - can_filter = explicit_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) - - # Assert - assert can_filter == False - assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries - - -def test_explicit_exclude_filter(tmp_path): - # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() - q_with_exclude_filter = 'head -"exclude_word" tail' - - # Act - can_filter = explicit_filter.can_filter(q_with_exclude_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) - - # Assert - assert can_filter == True - assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] - - -def test_explicit_include_filter(tmp_path): - # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() - query_with_include_filter = 'head +"include_word" tail' - - # Act - can_filter = explicit_filter.can_filter(query_with_include_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) - - # Assert - assert can_filter == True - assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[2], entries[3]] - - -def test_explicit_include_and_exclude_filter(tmp_path): - # Arrange - explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() - query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' - - # Act - can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) - - # Assert - assert can_filter == True - assert ret_query == 'head tail' - assert len(ret_emb) == 1 - assert ret_entries == [entries[2]] - - -def arrange_content(): - embeddings = torch.randn(4, 10) - entries = [ - {'compiled': '', 'raw': 'Minimal Entry'}, - {'compiled': '', 'raw': 'Entry with exclude_word'}, - {'compiled': '', 'raw': 'Entry with include_word'}, - {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] - - return embeddings, entries diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py new file mode 100644 index 00000000..3f9c22b3 --- /dev/null +++ b/tests/test_file_filter.py @@ -0,0 +1,112 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.file_filter import FileFilter + + +def test_no_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def test_file_filter_with_non_existent_file(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"nonexistent.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {} + + +def test_single_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"file 1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 2} + + +def test_file_filter_with_partial_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 2} + + +def test_file_filter_with_regex_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"*.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def test_multiple_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, + {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, + {'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, + {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] + + return embeddings, entries diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index cadd4a6a..6a626299 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -21,10 +21,10 @@ def test_entry_with_empty_body_line_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries = extract_org_entries(org_files=[orgfile]) + entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries) + jsonl_data = convert_org_entries_to_jsonl(entries, entry_to_file_map) # Assert assert is_none_or_empty(jsonl_data) @@ -43,10 +43,10 @@ def test_entry_with_body_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries = extract_org_entries(org_files=[orgfile]) + entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_org_entries_to_jsonl(entries) + jsonl_string = convert_org_entries_to_jsonl(entries, entry_to_file_map) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 84f16df5..39fed92e 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,6 +1,5 @@ # System Packages from pathlib import Path -from src.utils.config import SearchType # Internal Packages from src.utils.state import model @@ -14,7 +13,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) + notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -24,7 +23,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) query = "How to git install application?" # Act @@ -47,7 +46,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -60,11 +59,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) + regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) # Assert assert len(regenerated_notes_model.entries) == 11 diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py new file mode 100644 index 00000000..3efe8ed9 --- /dev/null +++ b/tests/test_word_filter.py @@ -0,0 +1,81 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.word_filter import WordFilter +from src.utils.config import SearchType + + +def test_no_word_filter(tmp_path): + # Arrange + word_filter = WordFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = word_filter.can_filter(q_with_no_filter) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert entry_indices == {0, 1, 2, 3} + + +def test_word_exclude_filter(tmp_path): + # Arrange + word_filter = WordFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + q_with_exclude_filter = 'head -"exclude_word" tail' + + # Act + can_filter = word_filter.can_filter(q_with_exclude_filter) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {0, 2} + + +def test_word_include_filter(tmp_path): + # Arrange + word_filter = WordFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + query_with_include_filter = 'head +"include_word" tail' + + # Act + can_filter = word_filter.can_filter(query_with_include_filter) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {2, 3} + + +def test_word_include_and_exclude_filter(tmp_path): + # Arrange + word_filter = WordFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' + + # Act + can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert entry_indices == {2} + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'Minimal Entry'}, + {'compiled': '', 'raw': 'Entry with exclude_word'}, + {'compiled': '', 'raw': 'Entry with include_word'}, + {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] + + return embeddings, entries