diff --git a/sample_config.yml b/sample_config.yml index ee19c8ee..3509f038 100644 --- a/sample_config.yml +++ b/sample_config.yml @@ -24,6 +24,11 @@ content-type: embeddings-file: "tests/data/.song_embeddings.pt" search-type: + symmetric: + encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2" + cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" + model_directory: "tests/data/.symmetric" + asymmetric: encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3" cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" diff --git a/src/main.py b/src/main.py index 71397486..2edbc209 100644 --- a/src/main.py +++ b/src/main.py @@ -140,7 +140,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose) + model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index 4a747215..c7d9cc9a 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -10,18 +10,31 @@ import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages -from src.utils.helpers import get_absolute_path, resolve_absolute_path +from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.utils.config import TextSearchModel -from src.utils.rawconfig import TextSearchConfig +from src.utils.rawconfig import SymmetricConfig, TextSearchConfig -def initialize_model(): +def initialize_model(search_config: SymmetricConfig): "Initialize model for symmetric semantic search. That is, where query of similar size to results" torch.set_num_threads(4) - bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search - top_k = 30 # Number of entries we want to retrieve with the bi-encoder - cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality + + # Number of entries we want to retrieve with the bi-encoder + top_k = 30 + + # The bi-encoder encodes all entries to use for semantic search + bi_encoder = load_model( + model_dir = search_config.model_directory, + model_name = search_config.encoder, + model_type = SentenceTransformer) + + # The cross-encoder re-ranks the results to improve quality + cross_encoder = load_model( + model_dir = search_config.model_directory, + model_name = search_config.cross_encoder, + model_type = CrossEncoder) + return bi_encoder, cross_encoder, top_k @@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel: +def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel: # Initialize Model - bi_encoder, cross_encoder, top_k = initialize_model() + bi_encoder, cross_encoder, top_k = initialize_model(search_config) # Map notes in Org-Mode files to (compressed) JSONL formatted file if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 44e4cce1..00eb2fc7 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -37,6 +37,11 @@ class ContentTypeConfig(ConfigBase): image: Optional[ImageSearchConfig] music: Optional[TextSearchConfig] +class SymmetricConfig(ConfigBase): + encoder: Optional[str] + cross_encoder: Optional[str] + model_directory: Optional[Path] + class AsymmetricConfig(ConfigBase): encoder: Optional[str] cross_encoder: Optional[str] @@ -47,6 +52,7 @@ class ImageSearchTypeConfig(ConfigBase): class SearchTypeConfig(ConfigBase): asymmetric: Optional[AsymmetricConfig] + symmetric: Optional[SymmetricConfig] image: Optional[ImageSearchTypeConfig] class ConversationProcessorConfig(ConfigBase):