mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Improve Query Speed. Normalize Embeddings, Moving them to Cuda GPU
- Move embeddings to CUDA GPU for compute, when available - Normalize embeddings and Use Dot Product instead of Cosine
This commit is contained in:
@@ -74,7 +74,7 @@ def extract_entries(notesfile, verbose=0):
|
||||
return entries
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||
@@ -84,6 +84,8 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry[0] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
||||
corpus_embeddings.to(device)
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
|
||||
if verbose > 0:
|
||||
print(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
@@ -91,7 +93,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel):
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu'):
|
||||
"Search all notes for entries that answer the query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
@@ -99,10 +101,12 @@ def query(raw_query: str, model: TextSearchModel):
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||
question_embedding.to(device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
|
||||
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)
|
||||
hits = hits[0] # Get the hits for the first query
|
||||
|
||||
# Filter out entries that contain required words and do not contain blocked words
|
||||
@@ -176,7 +180,7 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
|
||||
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
@@ -189,7 +193,7 @@ def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, rege
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user