Do not pass ML compute `device' around as argument to search funcs

- It is a non-user configurable, app state that is set on app start
- Reduce passing unneeded arguments around. Just set device where
  required by looking for ML compute device in global state
This commit is contained in:
Debanjum Singh Solanky
2022-08-20 14:14:42 +03:00
parent acc9091260
commit 82d2891765
4 changed files with 23 additions and 23 deletions

View File

@@ -53,16 +53,16 @@ def extract_entries(jsonl_file, verbose=0):
in load_jsonl(jsonl_file, verbose=verbose)]
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device)
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
if verbose > 0:
print(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True)
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file)
if verbose > 0:
@@ -71,7 +71,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
"Search for entries that answer the query"
query = raw_query
@@ -101,18 +101,18 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
# Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device)
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
if verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds")
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
# Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
if verbose > 1:
print(f"Search Time: {end - start:.3f} seconds")
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
# Score all retrieved entries using the cross-encoder
if rank_results:
@@ -121,7 +121,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
@@ -134,7 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds")
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
return hits, entries
@@ -167,7 +167,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
@@ -182,7 +182,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)