From eda4b65ddb7f8e038d1364183624817439939f25 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 30 Jun 2022 00:59:57 +0400 Subject: [PATCH] Improve Query Speed. Normalize Embeddings, Moving them to Cuda GPU - Move embeddings to CUDA GPU for compute, when available - Normalize embeddings and Use Dot Product instead of Cosine --- src/main.py | 14 ++++++++------ src/search_type/asymmetric.py | 16 ++++++++++------ tests/conftest.py | 4 +++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/main.py b/src/main.py index 7be3ebe2..557dd29d 100644 --- a/src/main.py +++ b/src/main.py @@ -4,6 +4,7 @@ from typing import Optional # External Packages import uvicorn +import torch from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles @@ -24,6 +25,7 @@ processor_config = ProcessorConfigModel() config_file = "" verbose = 0 app = FastAPI() +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") app.mount("/views", StaticFiles(directory="views"), name="views") templates = Jinja2Templates(directory="views/") @@ -56,14 +58,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits = asymmetric.query(user_query, model.notes_search) + hits = asymmetric.query(user_query, model.notes_search, device=device) # collate and return results return asymmetric.collate_results(hits, model.notes_search.entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits = asymmetric.query(user_query, model.music_search) + hits = asymmetric.query(user_query, model.music_search, device=device) # collate and return results return asymmetric.collate_results(hits, model.music_search.entries, results_count) @@ -93,14 +95,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): @app.get('/reload') def regenerate(t: Optional[SearchType] = None): global model - model = initialize_search(config, regenerate=False, t=t) + model = initialize_search(config, regenerate=False, t=t, device=device) return {'status': 'ok', 'message': 'reload completed'} @app.get('/regenerate') def regenerate(t: Optional[SearchType] = None): global model - model = initialize_search(config, regenerate=True, t=t) + model = initialize_search(config, regenerate=True, t=t, device=device) return {'status': 'ok', 'message': 'regeneration completed'} @@ -149,12 +151,12 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None # Initialize Org Notes Search if (t == SearchType.Notes or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) + model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) + model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 4dd17f82..aa64e128 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -74,7 +74,7 @@ def extract_entries(notesfile, verbose=0): return entries -def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): +def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" # Load pre-computed embeddings from file if exists if resolve_absolute_path(embeddings_file).exists() and not regenerate: @@ -84,6 +84,8 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v else: # Else compute the corpus_embeddings from scratch, which can take a while corpus_embeddings = bi_encoder.encode([entry[0] for entry in entries], convert_to_tensor=True, show_progress_bar=True) + corpus_embeddings.to(device) + corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, get_absolute_path(embeddings_file)) if verbose > 0: print(f"Computed embeddings and saved them to {embeddings_file}") @@ -91,7 +93,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v return corpus_embeddings -def query(raw_query: str, model: TextSearchModel): +def query(raw_query: str, model: TextSearchModel, device='cpu'): "Search all notes for entries that answer the query" # Separate natural query from explicit required, blocked words filters query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) @@ -99,10 +101,12 @@ def query(raw_query: str, model: TextSearchModel): blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) # Encode the query using the bi-encoder - question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True) + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) + question_embedding.to(device) + question_embedding = util.normalize_embeddings(question_embedding) # Find relevant entries for the query - hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k) + hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score) hits = hits[0] # Get the hits for the first query # Filter out entries that contain required words and do not contain blocked words @@ -176,7 +180,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: +def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -189,7 +193,7 @@ def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, rege top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) diff --git a/tests/conftest.py b/tests/conftest.py index af0e9e36..7d9adfaa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ # Standard Packages import pytest +import torch # Internal Packages from src.search_type import asymmetric, image_search @@ -35,6 +36,7 @@ def search_config(tmp_path_factory): @pytest.fixture(scope='session') def model_dir(search_config): model_dir = search_config.asymmetric.model_directory + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Generate Image Embeddings from Test Images content_config = ContentConfig() @@ -53,7 +55,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, verbose=True) + asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True) return model_dir