mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Consolidate the search config models and pass verbose as a top level flag
This commit is contained in:
@@ -14,7 +14,8 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.utils.config import TextSearchModel, TextSearchConfig
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfigModel
|
||||
|
||||
|
||||
def initialize_model():
|
||||
@@ -148,22 +149,22 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
||||
|
||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose)
|
||||
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
|
||||
|
||||
# Extract Entries
|
||||
entries = extract_entries(config.compressed_jsonl, config.verbose)
|
||||
entries = extract_entries(config.compressed_jsonl, verbose)
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -10,9 +10,10 @@ from tqdm import trange
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
import src.utils.exiftool as exiftool
|
||||
from src.utils.config import ImageSearchModel, ImageSearchConfig
|
||||
from src.utils.config import ImageSearchModel
|
||||
from src.utils.rawconfig import ImageSearchConfigModel
|
||||
|
||||
|
||||
def initialize_model():
|
||||
@@ -153,7 +154,7 @@ def collate_results(hits, image_names, image_directory, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
def setup(config: ImageSearchConfigModel, regenerate: bool, verbose: bool) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model()
|
||||
|
||||
@@ -170,13 +171,13 @@ def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
batch_size=config.batch_size,
|
||||
regenerate=regenerate,
|
||||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
verbose=config.verbose)
|
||||
verbose=verbose)
|
||||
|
||||
return ImageSearchModel(image_names,
|
||||
image_embeddings,
|
||||
image_metadata_embeddings,
|
||||
encoder,
|
||||
config.verbose)
|
||||
verbose)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -12,7 +12,8 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||
from src.utils.config import TextSearchModel, TextSearchConfig
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfigModel
|
||||
|
||||
|
||||
def initialize_model():
|
||||
@@ -140,7 +141,7 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
def setup(config: TextSearchConfigModel, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
||||
|
||||
@@ -153,9 +154,9 @@ def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
top_k = min(len(entries), top_k)
|
||||
|
||||
# Compute or Load Embeddings
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose)
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user