Improve Query Speed. Normalize Embeddings, Moving them to Cuda GPU

- Move embeddings to CUDA GPU for compute, when available
- Normalize embeddings and Use Dot Product instead of Cosine
This commit is contained in:
Debanjum Singh Solanky
2022-06-30 00:59:57 +04:00
parent 2f7ef08b11
commit eda4b65ddb
3 changed files with 21 additions and 13 deletions

View File

@@ -4,6 +4,7 @@ from typing import Optional
# External Packages
import uvicorn
import torch
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
@@ -24,6 +25,7 @@ processor_config = ProcessorConfigModel()
config_file = ""
verbose = 0
app = FastAPI()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
app.mount("/views", StaticFiles(directory="views"), name="views")
templates = Jinja2Templates(directory="views/")
@@ -56,14 +58,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and model.notes_search:
# query notes
hits = asymmetric.query(user_query, model.notes_search)
hits = asymmetric.query(user_query, model.notes_search, device=device)
# collate and return results
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
if (t == SearchType.Music or t == None) and model.music_search:
# query music library
hits = asymmetric.query(user_query, model.music_search)
hits = asymmetric.query(user_query, model.music_search, device=device)
# collate and return results
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
@@ -93,14 +95,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
@app.get('/reload')
def regenerate(t: Optional[SearchType] = None):
global model
model = initialize_search(config, regenerate=False, t=t)
model = initialize_search(config, regenerate=False, t=t, device=device)
return {'status': 'ok', 'message': 'reload completed'}
@app.get('/regenerate')
def regenerate(t: Optional[SearchType] = None):
global model
model = initialize_search(config, regenerate=True, t=t)
model = initialize_search(config, regenerate=True, t=t, device=device)
return {'status': 'ok', 'message': 'regeneration completed'}
@@ -149,12 +151,12 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
# Initialize Org Notes Search
if (t == SearchType.Notes or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: