diff --git a/src/configure.py b/src/configure.py index 1ddc9f76..c1132e8c 100644 --- a/src/configure.py +++ b/src/configure.py @@ -34,14 +34,14 @@ def configure_server(args, required=False): else: state.config = args.config + # Initialize Processor from Config + state.processor_config = configure_processor(args.config.processor) + # Initialize the search model from Config state.search_index_lock.acquire() state.model = configure_search(state.model, state.config, args.regenerate) state.search_index_lock.release() - # Initialize Processor from Config - state.processor_config = configure_processor(args.config.processor) - @schedule.repeat(schedule.every(1).hour) def update_search_index(): diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index 0eb60e6c..4d88a612 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -24,11 +24,13 @@ class TextToJsonl(ABC): return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() @staticmethod - def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]: + def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256, max_word_length: int=500) -> list[Entry]: "Split entries if compiled entry length exceeds the max tokens supported by the ML model." chunked_entries: list[Entry] = [] for entry in entries: compiled_entry_words = entry.compiled.split() + # Drop long words instead of having entry truncated to maintain quality of entry processed by models + compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] for chunk_index in range(0, len(compiled_entry_words), max_tokens): compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens] compiled_entry_chunk = ' '.join(compiled_entry_words_chunk) diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index e04bbe49..10d429c3 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -36,7 +36,7 @@ def initialize_model(search_config: ImageSearchConfig): encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.encoder, - model_type = SentenceTransformer) + model_type = search_config.encoder_type or SentenceTransformer) return encoder diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 53eb3c3d..816972cc 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -27,17 +27,18 @@ def initialize_model(search_config: TextSearchConfig): # Number of entries we want to retrieve with the bi-encoder top_k = 15 - # Convert model directory to absolute path - search_config.model_directory = resolve_absolute_path(search_config.model_directory) - - # Create model directory if it doesn't exist - search_config.model_directory.parent.mkdir(parents=True, exist_ok=True) + # If model directory is configured + if search_config.model_directory: + # Convert model directory to absolute path + search_config.model_directory = resolve_absolute_path(search_config.model_directory) + # Create model directory if it doesn't exist + search_config.model_directory.parent.mkdir(parents=True, exist_ok=True) # The bi-encoder encodes all entries to use for semantic search bi_encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.encoder, - model_type = SentenceTransformer, + model_type = search_config.encoder_type or SentenceTransformer, device=f'{state.device}') # The cross-encoder re-ranks the results to improve quality diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 2285018d..38181ad6 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,5 +1,6 @@ # Standard Packages from pathlib import Path +from importlib import import_module import sys from os.path import join from collections import OrderedDict @@ -44,17 +45,18 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name, model_dir, model_type, device:str=None): +def load_model(model_name: str, model_type, model_dir=None, device:str=None): "Load model from disk or huggingface" # Construct model path model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None # Load model from model_path if it exists there + model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type if model_path is not None and resolve_absolute_path(model_path).exists(): - model = model_type(get_absolute_path(model_path), device=device) + model = model_type_class(get_absolute_path(model_path), device=device) # Else load the model from the model_name else: - model = model_type(model_name, device=device) + model = model_type_class(model_name, device=device) if model_path is not None: model.save(model_path) @@ -66,6 +68,12 @@ def is_pyinstaller_app(): return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') +def get_class_by_name(name: str) -> object: + "Returns the class object from name string" + module_name, class_name = name.rsplit('.', 1) + return getattr(import_module(module_name), class_name) + + class LRU(OrderedDict): def __init__(self, *args, capacity=128, **kwargs): self.capacity = capacity diff --git a/src/utils/models.py b/src/utils/models.py new file mode 100644 index 00000000..9ba204b7 --- /dev/null +++ b/src/utils/models.py @@ -0,0 +1,38 @@ +# External Packages +import openai +import torch +from tqdm import trange + +# Internal Packages +from src.utils.state import processor_config, config_file + + +class OpenAI: + def __init__(self, model_name, device=None): + self.model_name = model_name + if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key: + raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}") + openai.api_key = processor_config.conversation.openai_api_key + self.embedding_dimensions = None + + def encode(self, entries: list[str], device=None, **kwargs): + embedding_tensors = [] + + for index in trange(0, len(entries)): + # OpenAI models create better embeddings for entries without newlines + processed_entry = entries[index].replace('\n', ' ') + + try: + response = openai.Embedding.create(input=processed_entry, model=self.model_name) + embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)] + # Use current models embedding dimension, once available + # Else default to embedding dimensions of the text-embedding-ada-002 model + self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536 + except Exception as e: + print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}") + # Use zero embedding vector for entries with failed embeddings + # This ensures entry embeddings match the order of the source entries + # And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector) + embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)] + + return torch.stack(embedding_tensors) \ No newline at end of file diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 165be0d1..5ed3a9eb 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -50,10 +50,12 @@ class ContentConfig(ConfigBase): class TextSearchConfig(ConfigBase): encoder: str cross_encoder: str + encoder_type: Optional[str] model_directory: Optional[Path] class ImageSearchConfig(ConfigBase): encoder: str + encoder_type: Optional[str] model_directory: Optional[Path] class SearchConfig(ConfigBase): diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index fe64cc67..3f30b7fc 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -5,6 +5,7 @@ import json from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.processor.text_to_jsonl import TextToJsonl from src.utils.helpers import is_none_or_empty +from src.utils.rawconfig import Entry def test_configure_heading_entry_to_jsonl(tmp_path): @@ -61,6 +62,24 @@ def test_entry_split_when_exceeds_max_words(tmp_path): assert len(jsonl_data) == 2 +def test_entry_split_drops_large_words(tmp_path): + "Ensure entries drops words larger than specified max word length from compiled version." + # Arrange + entry_text = f'''*** Heading + \t\r + Body Line 1 + ''' + entry = Entry(raw=entry_text, compiled=entry_text) + + # Act + # Split entry by max words and drop words larger than max word length + processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length = 5)[0] + + # Assert + # "Heading" dropped from compiled version because its over the set max word limit + assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1 + + def test_entry_with_body_to_jsonl(tmp_path): "Ensure entries with valid body text are loaded." # Arrange