mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Use MPS, CUDA to GPU Accelerate Query Performance
- Load Models and Embeddings onto GPU if available - Use MPS for GPU acceleration when available - Note: Support for [MPS](https://developer.apple.com/metal/) in Pytorch is currently in v1.13.0 nightly builds. See [Announcement](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/) - Users will have to wait for PyTorch MPS support to land in stable builds - Until then code can be tuned and tested for GPU acceleration on newer Macs - Re-enable Tests for Image Search
This commit is contained in:
4
setup.py
4
setup.py
@@ -24,8 +24,8 @@ setup(
|
||||
),
|
||||
install_requires=[
|
||||
"numpy == 1.22.4",
|
||||
"torch == 1.11.0",
|
||||
"torchvision == 0.12.0",
|
||||
"torch == 1.12.1",
|
||||
"torchvision == 0.13.1",
|
||||
"transformers == 4.21.0",
|
||||
"sentence-transformers == 2.1.0",
|
||||
"openai == 0.20.0",
|
||||
|
||||
@@ -27,27 +27,27 @@ def configure_server(args, required=False):
|
||||
state.config = args.config
|
||||
|
||||
# Initialize the search model from Config
|
||||
state.model = configure_search(state.model, state.config, args.regenerate, device=state.device, verbose=state.verbose)
|
||||
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu"), verbose: int = 0):
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0):
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
|
||||
@@ -62,7 +62,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -73,7 +73,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -84,7 +84,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -95,7 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -131,13 +131,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
|
||||
@router.get('/reload')
|
||||
def reload(t: Optional[SearchType] = None):
|
||||
state.model = configure_search(state.model, state.config, regenerate=False, t=t, device=state.device)
|
||||
state.model = configure_search(state.model, state.config, regenerate=False, t=t)
|
||||
return {'status': 'ok', 'message': 'reload completed'}
|
||||
|
||||
|
||||
@router.get('/regenerate')
|
||||
def regenerate(t: Optional[SearchType] = None):
|
||||
state.model = configure_search(state.model, state.config, regenerate=True, t=t, device=state.device)
|
||||
state.model = configure_search(state.model, state.config, regenerate=True, t=t)
|
||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
@@ -32,13 +33,15 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
bi_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
model_type = SentenceTransformer,
|
||||
device=f'{state.device}')
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
cross_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.cross_encoder,
|
||||
model_type = CrossEncoder)
|
||||
model_type = CrossEncoder,
|
||||
device=f'{state.device}')
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
@@ -50,17 +53,16 @@ def extract_entries(jsonl_file, verbose=0):
|
||||
in load_jsonl(jsonl_file, verbose=verbose)]
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file))
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
if verbose > 0:
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
||||
corpus_embeddings.to(device)
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
if verbose > 0:
|
||||
@@ -69,7 +71,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
@@ -99,19 +101,18 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||
question_embedding.to(device)
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Find relevant entries for the query
|
||||
start = time.time()
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Search Time: {end - start:.3f} seconds")
|
||||
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
@@ -120,7 +121,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
@@ -133,7 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Rank Time: {end - start:.3f} seconds")
|
||||
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
return hits, entries
|
||||
|
||||
@@ -166,7 +167,7 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
@@ -181,7 +182,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
|
||||
|
||||
@@ -41,17 +41,17 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
||||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name, model_dir, model_type):
|
||||
def load_model(model_name, model_dir, model_type, device:str=None):
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||
|
||||
# Load model from model_path if it exists there
|
||||
if model_path is not None and resolve_absolute_path(model_path).exists():
|
||||
model = model_type(get_absolute_path(model_path))
|
||||
model = model_type(get_absolute_path(model_path), device=device)
|
||||
# Else load the model from the model_name
|
||||
else:
|
||||
model = model_type(model_name)
|
||||
model = model_type(model_name, device=device)
|
||||
if model_path is not None:
|
||||
model.save(model_path)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Standard Packages
|
||||
from packaging import version
|
||||
# External Packages
|
||||
import torch
|
||||
from pathlib import Path
|
||||
@@ -12,7 +14,15 @@ model = SearchModels()
|
||||
processor_config = ProcessorConfigModel()
|
||||
config_file: Path = ""
|
||||
verbose: int = 0
|
||||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available
|
||||
host: str = None
|
||||
port: int = None
|
||||
cli_args = None
|
||||
cli_args = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use CUDA GPU
|
||||
device = torch.device("cuda:0")
|
||||
elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available():
|
||||
# Use Apple M1 Metal Acceleration
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
@@ -39,14 +39,14 @@ def model_dir(search_config):
|
||||
model_dir = search_config.asymmetric.model_directory
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
# content_config = ContentConfig()
|
||||
# content_config.image = ImageContentConfig(
|
||||
# input_directories = ['tests/data/images'],
|
||||
# embeddings_file = model_dir.joinpath('image_embeddings.pt'),
|
||||
# batch_size = 10,
|
||||
# use_xmp_metadata = False)
|
||||
content_config = ContentConfig()
|
||||
content_config.image = ImageContentConfig(
|
||||
input_directories = ['tests/data/images'],
|
||||
embeddings_file = model_dir.joinpath('image_embeddings.pt'),
|
||||
batch_size = 10,
|
||||
use_xmp_metadata = False)
|
||||
|
||||
# image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
|
||||
image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
@@ -55,7 +55,7 @@ def model_dir(search_config):
|
||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=state.device, verbose=True)
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
|
||||
|
||||
return model_dir
|
||||
|
||||
@@ -69,10 +69,10 @@ def content_config(model_dir):
|
||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
# content_config.image = ImageContentConfig(
|
||||
# input_directories = ['tests/data/images'],
|
||||
# embeddings_file = model_dir.joinpath('image_embeddings.pt'),
|
||||
# batch_size = 10,
|
||||
# use_xmp_metadata = False)
|
||||
content_config.image = ImageContentConfig(
|
||||
input_directories = ['tests/data/images'],
|
||||
embeddings_file = model_dir.joinpath('image_embeddings.pt'),
|
||||
batch_size = 1,
|
||||
use_xmp_metadata = False)
|
||||
|
||||
return content_config
|
||||
@@ -90,7 +90,6 @@ def test_regenerate_with_valid_content_type(content_config: ContentConfig, searc
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="Flaky test. Search doesn't always return expected image path.")
|
||||
def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
config.content_type = content_config
|
||||
|
||||
@@ -15,7 +15,6 @@ from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="upstream issues in loading image search model. disabled for now")
|
||||
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Act
|
||||
# Regenerate image search embeddings during image setup
|
||||
@@ -27,7 +26,6 @@ def test_image_search_setup(content_config: ContentConfig, search_config: Search
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="results inconsistent currently")
|
||||
def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
|
||||
Reference in New Issue
Block a user