diff --git a/environment.yml b/environment.yml index bd7f9f59..67535adc 100644 --- a/environment.yml +++ b/environment.yml @@ -10,4 +10,6 @@ dependencies: - fastapi - uvicorn - pyyaml - - pytest \ No newline at end of file + - pytest + - pillow + - torchvision \ No newline at end of file diff --git a/src/main.py b/src/main.py index 3c9fa2a5..b8b03e06 100644 --- a/src/main.py +++ b/src/main.py @@ -8,7 +8,7 @@ import uvicorn from fastapi import FastAPI # Internal Packages -from search_type import asymmetric, symmetric_ledger +from search_type import asymmetric, symmetric_ledger, image_search from utils.helpers import get_from_dict from utils.cli import cli @@ -50,6 +50,22 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): # collate and return results return symmetric_ledger.collate_results(hits, transactions, results_count) + if (t == 'image' or t == None) and image_search_enabled: + # query transactions + hits = image_search.query_images( + user_query, + image_embeddings, + image_encoder, + results_count, + args.verbose) + + # collate and return results + return image_search.collate_results( + hits, + image_names, + image_config['input-directory'], + results_count) + else: return {} @@ -80,6 +96,16 @@ def regenerate(t: Optional[str] = None): regenerate=True, verbose=args.verbose) + if (t == 'image' or t == None) and image_search_enabled: + # Extract Images, Generate Embeddings + global image_embeddings + global image_names + image_names, image_embeddings, _ = image_search.setup( + pathlib.Path(image_config['input-directory']), + pathlib.Path(image_config['embeddings-file']), + regenerate=True, + verbose=args.verbose) + return {'status': 'ok', 'message': 'regeneration completed'} @@ -112,5 +138,16 @@ if __name__ == '__main__': args.regenerate, args.verbose) + # Initialize Image Search + image_config = get_from_dict(args.config, 'content-type', 'image') + image_search_enabled = False + if image_config and 'input-directory' in image_config: + image_search_enabled = True + image_names, image_embeddings, image_encoder = image_search.setup( + pathlib.Path(image_config['input-directory']), + pathlib.Path(image_config['embeddings-file']), + args.regenerate, + args.verbose) + # Start Application Server uvicorn.run(app) diff --git a/src/search_type/image-search.py b/src/search_type/image_search.py similarity index 63% rename from src/search_type/image-search.py rename to src/search_type/image_search.py index d62f4239..88b69f83 100644 --- a/src/search_type/image-search.py +++ b/src/search_type/image_search.py @@ -1,31 +1,38 @@ -from sentence_transformers import SentenceTransformer, util -from PIL import Image -import torch +# Standard Packages import argparse import pathlib import copy +# External Packages +from sentence_transformers import SentenceTransformer, util +from PIL import Image +import torch + +# Internal Packages +from utils.helpers import get_absolute_path, resolve_absolute_path + def initialize_model(): # Initialize Model torch.set_num_threads(4) - top_k = 3 model = SentenceTransformer('clip-ViT-B-32') #Load the CLIP model - return model, top_k + return model def extract_entries(image_directory, verbose=False): + image_directory = resolve_absolute_path(image_directory, strict=True) image_names = list(image_directory.glob('*.jpg')) if verbose: print(f'Found {len(image_names)} images in {image_directory}') return image_names -def compute_embeddings(image_names, model, embeddings_file, verbose=False): +def compute_embeddings(image_names, model, embeddings_file, regenerate=False, verbose=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" + image_embeddings = None # Load pre-computed embeddings from file if exists - if embeddings_file.exists(): + if embeddings_file.exists() and not regenerate: image_embeddings = torch.load(embeddings_file) if verbose: print(f"Loaded pre-computed embeddings from {embeddings_file}") @@ -46,10 +53,10 @@ def compute_embeddings(image_names, model, embeddings_file, verbose=False): return image_embeddings -def search(query, image_embeddings, model, count=3, verbose=False): +def query_images(query, image_embeddings, model, count=3, verbose=False): # Set query to image content if query is a filepath - if pathlib.Path(query).expanduser().is_file(): - query_imagepath = pathlib.Path(query).expanduser().resolve(strict=True) + if pathlib.Path(query).is_file(): + query_imagepath = resolve_absolute_path(pathlib.Path(query), strict=True) query = copy.deepcopy(Image.open(query_imagepath)) if verbose: print(f"Find Images similar to Image at {query_imagepath}") @@ -68,6 +75,8 @@ def search(query, image_embeddings, model, count=3, verbose=False): def render_results(hits, image_names, image_directory, count): + image_directory = resolve_absolute_path(image_directory, strict=True) + for hit in hits[:count]: print(image_names[hit['corpus_id']]) image_path = image_directory.joinpath(image_names[hit['corpus_id']]) @@ -75,28 +84,44 @@ def render_results(hits, image_names, image_directory, count): img.show() +def collate_results(hits, image_names, image_directory, count=5): + image_directory = resolve_absolute_path(image_directory, strict=True) + return [ + { + "Entry": image_directory.joinpath(image_names[hit['corpus_id']]), + "Score": f"{hit['score']:.3f}" + } + for hit + in hits[0:count]] + + +def setup(image_directory, embeddings_file, regenerate=False, verbose=False): + # Initialize Model + model = initialize_model() + + # Extract Entries + image_directory = resolve_absolute_path(image_directory, strict=True) + image_names = extract_entries(image_directory, verbose) + + # Compute or Load Embeddings + embeddings_file = resolve_absolute_path(embeddings_file) + image_embeddings = compute_embeddings(image_names, model, embeddings_file, regenerate=regenerate, verbose=verbose) + + return image_names, image_embeddings, model + + if __name__ == '__main__': # Setup Argument Parser parser = argparse.ArgumentParser(description="Semantic Search on Images") parser.add_argument('--image-directory', '-i', required=True, type=pathlib.Path, help="Image directory to query") - parser.add_argument('--embeddings-file', '-e', default='embeddings.pt', type=pathlib.Path, help="File to save/load model embeddings to/from. Default: ./embeddings.pt") + parser.add_argument('--embeddings-file', '-e', default='image_embeddings.pt', type=pathlib.Path, help="File to save/load model embeddings to/from. Default: ./embeddings.pt") + parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings of Images in Image Directory . Default: false") parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5") parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true") parser.add_argument('--verbose', action='store_true', default=False, help="Show verbose conversion logs. Default: false") args = parser.parse_args() - # Resolve file, directory paths in args to absolute paths - embeddings_file = args.embeddings_file.expanduser().resolve() - image_directory = args.image_directory.expanduser().resolve(strict=True) - - # Initialize Model - model, count = initialize_model() - - # Extract Entries - image_names = extract_entries(image_directory, args.verbose) - - # Compute or Load Embeddings - image_embeddings = compute_embeddings(image_names, model, embeddings_file, args.verbose) + image_names, image_embeddings, model = setup(args.image_directory, args.embeddings_file, regenerate=args.regenerate) # Run User Queries on Entries in Interactive Mode while args.interactive: @@ -105,8 +130,8 @@ if __name__ == '__main__': if user_query == "exit": exit(0) - # query notes - hits = search(user_query, image_embeddings, model, args.results_count, args.verbose) + # query images + hits = query_images(user_query, image_embeddings, model, args.results_count, args.verbose) # render results - render_results(hits, image_names, image_directory, count=args.results_count) + render_results(hits, image_names, args.image_directory, count=args.results_count) diff --git a/src/utils/cli.py b/src/utils/cli.py index cd1519d3..6f8489a3 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -53,6 +53,10 @@ default_config = { { 'compressed-jsonl': '.transactions.jsonl.gz', 'embeddings-file': '.transaction_embeddings.pt' + }, + 'image': + { + 'embeddings-file': '.image_embeddings.pt' } }, 'search-type': @@ -61,6 +65,10 @@ default_config = { { 'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3", 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2" + }, + 'image': + { + 'encoder': "clip-ViT-B-32" } } } diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 2ed0bf36..b9f60ef0 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -9,6 +9,10 @@ def get_absolute_path(filepath): return str(pathlib.Path(filepath).expanduser().absolute()) +def resolve_absolute_path(filepath, strict=False): + return pathlib.Path(filepath).expanduser().absolute().resolve(strict=strict) + + def get_from_dict(dictionary, *args): '''null-aware get from a nested dictionary Returns: dictionary[args[0]][args[1]]... or None if any keys missing'''