mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Save Image Search Model to Disk
This commit is contained in:
@@ -10,16 +10,22 @@ from tqdm import trange
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
from src.utils.helpers import resolve_absolute_path, load_model
|
||||
import src.utils.exiftool as exiftool
|
||||
from src.utils.config import ImageSearchModel
|
||||
from src.utils.rawconfig import ImageSearchConfig
|
||||
from src.utils.rawconfig import ImageSearchConfig, ImageSearchTypeConfig
|
||||
|
||||
|
||||
def initialize_model():
|
||||
def initialize_model(search_config: ImageSearchTypeConfig):
|
||||
# Initialize Model
|
||||
torch.set_num_threads(4)
|
||||
encoder = SentenceTransformer('sentence-transformers/clip-ViT-B-32') #Load the CLIP model
|
||||
|
||||
# Load the CLIP model
|
||||
encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
@@ -154,9 +160,9 @@ def collate_results(hits, image_names, image_directory, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
|
||||
def setup(config: ImageSearchConfig, search_config: ImageSearchTypeConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model()
|
||||
encoder = initialize_model(search_config)
|
||||
|
||||
# Extract Entries
|
||||
image_directory = resolve_absolute_path(config.input_directory, strict=True)
|
||||
|
||||
Reference in New Issue
Block a user