Modularize Code. Wrap Search, Model Config in Classes. Add Tests

Details
  - Rename method query_* to query in search_types for standardization
  - Wrapping Config code in classes simplified mocking test config
  - Reduce args beings passed to a function by passing it as single
    argument wrapped in a class
  - Minimize setup in main.py:__main__. Put most of it into functions
    These functions can be mocked if required in tests later too

Setup Flow:
  CLI_Args|Config_YAML -> (Text|Image)SearchConfig -> (Text|Image)SearchModel
This commit is contained in:
Debanjum Singh Solanky
2021-09-30 02:04:04 -07:00
parent f4dd9cd117
commit d5597442f4
6 changed files with 201 additions and 154 deletions

View File

@@ -17,7 +17,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from utils.helpers import get_absolute_path, resolve_absolute_path
from processor.org_mode.org_to_jsonl import org_to_jsonl
from utils.config import AsymmetricSearchModel
from utils.config import TextSearchModel, TextSearchConfig
def initialize_model():
@@ -66,7 +66,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
return corpus_embeddings
def query_notes(raw_query: str, model: AsymmetricSearchModel):
def query(raw_query: str, model: TextSearchModel):
"Search all notes for entries that answer the query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
@@ -151,21 +151,21 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False):
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model()
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(compressed_jsonl).exists() or regenerate:
org_to_jsonl(input_files, input_filter, compressed_jsonl, verbose)
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose)
# Extract Entries
entries = extract_entries(compressed_jsonl, verbose)
entries = extract_entries(config.compressed_jsonl, config.verbose)
# Compute or Load Embeddings
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
return AsymmetricSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
if __name__ == '__main__':
@@ -191,7 +191,7 @@ if __name__ == '__main__':
exit(0)
# query notes
hits = query_notes(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# render results
render_results(hits, entries, count=args.results_count)

View File

@@ -12,6 +12,8 @@ import torch
# Internal Packages
from utils.helpers import get_absolute_path, resolve_absolute_path
import utils.exiftool as exiftool
from utils.config import ImageSearchModel, ImageSearchConfig
def initialize_model():
# Initialize Model
@@ -93,30 +95,31 @@ def extract_metadata(image_name, verbose=0):
return image_processed_metadata
def query_images(query, image_embeddings, image_metadata_embeddings, model, count=3, verbose=0):
def query(raw_query, count, model: ImageSearchModel):
# Set query to image content if query is a filepath
if pathlib.Path(query).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(query), strict=True)
if pathlib.Path(raw_query).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
query = copy.deepcopy(Image.open(query_imagepath))
if verbose > 0:
if model.verbose > 0:
print(f"Find Images similar to Image at {query_imagepath}")
else:
if verbose > 0:
query = raw_query
if model.verbose > 0:
print(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)
query_embedding = model.encode([query], convert_to_tensor=True, show_progress_bar=False)
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
image_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]}
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if image_metadata_embeddings:
if model.image_metadata_embeddings:
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]}
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():
@@ -150,20 +153,30 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]]
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model
model = initialize_model()
# Extract Entries
image_directory = resolve_absolute_path(image_directory, strict=True)
image_names = extract_entries(image_directory, verbose)
image_directory = resolve_absolute_path(config.input_directory, strict=True)
image_names = extract_entries(config.input_directory, config.verbose)
# Compute or Load Embeddings
embeddings_file = resolve_absolute_path(embeddings_file)
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file,
batch_size=batch_size, regenerate=regenerate, use_xmp_metadata=use_xmp_metadata, verbose=verbose)
embeddings_file = resolve_absolute_path(config.embeddings_file)
image_embeddings, image_metadata_embeddings = compute_embeddings(
image_names,
model,
embeddings_file,
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata,
verbose=config.verbose)
return image_names, image_embeddings, image_metadata_embeddings, model
return ImageSearchModel(image_names,
image_embeddings,
image_metadata_embeddings,
model,
config.verbose)
if __name__ == '__main__':
@@ -187,7 +200,7 @@ if __name__ == '__main__':
exit(0)
# query images
hits = query_images(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose)
hits = query(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose)
# render results
render_results(hits, image_names, args.image_directory, count=args.results_count)

View File

@@ -15,6 +15,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from utils.helpers import get_absolute_path, resolve_absolute_path
from processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from utils.config import TextSearchModel, TextSearchConfig
def initialize_model():
@@ -59,7 +60,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
return corpus_embeddings
def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100):
def query(raw_query, model: TextSearchModel):
"Search all notes for entries that answer the query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
@@ -67,20 +68,20 @@ def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
# Encode the query using the bi-encoder
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
# Find relevant entries for the query
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
hits = hits[0] # Get the hits for the first query
# Filter results using explicit filters
hits = explicit_filter(hits, entries, required_words, blocked_words)
hits = explicit_filter(hits, model.entries, required_words, blocked_words)
if hits is None or len(hits) == 0:
return hits
# Score all retrieved entries using the cross-encoder
cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
cross_inp = [[query, model.entries[hit['corpus_id']]] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
@@ -142,21 +143,21 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False):
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model()
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(compressed_jsonl).exists() or regenerate:
beancount_to_jsonl(input_files, input_filter, compressed_jsonl, verbose)
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
beancount_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose)
# Extract Entries
entries = extract_entries(compressed_jsonl, verbose)
entries = extract_entries(config.compressed_jsonl, config.verbose)
# Compute or Load Embeddings
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
if __name__ == '__main__':
@@ -181,8 +182,8 @@ if __name__ == '__main__':
if user_query == "exit":
exit(0)
# query notes
hits = query_transactions(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# query
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# render results
render_results(hits, entries, count=args.results_count)