diff --git a/src/main.py b/src/main.py index f991adcf..515c64a8 100644 --- a/src/main.py +++ b/src/main.py @@ -11,7 +11,9 @@ from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates # Internal Packages -from src.search_type import asymmetric, symmetric_ledger, image_search +from src.search_type import image_search, text_search +from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.utils.helpers import get_absolute_path, get_from_dict from src.utils.cli import cli from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel @@ -66,24 +68,24 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter]) + hits, entries = text_search.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter]) # collate and return results - return asymmetric.collate_results(hits, entries, results_count) + return text_search.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter]) + hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter]) # collate and return results - return asymmetric.collate_results(hits, entries, results_count) + return text_search.collate_results(hits, entries, results_count) if (t == SearchType.Ledger or t == None) and model.ledger_search: # query transactions - hits, entries = symmetric_ledger.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter]) + hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter]) # collate and return results - return symmetric_ledger.collate_results(hits, entries, results_count) + return text_search.collate_results(hits, entries, results_count) if (t == SearchType.Image or t == None) and model.image_search: # query transactions @@ -163,17 +165,17 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None # Initialize Org Notes Search if (t == SearchType.Notes or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.notes_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose) + model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/search_type/symmetric.py b/src/search_type/symmetric.py deleted file mode 100644 index 2e26e7c4..00000000 --- a/src/search_type/symmetric.py +++ /dev/null @@ -1,97 +0,0 @@ -import pandas as pd -import faiss -import numpy as np - -from sentence_transformers import SentenceTransformer - -import argparse -import os - -def create_index( - model, - dataset_path, - index_path, - column_name, - recreate): - # Load Dataset - dataset = pd.read_csv(dataset_path) - - # Clean Dataset - dataset = dataset.dropna() - dataset[column_name] = dataset[column_name].str.strip() - - # Create Index or Load it if it already exists - if os.path.exists(index_path) and not recreate: - index = faiss.read_index(index_path) - else: - # Create Embedding Vectors of Documents - embeddings = model.encode(dataset[column_name].to_list(), show_progress_bar=True) - embeddings = np.array([embedding for embedding in embeddings]).astype("float32") - - index = faiss.IndexIDMap( - faiss.IndexFlatL2( - embeddings.shape[1])) - - index.add_with_ids(embeddings, dataset.index.values) - - faiss.write_index(index, index_path) - - return index, dataset - - -def resolve_column(dataset, Id, column): - return [list(dataset[dataset.index == idx][column]) for idx in Id[0]] - - -def vector_search(query, index, dataset, column_name, num_results=10): - query_vector = np.array(query).astype("float32") - D, Id = index.search(query_vector, k=num_results) - - return zip(D[0], Id[0], resolve_column(dataset, Id, column_name)) - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Find most suitable match based on users exclude, include preferences") - parser.add_argument('positives', type=str, help="Terms to find closest match to") - parser.add_argument('--negatives', '-n', type=str, help="Terms to find farthest match from") - - parser.add_argument('--recreate', action='store_true', default=False, help="Recreate index at index_path from dataset at dataset path") - parser.add_argument('--index', type=str, default="./.faiss_index", help="Path to index for storing vector embeddings") - parser.add_argument('--dataset', type=str, default="./.dataset", help="Path to dataset to generate index from") - parser.add_argument('--column', type=str, default="DATA", help="Name of dataset column to index") - parser.add_argument('--num_results', type=int, default=10, help="Number of most suitable matches to show") - parser.add_argument('--model_name', type=str, default='all-MiniLM-L6-v2', help="Specify name of the SentenceTransformer model to use for encoding") - args = parser.parse_args() - - model = SentenceTransformer(args.model_name) - - if args.positives and not args.negatives: - # Get index, create it from dataset if doesn't exist - index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate) - - # Create vector to represent user's stated positive preference - preference_vector = model.encode([args.positives]) - - # Find and display most suitable matches for users preferences in the dataset - results = vector_search(preference_vector, index, dataset, args.column, args.num_results) - - print("Most Suitable Matches:") - for similarity, id_, data in results: - print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}") - - elif args.positives and args.negatives: - # Get index, create it from dataset if doesn't exist - index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate) - - # Create vector to represent user's stated preference - positives_vector = np.array(model.encode([args.positives])).astype("float32") - negatives_vector = np.array(model.encode([args.negatives])).astype("float32") - - # preference_vector = np.mean([positives_vector, -1 * negatives_vector], axis=0) - preference_vector = np.add(positives_vector, -1 * negatives_vector) - - # Find and display most suitable matches for users preferences in the dataset - results = vector_search(preference_vector, index, dataset, args.column, args.num_results) - - print("Most Suitable Matches:") - for similarity, id_, data in results: - print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}") diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py deleted file mode 100644 index 0dcb94bf..00000000 --- a/src/search_type/symmetric_ledger.py +++ /dev/null @@ -1,170 +0,0 @@ -# Standard Packages -import argparse -import pathlib -from copy import deepcopy - -# External Packages -import torch -from sentence_transformers import SentenceTransformer, CrossEncoder, util - -# Internal Packages -from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl -from src.utils.config import TextSearchModel -from src.utils.rawconfig import SymmetricSearchConfig, TextContentConfig -from src.utils.jsonl import load_jsonl - - -def initialize_model(search_config: SymmetricSearchConfig): - "Initialize model for symmetric semantic search. That is, where query of similar size to results" - torch.set_num_threads(4) - - # Number of entries we want to retrieve with the bi-encoder - top_k = 30 - - # The bi-encoder encodes all entries to use for semantic search - bi_encoder = load_model( - model_dir = search_config.model_directory, - model_name = search_config.encoder, - model_type = SentenceTransformer) - - # The cross-encoder re-ranks the results to improve quality - cross_encoder = load_model( - model_dir = search_config.model_directory, - model_name = search_config.cross_encoder, - model_type = CrossEncoder) - - return bi_encoder, cross_encoder, top_k - - -def extract_entries(notesfile, verbose=0): - "Load entries from compressed jsonl" - return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'} - for entry - in load_jsonl(notesfile, verbose=verbose)] - - -def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): - "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - # Load pre-computed embeddings from file if exists - if resolve_absolute_path(embeddings_file).exists() and not regenerate: - corpus_embeddings = torch.load(get_absolute_path(embeddings_file)) - if verbose > 0: - print(f"Loaded embeddings from {embeddings_file}") - - else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode(entries, convert_to_tensor=True, show_progress_bar=True) - torch.save(corpus_embeddings, get_absolute_path(embeddings_file)) - if verbose > 0: - print(f"Computed embeddings and saved them to {embeddings_file}") - - return corpus_embeddings - - -def query(raw_query, model: TextSearchModel, filters=[]): - "Search all notes for entries that answer the query" - # Copy original embeddings, entries to filter them for query - query = raw_query - corpus_embeddings = deepcopy(model.corpus_embeddings) - entries = deepcopy(model.entries) - - # Filter query, entries and embeddings before semantic search - for filter in filters: - query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) - if entries is None or len(entries) == 0: - return [], [] - - # Encode the query using the bi-encoder - question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True) - - # Find relevant entries for the query - hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0] - - # Score all retrieved entries using the cross-encoder - cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] - cross_scores = model.cross_encoder.predict(cross_inp) - - # Store cross-encoder scores in results dictionary for ranking - for idx in range(len(cross_scores)): - hits[idx]['cross-score'] = cross_scores[idx] - - # Order results by cross encoder score followed by biencoder score - hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score - hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score - - return hits, entries - - -def render_results(hits, entries, count=5, display_biencoder_results=False): - "Render the Results returned by Search for the Query" - if display_biencoder_results: - # Output of top hits from bi-encoder - print("\n-------------------------\n") - print(f"Top-{count} Bi-Encoder Retrieval hits") - hits = sorted(hits, key=lambda x: x['score'], reverse=True) - for hit in hits[0:count]: - print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}") - - # Output of top hits from re-ranker - print("\n-------------------------\n") - print(f"Top-{count} Cross-Encoder Re-ranker hits") - hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) - for hit in hits[0:count]: - print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") - - -def collate_results(hits, entries, count=5): - return [ - { - "entry": entries[hit['corpus_id']]['raw'], - "score": f"{hit['cross-score']:.3f}" - } - for hit - in hits[0:count]] - - -def setup(config: TextContentConfig, search_config: SymmetricSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel: - # Initialize Model - bi_encoder, cross_encoder, top_k = initialize_model(search_config) - - # Map notes in Org-Mode files to (compressed) JSONL formatted file - if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: - beancount_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) - - # Extract Entries - entries = extract_entries(config.compressed_jsonl, verbose) - top_k = min(len(entries), top_k) - - # Compute or Load Embeddings - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) - - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) - - -if __name__ == '__main__': - # Setup Argument Parser - parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format") - parser.add_argument('--input-files', '-i', nargs='*', help="List of Beancount files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Beancount files to process") - parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".transactions.jsonl.gz"), help="Compressed JSONL formatted transactions file to compute embeddings from") - parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".transaction_embeddings.pt"), help="File to save/load model embeddings to/from") - parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from Beancount files. Default: false") - parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5") - parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true") - parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0") - args = parser.parse_args() - - entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose) - - # Run User Queries on Entries in Interactive Mode - while args.interactive: - # get query from user - user_query = input("Enter your query: ") - if user_query == "exit": - exit(0) - - # query - hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) - - # render results - render_results(hits, entries, count=args.results_count) diff --git a/src/search_type/asymmetric.py b/src/search_type/text_search.py similarity index 82% rename from src/search_type/asymmetric.py rename to src/search_type/text_search.py index 0fb879d7..39ae19b8 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/text_search.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Standard Packages import argparse import pathlib @@ -11,14 +9,13 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.utils.config import TextSearchModel -from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig +from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl -def initialize_model(search_config: AsymmetricSearchConfig): - "Initialize model for assymetric semantic search. That is, where query smaller than results" +def initialize_model(search_config: TextSearchConfig): + "Initialize model for semantic search on text" torch.set_num_threads(4) # Number of entries we want to retrieve with the bi-encoder @@ -39,11 +36,11 @@ def initialize_model(search_config: AsymmetricSearchConfig): return bi_encoder, cross_encoder, top_k -def extract_entries(notesfile, verbose=0): +def extract_entries(jsonl_file, verbose=0): "Load entries from compressed jsonl" return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'} for entry - in load_jsonl(notesfile, verbose=verbose)] + in load_jsonl(jsonl_file, verbose=verbose)] def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0): @@ -65,9 +62,8 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []): - "Search all notes for entries that answer the query" - +def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []): + "Search for entries that answer the query" # Copy original embeddings, entries to filter them for query query = raw_query corpus_embeddings = deepcopy(model.corpus_embeddings) @@ -130,13 +126,13 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device=torch.device('cpu'), verbose: bool=False) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) - # Map notes in Org-Mode files to (compressed) JSONL formatted file + # Map notes in text files to (compressed) JSONL formatted file if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: - org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) + text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) # Extract Entries entries = extract_entries(config.compressed_jsonl, verbose) @@ -150,12 +146,12 @@ def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, rege if __name__ == '__main__': # Setup Argument Parser - parser = argparse.ArgumentParser(description="Map Org-Mode notes into (compressed) JSONL format") - parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process") - parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from") - parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from") - parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false") + parser = argparse.ArgumentParser(description="Map Text files into (compressed) JSONL format") + parser.add_argument('--input-files', '-i', nargs='*', help="List of Text files to process") + parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Text files to process") + parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path("text.jsonl.gz"), help="Compressed JSONL to compute embeddings from") + parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path("text_embeddings.pt"), help="File to save/load model embeddings to/from") + parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from text files. Default: false") parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5") parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true") parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0") diff --git a/src/utils/install.py b/src/utils/install.py index e894a7fe..a92eec3a 100644 --- a/src/utils/install.py +++ b/src/utils/install.py @@ -36,7 +36,7 @@ conda activate khoj cd {get_absolute(args.script_dir)} # Act -python3 search_types/asymmetric.py -j {get_absolute(args.model_dir)}/notes.jsonl.gz -e {get_absolute(args.model_dir)}/notes_embeddings.pt -n 5 --interactive +python3 search_types/text_search.py -j {get_absolute(args.model_dir)}/notes.jsonl.gz -e {get_absolute(args.model_dir)}/notes_embeddings.pt -n 5 --interactive ''' search_cmd_content = f'''#!/bin/bash diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 4a8749a8..a355d6a4 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -32,12 +32,7 @@ class ContentConfig(ConfigBase): image: Optional[ImageContentConfig] music: Optional[TextContentConfig] -class SymmetricSearchConfig(ConfigBase): - encoder: Optional[str] - cross_encoder: Optional[str] - model_directory: Optional[Path] - -class AsymmetricSearchConfig(ConfigBase): +class TextSearchConfig(ConfigBase): encoder: Optional[str] cross_encoder: Optional[str] model_directory: Optional[Path] @@ -47,8 +42,8 @@ class ImageSearchConfig(ConfigBase): model_directory: Optional[Path] class SearchConfig(ConfigBase): - asymmetric: Optional[AsymmetricSearchConfig] - symmetric: Optional[SymmetricSearchConfig] + asymmetric: Optional[TextSearchConfig] + symmetric: Optional[TextSearchConfig] image: Optional[ImageSearchConfig] class ConversationProcessorConfig(ConfigBase): diff --git a/tests/conftest.py b/tests/conftest.py index 2471f8e0..622da0da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ import pytest import torch # Internal Packages -from src.search_type import asymmetric, image_search -from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, SymmetricSearchConfig, AsymmetricSearchConfig, ImageSearchConfig +from src.search_type import image_search, text_search +from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig +from src.processor.org_mode.org_to_jsonl import org_to_jsonl @pytest.fixture(scope='session') @@ -13,13 +14,13 @@ def search_config(tmp_path_factory): search_config = SearchConfig() - search_config.asymmetric = SymmetricSearchConfig( + search_config.symmetric = TextSearchConfig( encoder = "sentence-transformers/all-MiniLM-L6-v2", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory = model_dir ) - search_config.asymmetric = AsymmetricSearchConfig( + search_config.asymmetric = TextSearchConfig( encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory = model_dir @@ -55,7 +56,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True) + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True) return model_dir diff --git a/tests/test_asymmetric_search.py b/tests/test_asymmetric_search.py index 0a167341..b14cc10d 100644 --- a/tests/test_asymmetric_search.py +++ b/tests/test_asymmetric_search.py @@ -3,8 +3,9 @@ from pathlib import Path # Internal Packages from src.main import model -from src.search_type import asymmetric +from src.search_type import text_search from src.utils.rawconfig import ContentConfig, SearchConfig +from src.processor.org_mode.org_to_jsonl import org_to_jsonl # Test @@ -12,7 +13,7 @@ from src.utils.rawconfig import ContentConfig, SearchConfig def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True) + notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -22,15 +23,15 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) query = "How to git install application?" # Act - hits, entries = asymmetric.query( + hits, entries = text_search.query( query, model = model.notes_search) - results = asymmetric.collate_results( + results = text_search.collate_results( hits, entries, count=1) @@ -44,7 +45,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -57,11 +58,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True) + regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) # Assert assert len(regenerated_notes_model.entries) == 11 diff --git a/tests/test_client.py b/tests/test_client.py index 975a9f18..73b0f210 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,9 +8,9 @@ import pytest # Internal Packages from src.main import app, model, config -from src.search_type import asymmetric, image_search -from src.utils.helpers import resolve_absolute_path +from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig +from src.processor.org_mode import org_to_jsonl # Arrange @@ -115,7 +115,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application?" # Act @@ -131,7 +131,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application? +Emacs" # Act @@ -147,7 +147,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = "How to git install application? -clone" # Act