From 7de9c58a1c1e237acbebd4ed17b130c1461de01c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 13:18:31 +0300 Subject: [PATCH 1/5] Load models, corpus embeddings onto GPU device for text search, if available - Pass device to load models onto from app state. - SentenceTransformer models accept device to load models onto during initialization - Pass device to load corpus embeddings onto from app state --- src/search_type/text_search.py | 15 ++++++++------- src/utils/helpers.py | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index c446f31c..c53048b6 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -9,6 +9,7 @@ import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages +from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.utils.config import TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig @@ -32,13 +33,15 @@ def initialize_model(search_config: TextSearchConfig): bi_encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.encoder, - model_type = SentenceTransformer) + model_type = SentenceTransformer, + device=f'{state.device}') # The cross-encoder re-ranks the results to improve quality cross_encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.cross_encoder, - model_type = CrossEncoder) + model_type = CrossEncoder, + device=f'{state.device}') return bi_encoder, cross_encoder, top_k @@ -54,13 +57,12 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" # Load pre-computed embeddings from file if exists if embeddings_file.exists() and not regenerate: - corpus_embeddings = torch.load(get_absolute_path(embeddings_file)) + corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device) if verbose > 0: print(f"Loaded embeddings from {embeddings_file}") else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True) - corpus_embeddings.to(device) + corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True) corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) if verbose > 0: @@ -99,8 +101,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp # Encode the query using the bi-encoder start = time.time() - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) - question_embedding.to(device) + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device) question_embedding = util.normalize_embeddings(question_embedding) end = time.time() if verbose > 1: diff --git a/src/utils/helpers.py b/src/utils/helpers.py index e77e656d..66e9d8fc 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -41,17 +41,17 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name, model_dir, model_type): +def load_model(model_name, model_dir, model_type, device:str=None): "Load model from disk or huggingface" # Construct model path model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None # Load model from model_path if it exists there if model_path is not None and resolve_absolute_path(model_path).exists(): - model = model_type(get_absolute_path(model_path)) + model = model_type(get_absolute_path(model_path), device=device) # Else load the model from the model_name else: - model = model_type(model_name) + model = model_type(model_name, device=device) if model_path is not None: model.save(model_path) From acc909126003adde64bd283a7def30c3f91d365d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 13:21:21 +0300 Subject: [PATCH 2/5] Use MPS on Apple Mac M1 to GPU accelerate Encode, Query Performance - Note: Support for MPS in Pytorch is currently in v1.13.0 nightly builds - Users will have to wait for PyTorch MPS support to land in stable builds - Until then the code can be tweaked and tested to make use of the GPU acceleration on newer Macs --- src/utils/state.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/utils/state.py b/src/utils/state.py index 4f194331..b5c082d6 100644 --- a/src/utils/state.py +++ b/src/utils/state.py @@ -1,3 +1,5 @@ +# Standard Packages +from packaging import version # External Packages import torch from pathlib import Path @@ -12,7 +14,15 @@ model = SearchModels() processor_config = ProcessorConfigModel() config_file: Path = "" verbose: int = 0 -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available host: str = None port: int = None -cli_args = None \ No newline at end of file +cli_args = None + +if torch.cuda.is_available(): + # Use CUDA GPU + device = torch.device("cuda:0") +elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available(): + # Use Apple M1 Metal Acceleration + device = torch.device("mps") +else: + device = torch.device("cpu") From 82d2891765558aa83a4af453fc1b8635211c6166 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 14:14:42 +0300 Subject: [PATCH 3/5] Do not pass ML compute `device' around as argument to search funcs - It is a non-user configurable, app state that is set on app start - Reduce passing unneeded arguments around. Just set device where required by looking for ML compute device in global state --- src/configure.py | 10 +++++----- src/router.py | 12 ++++++------ src/search_type/text_search.py | 22 +++++++++++----------- tests/conftest.py | 2 +- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/configure.py b/src/configure.py index 2981dc9d..7cdb4e5f 100644 --- a/src/configure.py +++ b/src/configure.py @@ -27,27 +27,27 @@ def configure_server(args, required=False): state.config = args.config # Initialize the search model from Config - state.model = configure_search(state.model, state.config, args.regenerate, device=state.device, verbose=state.verbose) + state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose) # Initialize Processor from Config state.processor_config = configure_processor(args.config.processor, verbose=state.verbose) -def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu"), verbose: int = 0): +def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0): # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: diff --git a/src/router.py b/src/router.py index 6613a422..0be166c2 100644 --- a/src/router.py +++ b/src/router.py @@ -62,7 +62,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -73,7 +73,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -84,7 +84,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -95,7 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -131,13 +131,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti @router.get('/reload') def reload(t: Optional[SearchType] = None): - state.model = configure_search(state.model, state.config, regenerate=False, t=t, device=state.device) + state.model = configure_search(state.model, state.config, regenerate=False, t=t) return {'status': 'ok', 'message': 'reload completed'} @router.get('/regenerate') def regenerate(t: Optional[SearchType] = None): - state.model = configure_search(state.model, state.config, regenerate=True, t=t, device=state.device) + state.model = configure_search(state.model, state.config, regenerate=True, t=t) return {'status': 'ok', 'message': 'regeneration completed'} diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index c53048b6..2b47eabe 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -53,16 +53,16 @@ def extract_entries(jsonl_file, verbose=0): in load_jsonl(jsonl_file, verbose=verbose)] -def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0): +def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" # Load pre-computed embeddings from file if exists if embeddings_file.exists() and not regenerate: - corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device) + corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) if verbose > 0: print(f"Loaded embeddings from {embeddings_file}") else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True) + corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True) corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) if verbose > 0: @@ -71,7 +71,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0): +def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0): "Search for entries that answer the query" query = raw_query @@ -101,18 +101,18 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp # Encode the query using the bi-encoder start = time.time() - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device) + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = util.normalize_embeddings(question_embedding) end = time.time() if verbose > 1: - print(f"Query Encode Time: {end - start:.3f} seconds") + print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}") # Find relevant entries for the query start = time.time() hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] end = time.time() if verbose > 1: - print(f"Search Time: {end - start:.3f} seconds") + print(f"Search Time: {end - start:.3f} seconds on device: {state.device}") # Score all retrieved entries using the cross-encoder if rank_results: @@ -121,7 +121,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp cross_scores = model.cross_encoder.predict(cross_inp) end = time.time() if verbose > 1: - print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds") + print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): @@ -134,7 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score end = time.time() if verbose > 1: - print(f"Rank Time: {end - start:.3f} seconds") + print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") return hits, entries @@ -167,7 +167,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -182,7 +182,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) diff --git a/tests/conftest.py b/tests/conftest.py index 56610d45..f7c26b64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=state.device, verbose=True) + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True) return model_dir From 972523e8a9cd7f316c5044259f17eb4ac3e14dc1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 14:21:04 +0300 Subject: [PATCH 4/5] Re-enable tests for image search Verify if recent fixes resolve test flakiness --- tests/conftest.py | 24 ++++++++++++------------ tests/test_client.py | 1 - tests/test_image_search.py | 2 -- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f7c26b64..b70deb87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,14 +39,14 @@ def model_dir(search_config): model_dir = search_config.asymmetric.model_directory # Generate Image Embeddings from Test Images - # content_config = ContentConfig() - # content_config.image = ImageContentConfig( - # input_directories = ['tests/data/images'], - # embeddings_file = model_dir.joinpath('image_embeddings.pt'), - # batch_size = 10, - # use_xmp_metadata = False) + content_config = ContentConfig() + content_config.image = ImageContentConfig( + input_directories = ['tests/data/images'], + embeddings_file = model_dir.joinpath('image_embeddings.pt'), + batch_size = 10, + use_xmp_metadata = False) - # image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) + image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( @@ -69,10 +69,10 @@ def content_config(model_dir): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - # content_config.image = ImageContentConfig( - # input_directories = ['tests/data/images'], - # embeddings_file = model_dir.joinpath('image_embeddings.pt'), - # batch_size = 10, - # use_xmp_metadata = False) + content_config.image = ImageContentConfig( + input_directories = ['tests/data/images'], + embeddings_file = model_dir.joinpath('image_embeddings.pt'), + batch_size = 1, + use_xmp_metadata = False) return content_config \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index 85aad8d7..38b98c1f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -90,7 +90,6 @@ def test_regenerate_with_valid_content_type(content_config: ContentConfig, searc # ---------------------------------------------------------------------------------------------------- -@pytest.mark.skip(reason="Flaky test. Search doesn't always return expected image path.") def test_image_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange config.content_type = content_config diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 4eb52048..80c4fdf6 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -15,7 +15,6 @@ from src.utils.rawconfig import ContentConfig, SearchConfig # Test # ---------------------------------------------------------------------------------------------------- -@pytest.mark.skip(reason="upstream issues in loading image search model. disabled for now") def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate image search embeddings during image setup @@ -27,7 +26,6 @@ def test_image_search_setup(content_config: ContentConfig, search_config: Search # ---------------------------------------------------------------------------------------------------- -@pytest.mark.skip(reason="results inconsistent currently") def test_image_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange output_directory = resolve_absolute_path(web_directory) From e6abe76875352b4f72d21e3406f285db2fc56665 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 14:45:43 +0300 Subject: [PATCH 5/5] Upgrade torch, torchvision package versions --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a8640aa3..bf9f3ce9 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,8 @@ setup( ), install_requires=[ "numpy == 1.22.4", - "torch == 1.11.0", - "torchvision == 0.12.0", + "torch == 1.12.1", + "torchvision == 0.13.1", "transformers == 4.21.0", "sentence-transformers == 2.1.0", "openai == 0.20.0",