mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Do not pass ML compute `device' around as argument to search funcs
- It is a non-user configurable, app state that is set on app start - Reduce passing unneeded arguments around. Just set device where required by looking for ML compute device in global state
This commit is contained in:
@@ -27,27 +27,27 @@ def configure_server(args, required=False):
|
|||||||
state.config = args.config
|
state.config = args.config
|
||||||
|
|
||||||
# Initialize the search model from Config
|
# Initialize the search model from Config
|
||||||
state.model = configure_search(state.model, state.config, args.regenerate, device=state.device, verbose=state.verbose)
|
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
|
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
|
||||||
|
|
||||||
|
|
||||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu"), verbose: int = 0):
|
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0):
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||||
# Extract Entries, Generate Music Embeddings
|
# Extract Entries, Generate Music Embeddings
|
||||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -73,7 +73,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -84,7 +84,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||||
# query markdown files
|
# query markdown files
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -95,7 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
@@ -131,13 +131,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
|
|
||||||
@router.get('/reload')
|
@router.get('/reload')
|
||||||
def reload(t: Optional[SearchType] = None):
|
def reload(t: Optional[SearchType] = None):
|
||||||
state.model = configure_search(state.model, state.config, regenerate=False, t=t, device=state.device)
|
state.model = configure_search(state.model, state.config, regenerate=False, t=t)
|
||||||
return {'status': 'ok', 'message': 'reload completed'}
|
return {'status': 'ok', 'message': 'reload completed'}
|
||||||
|
|
||||||
|
|
||||||
@router.get('/regenerate')
|
@router.get('/regenerate')
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
state.model = configure_search(state.model, state.config, regenerate=True, t=t, device=state.device)
|
state.model = configure_search(state.model, state.config, regenerate=True, t=t)
|
||||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -53,16 +53,16 @@ def extract_entries(jsonl_file, verbose=0):
|
|||||||
in load_jsonl(jsonl_file, verbose=verbose)]
|
in load_jsonl(jsonl_file, verbose=verbose)]
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
|
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
# Load pre-computed embeddings from file if exists
|
# Load pre-computed embeddings from file if exists
|
||||||
if embeddings_file.exists() and not regenerate:
|
if embeddings_file.exists() and not regenerate:
|
||||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device)
|
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Loaded embeddings from {embeddings_file}")
|
print(f"Loaded embeddings from {embeddings_file}")
|
||||||
|
|
||||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True)
|
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||||
torch.save(corpus_embeddings, embeddings_file)
|
torch.save(corpus_embeddings, embeddings_file)
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
@@ -71,7 +71,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
|||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
|
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query = raw_query
|
query = raw_query
|
||||||
|
|
||||||
@@ -101,18 +101,18 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
|||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
start = time.time()
|
start = time.time()
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if verbose > 1:
|
if verbose > 1:
|
||||||
print(f"Query Encode Time: {end - start:.3f} seconds")
|
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
start = time.time()
|
start = time.time()
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if verbose > 1:
|
if verbose > 1:
|
||||||
print(f"Search Time: {end - start:.3f} seconds")
|
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
@@ -121,7 +121,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
|||||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if verbose > 1:
|
if verbose > 1:
|
||||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
@@ -134,7 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
|||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if verbose > 1:
|
if verbose > 1:
|
||||||
print(f"Rank Time: {end - start:.3f} seconds")
|
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
@@ -167,7 +167,7 @@ def collate_results(hits, entries, count=5):
|
|||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
|
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||||
|
|
||||||
@@ -182,7 +182,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
|||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def model_dir(search_config):
|
|||||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||||
|
|
||||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=state.device, verbose=True)
|
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
|
||||||
|
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user