Consolidate the search config models and pass verbose as a top level flag

This commit is contained in:
Saba
2021-12-04 11:43:48 -05:00
parent 43e647835b
commit 10e4065e05
6 changed files with 52 additions and 90 deletions

View File

@@ -14,7 +14,8 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils.config import TextSearchModel, TextSearchConfig
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfigModel
def initialize_model():
@@ -148,22 +149,22 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model()
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose)
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
# Extract Entries
entries = extract_entries(config.compressed_jsonl, config.verbose)
entries = extract_entries(config.compressed_jsonl, verbose)
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
if __name__ == '__main__':

View File

@@ -10,9 +10,10 @@ from tqdm import trange
import torch
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.utils.helpers import resolve_absolute_path
import src.utils.exiftool as exiftool
from src.utils.config import ImageSearchModel, ImageSearchConfig
from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageSearchConfigModel
def initialize_model():
@@ -153,7 +154,7 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]]
def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
def setup(config: ImageSearchConfigModel, regenerate: bool, verbose: bool) -> ImageSearchModel:
# Initialize Model
encoder = initialize_model()
@@ -170,13 +171,13 @@ def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata,
verbose=config.verbose)
verbose=verbose)
return ImageSearchModel(image_names,
image_embeddings,
image_metadata_embeddings,
encoder,
config.verbose)
verbose)
if __name__ == '__main__':

View File

@@ -12,7 +12,8 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.config import TextSearchModel, TextSearchConfig
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfigModel
def initialize_model():
@@ -140,7 +141,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model()
@@ -153,9 +154,9 @@ def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
top_k = min(len(entries), top_k)
# Compute or Load Embeddings
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
if __name__ == '__main__':