mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Improve Query Speed. Normalize Embeddings, Moving them to Cuda GPU
- Move embeddings to CUDA GPU for compute, when available - Normalize embeddings and Use Dot Product instead of Cosine
This commit is contained in:
14
src/main.py
14
src/main.py
@@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import torch
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
@@ -24,6 +25,7 @@ processor_config = ProcessorConfigModel()
|
|||||||
config_file = ""
|
config_file = ""
|
||||||
verbose = 0
|
verbose = 0
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
|
||||||
app.mount("/views", StaticFiles(directory="views"), name="views")
|
app.mount("/views", StaticFiles(directory="views"), name="views")
|
||||||
templates = Jinja2Templates(directory="views/")
|
templates = Jinja2Templates(directory="views/")
|
||||||
@@ -56,14 +58,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||||||
|
|
||||||
if (t == SearchType.Notes or t == None) and model.notes_search:
|
if (t == SearchType.Notes or t == None) and model.notes_search:
|
||||||
# query notes
|
# query notes
|
||||||
hits = asymmetric.query(user_query, model.notes_search)
|
hits = asymmetric.query(user_query, model.notes_search, device=device)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and model.music_search:
|
if (t == SearchType.Music or t == None) and model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
hits = asymmetric.query(user_query, model.music_search)
|
hits = asymmetric.query(user_query, model.music_search, device=device)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
||||||
@@ -93,14 +95,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||||||
@app.get('/reload')
|
@app.get('/reload')
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
global model
|
global model
|
||||||
model = initialize_search(config, regenerate=False, t=t)
|
model = initialize_search(config, regenerate=False, t=t, device=device)
|
||||||
return {'status': 'ok', 'message': 'reload completed'}
|
return {'status': 'ok', 'message': 'reload completed'}
|
||||||
|
|
||||||
|
|
||||||
@app.get('/regenerate')
|
@app.get('/regenerate')
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
global model
|
global model
|
||||||
model = initialize_search(config, regenerate=True, t=t)
|
model = initialize_search(config, regenerate=True, t=t, device=device)
|
||||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||||
|
|
||||||
|
|
||||||
@@ -149,12 +151,12 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
|
|||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == SearchType.Notes or t == None) and config.content_type.org:
|
if (t == SearchType.Notes or t == None) and config.content_type.org:
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||||
# Extract Entries, Generate Music Embeddings
|
# Extract Entries, Generate Music Embeddings
|
||||||
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ def extract_entries(notesfile, verbose=0):
|
|||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
# Load pre-computed embeddings from file if exists
|
# Load pre-computed embeddings from file if exists
|
||||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||||
@@ -84,6 +84,8 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
|||||||
|
|
||||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||||
corpus_embeddings = bi_encoder.encode([entry[0] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
corpus_embeddings = bi_encoder.encode([entry[0] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
||||||
|
corpus_embeddings.to(device)
|
||||||
|
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||||
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
|
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Computed embeddings and saved them to {embeddings_file}")
|
print(f"Computed embeddings and saved them to {embeddings_file}")
|
||||||
@@ -91,7 +93,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
|||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel):
|
def query(raw_query: str, model: TextSearchModel, device='cpu'):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
@@ -99,10 +101,12 @@ def query(raw_query: str, model: TextSearchModel):
|
|||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||||
|
question_embedding.to(device)
|
||||||
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
|
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)
|
||||||
hits = hits[0] # Get the hits for the first query
|
hits = hits[0] # Get the hits for the first query
|
||||||
|
|
||||||
# Filter out entries that contain required words and do not contain blocked words
|
# Filter out entries that contain required words and do not contain blocked words
|
||||||
@@ -176,7 +180,7 @@ def collate_results(hits, entries, count=5):
|
|||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
|
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||||
|
|
||||||
@@ -189,7 +193,7 @@ def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, rege
|
|||||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose)
|
||||||
|
|
||||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.search_type import asymmetric, image_search
|
from src.search_type import asymmetric, image_search
|
||||||
@@ -35,6 +36,7 @@ def search_config(tmp_path_factory):
|
|||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def model_dir(search_config):
|
def model_dir(search_config):
|
||||||
model_dir = search_config.asymmetric.model_directory
|
model_dir = search_config.asymmetric.model_directory
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
|
||||||
# Generate Image Embeddings from Test Images
|
# Generate Image Embeddings from Test Images
|
||||||
content_config = ContentConfig()
|
content_config = ContentConfig()
|
||||||
@@ -53,7 +55,7 @@ def model_dir(search_config):
|
|||||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||||
|
|
||||||
asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
|
asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True)
|
||||||
|
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user