From 82d2891765558aa83a4af453fc1b8635211c6166 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 14:14:42 +0300 Subject: [PATCH] Do not pass ML compute `device' around as argument to search funcs - It is a non-user configurable, app state that is set on app start - Reduce passing unneeded arguments around. Just set device where required by looking for ML compute device in global state --- src/configure.py | 10 +++++----- src/router.py | 12 ++++++------ src/search_type/text_search.py | 22 +++++++++++----------- tests/conftest.py | 2 +- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/configure.py b/src/configure.py index 2981dc9d..7cdb4e5f 100644 --- a/src/configure.py +++ b/src/configure.py @@ -27,27 +27,27 @@ def configure_server(args, required=False): state.config = args.config # Initialize the search model from Config - state.model = configure_search(state.model, state.config, args.regenerate, device=state.device, verbose=state.verbose) + state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose) # Initialize Processor from Config state.processor_config = configure_processor(args.config.processor, verbose=state.verbose) -def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu"), verbose: int = 0): +def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0): # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: diff --git a/src/router.py b/src/router.py index 6613a422..0be166c2 100644 --- a/src/router.py +++ b/src/router.py @@ -62,7 +62,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -73,7 +73,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -84,7 +84,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -95,7 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -131,13 +131,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti @router.get('/reload') def reload(t: Optional[SearchType] = None): - state.model = configure_search(state.model, state.config, regenerate=False, t=t, device=state.device) + state.model = configure_search(state.model, state.config, regenerate=False, t=t) return {'status': 'ok', 'message': 'reload completed'} @router.get('/regenerate') def regenerate(t: Optional[SearchType] = None): - state.model = configure_search(state.model, state.config, regenerate=True, t=t, device=state.device) + state.model = configure_search(state.model, state.config, regenerate=True, t=t) return {'status': 'ok', 'message': 'regeneration completed'} diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index c53048b6..2b47eabe 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -53,16 +53,16 @@ def extract_entries(jsonl_file, verbose=0): in load_jsonl(jsonl_file, verbose=verbose)] -def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0): +def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" # Load pre-computed embeddings from file if exists if embeddings_file.exists() and not regenerate: - corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device) + corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) if verbose > 0: print(f"Loaded embeddings from {embeddings_file}") else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True) + corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True) corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) if verbose > 0: @@ -71,7 +71,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0): +def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0): "Search for entries that answer the query" query = raw_query @@ -101,18 +101,18 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp # Encode the query using the bi-encoder start = time.time() - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device) + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = util.normalize_embeddings(question_embedding) end = time.time() if verbose > 1: - print(f"Query Encode Time: {end - start:.3f} seconds") + print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}") # Find relevant entries for the query start = time.time() hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] end = time.time() if verbose > 1: - print(f"Search Time: {end - start:.3f} seconds") + print(f"Search Time: {end - start:.3f} seconds on device: {state.device}") # Score all retrieved entries using the cross-encoder if rank_results: @@ -121,7 +121,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp cross_scores = model.cross_encoder.predict(cross_inp) end = time.time() if verbose > 1: - print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds") + print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): @@ -134,7 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score end = time.time() if verbose > 1: - print(f"Rank Time: {end - start:.3f} seconds") + print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") return hits, entries @@ -167,7 +167,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -182,7 +182,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) diff --git a/tests/conftest.py b/tests/conftest.py index 56610d45..f7c26b64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=state.device, verbose=True) + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True) return model_dir