mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 05:39:06 +00:00
Modularize Code. Wrap Search, Model Config in Classes. Add Tests
Details
- Rename method query_* to query in search_types for standardization
- Wrapping Config code in classes simplified mocking test config
- Reduce args beings passed to a function by passing it as single
argument wrapped in a class
- Minimize setup in main.py:__main__. Put most of it into functions
These functions can be mocked if required in tests later too
Setup Flow:
CLI_Args|Config_YAML -> (Text|Image)SearchConfig -> (Text|Image)SearchModel
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
# System Packages
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
# Internal Packages
|
||||
from utils.helpers import get_from_dict
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
@@ -10,43 +14,82 @@ class SearchType(str, Enum):
|
||||
Image = "image"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchSettings():
|
||||
notes_search_enabled: bool = False
|
||||
ledger_search_enabled: bool = False
|
||||
music_search_enabled: bool = False
|
||||
image_search_enabled: bool = False
|
||||
|
||||
|
||||
class AsymmetricSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.top_k = top_k
|
||||
|
||||
|
||||
class LedgerSearchModel():
|
||||
def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k):
|
||||
self.transactions = transactions
|
||||
self.transaction_embeddings = transaction_embeddings
|
||||
self.symmetric_encoder = symmetric_encoder
|
||||
self.symmetric_cross_encoder = symmetric_cross_encoder
|
||||
self.top_k = top_k
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
class ImageSearchModel():
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels():
|
||||
notes_search: AsymmetricSearchModel = None
|
||||
ledger_search: LedgerSearchModel = None
|
||||
music_search: AsymmetricSearchModel = None
|
||||
notes_search: TextSearchModel = None
|
||||
ledger_search: TextSearchModel = None
|
||||
music_search: TextSearchModel = None
|
||||
image_search: ImageSearchModel = None
|
||||
|
||||
|
||||
class TextSearchConfig():
|
||||
def __init__(self, input_files, input_filter, compressed_jsonl, embeddings_file, verbose):
|
||||
self.input_files = input_files
|
||||
self.input_filter = input_filter
|
||||
self.compressed_jsonl = Path(compressed_jsonl)
|
||||
self.embeddings_file = Path(embeddings_file)
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
def create_from_dictionary(config, key_tree, verbose):
|
||||
text_config = get_from_dict(config, *key_tree)
|
||||
search_enabled = text_config and ('input-files' in text_config or 'input-filter' in text_config)
|
||||
if not search_enabled:
|
||||
return None
|
||||
|
||||
return TextSearchConfig(
|
||||
input_files = text_config['input-files'],
|
||||
input_filter = text_config['input-filter'],
|
||||
compressed_jsonl = Path(text_config['compressed-jsonl']),
|
||||
embeddings_file = Path(text_config['embeddings-file']),
|
||||
verbose = verbose)
|
||||
|
||||
|
||||
class ImageSearchConfig():
|
||||
def __init__(self, input_directory, embeddings_file, batch_size, use_xmp_metadata, verbose):
|
||||
self.input_directory = input_directory
|
||||
self.embeddings_file = Path(embeddings_file)
|
||||
self.batch_size = batch_size
|
||||
self.use_xmp_metadata = use_xmp_metadata
|
||||
self.verbose = verbose
|
||||
|
||||
def create_from_dictionary(config, key_tree, verbose):
|
||||
image_config = get_from_dict(config, *key_tree)
|
||||
search_enabled = image_config and 'input-directory' in image_config
|
||||
if not search_enabled:
|
||||
return None
|
||||
|
||||
return ImageSearchConfig(
|
||||
input_directory = Path(image_config['input-directory']),
|
||||
embeddings_file = Path(image_config['embeddings-file']),
|
||||
batch_size = image_config['batch-size'],
|
||||
use_xmp_metadata = {'yes': True, 'no': False}[image_config['use-xmp-metadata']],
|
||||
verbose = verbose)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchConfig():
|
||||
notes: TextSearchConfig = None
|
||||
ledger: TextSearchConfig = None
|
||||
music: TextSearchConfig = None
|
||||
image: ImageSearchConfig = None
|
||||
|
||||
Reference in New Issue
Block a user