diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index f6187e54..23a64459 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -22,6 +22,12 @@ def initialize_model(search_config: ImageSearchConfig): # Initialize Model torch.set_num_threads(4) + # Convert model directory to absolute path + search_config.model_directory = resolve_absolute_path(search_config.model_directory) + + # Create model directory if it doesn't exist + search_config.model_directory.parent.mkdir(parents=True, exist_ok=True) + # Load the CLIP model encoder = load_model( model_dir = search_config.model_directory, @@ -61,9 +67,8 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5 image_embeddings = torch.load(embeddings_file) if verbose > 0: print(f"Loaded pre-computed embeddings from {embeddings_file}") - # Else compute the image embeddings from scratch, which can take a while - if image_embeddings is None: + elif image_embeddings is None: image_embeddings = [] for index in trange(0, len(image_names), batch_size): images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]] @@ -71,6 +76,11 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5 images, convert_to_tensor=True, batch_size=min(len(images), batch_size)) + + # Create directory for embeddings file, if it doesn't exist + embeddings_file.parent.mkdir(parents=True, exist_ok=True) + + # Save computed image embeddings to file torch.save(image_embeddings, embeddings_file) if verbose > 0: print(f"Saved computed embeddings to {embeddings_file}") diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 73f48930..c446f31c 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -22,6 +22,12 @@ def initialize_model(search_config: TextSearchConfig): # Number of entries we want to retrieve with the bi-encoder top_k = 15 + # Convert model directory to absolute path + search_config.model_directory = resolve_absolute_path(search_config.model_directory) + + # Create model directory if it doesn't exist + search_config.model_directory.parent.mkdir(parents=True, exist_ok=True) + # The bi-encoder encodes all entries to use for semantic search bi_encoder = load_model( model_dir = search_config.model_directory, @@ -47,7 +53,7 @@ def extract_entries(jsonl_file, verbose=0): def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" # Load pre-computed embeddings from file if exists - if resolve_absolute_path(embeddings_file).exists() and not regenerate: + if embeddings_file.exists() and not regenerate: corpus_embeddings = torch.load(get_absolute_path(embeddings_file)) if verbose > 0: print(f"Loaded embeddings from {embeddings_file}") @@ -56,7 +62,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True) corpus_embeddings.to(device) corpus_embeddings = util.normalize_embeddings(corpus_embeddings) - torch.save(corpus_embeddings, get_absolute_path(embeddings_file)) + torch.save(corpus_embeddings, embeddings_file) if verbose > 0: print(f"Computed embeddings and saved them to {embeddings_file}") @@ -165,7 +171,8 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon bi_encoder, cross_encoder, top_k = initialize_model(search_config) # Map notes in text files to (compressed) JSONL formatted file - if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: + config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) + if not config.compressed_jsonl.exists() or regenerate: text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) # Extract Entries @@ -173,6 +180,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings + config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) diff --git a/src/utils/jsonl.py b/src/utils/jsonl.py index 67fc7c9a..873cdd39 100644 --- a/src/utils/jsonl.py +++ b/src/utils/jsonl.py @@ -35,7 +35,10 @@ def load_jsonl(input_path, verbose=0): def dump_jsonl(jsonl_data, output_path, verbose=0): "Write List of JSON objects to JSON line file" - with open(get_absolute_path(output_path), 'w', encoding='utf-8') as f: + # Create output directory, if it doesn't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: f.write(jsonl_data) if verbose > 0: @@ -43,7 +46,10 @@ def dump_jsonl(jsonl_data, output_path, verbose=0): def compress_jsonl_data(jsonl_data, output_path, verbose=0): - with gzip.open(get_absolute_path(output_path), 'wt') as gzip_file: + # Create output directory, if it doesn't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + with gzip.open(output_path, 'wt') as gzip_file: gzip_file.write(jsonl_data) if verbose > 0: