Save Asymmetric Search Model to Disk

- Improve application load time
- Remove dependence on internet to startup application and perform semantic search
This commit is contained in:
Debanjum Singh Solanky
2022-01-14 16:31:55 -05:00
parent 2e53fbc844
commit b63026d97c
5 changed files with 43 additions and 10 deletions

View File

@@ -12,18 +12,31 @@ import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig
from src.utils.rawconfig import AsymmetricConfig, TextSearchConfig
def initialize_model():
def initialize_model(search_config: AsymmetricConfig):
"Initialize model for assymetric semantic search. That is, where query smaller than results"
torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3') # The bi-encoder encodes all entries to use for semantic search
top_k = 30 # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality
# Number of entries we want to retrieve with the bi-encoder
top_k = 30
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder)
return bi_encoder, cross_encoder, top_k
@@ -149,9 +162,9 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
def setup(config: TextSearchConfig, search_config: AsymmetricConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model()
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: