Make Encoder Type Configurable. Allow using OpenAI Model for Search

- 2fe37a0 Make type of encoder to use for embeddings configurable via `khoj.yml'
  - Previously `encoder_type' was set in the setup code of search_type
    - All *encoders* were of type `SentenceTransformer'
    - All *cross_encoders* were of type `CrossEncoder'
  - Now the `encoder_type' can be configured via the new `encoder_type' field 
    in `TextSearchConfig' under `search_type` in `khoj.yml'
  - All the specified `encoder-type' class needs is an `encode' method
    that takes entries and returns embedding vectors
  
- 826f9dc Drop long words from compiled entries to be within max token limit of models
  Long words (>500 characters) provide less useful context to models.
   
  Dropping very long words allow models to create better embeddings by
  passing more of the useful context from the entry to the model

- c0ae8ee Allow using OpenAI models for search in Khoj
  To use OpenAI models for search in Khoj, in `~/.khoj/khoj.yml'
  1. Set `encoder' to name of an OpenAI model. E.g *text-embedding-ada-002*
  2. Set `encoder-type' to *src.utils.models.OpenAI*
  3. Set `model-directory` to *null*, as this is an online model and
     cannot be stored on the file system
This commit is contained in:
Debanjum
2023-01-08 11:10:25 -03:00
committed by GitHub
8 changed files with 84 additions and 14 deletions

View File

@@ -34,14 +34,14 @@ def configure_server(args, required=False):
else:
state.config = args.config
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor)
# Initialize the search model from Config
state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, args.regenerate)
state.search_index_lock.release()
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor)
@schedule.repeat(schedule.every(1).hour)
def update_search_index():

View File

@@ -24,11 +24,13 @@ class TextToJsonl(ABC):
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
@staticmethod
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]:
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256, max_word_length: int=500) -> list[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: list[Entry] = []
for entry in entries:
compiled_entry_words = entry.compiled.split()
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens]
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)

View File

@@ -36,7 +36,7 @@ def initialize_model(search_config: ImageSearchConfig):
encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
model_type = search_config.encoder_type or SentenceTransformer)
return encoder

View File

@@ -27,17 +27,18 @@ def initialize_model(search_config: TextSearchConfig):
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# If model directory is configured
if search_config.model_directory:
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer,
model_type = search_config.encoder_type or SentenceTransformer,
device=f'{state.device}')
# The cross-encoder re-ranks the results to improve quality

View File

@@ -1,5 +1,6 @@
# Standard Packages
from pathlib import Path
from importlib import import_module
import sys
from os.path import join
from collections import OrderedDict
@@ -44,17 +45,18 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name, model_dir, model_type, device:str=None):
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
# Load model from model_path if it exists there
model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type
if model_path is not None and resolve_absolute_path(model_path).exists():
model = model_type(get_absolute_path(model_path), device=device)
model = model_type_class(get_absolute_path(model_path), device=device)
# Else load the model from the model_name
else:
model = model_type(model_name, device=device)
model = model_type_class(model_name, device=device)
if model_path is not None:
model.save(model_path)
@@ -66,6 +68,12 @@ def is_pyinstaller_app():
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
def get_class_by_name(name: str) -> object:
"Returns the class object from name string"
module_name, class_name = name.rsplit('.', 1)
return getattr(import_module(module_name), class_name)
class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity

38
src/utils/models.py Normal file
View File

@@ -0,0 +1,38 @@
# External Packages
import openai
import torch
from tqdm import trange
# Internal Packages
from src.utils.state import processor_config, config_file
class OpenAI:
def __init__(self, model_name, device=None):
self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}")
openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None
def encode(self, entries: list[str], device=None, **kwargs):
embedding_tensors = []
for index in trange(0, len(entries)):
# OpenAI models create better embeddings for entries without newlines
processed_entry = entries[index].replace('\n', ' ')
try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
# Use current models embedding dimension, once available
# Else default to embedding dimensions of the text-embedding-ada-002 model
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
except Exception as e:
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}")
# Use zero embedding vector for entries with failed embeddings
# This ensures entry embeddings match the order of the source entries
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)]
return torch.stack(embedding_tensors)

View File

@@ -50,10 +50,12 @@ class ContentConfig(ConfigBase):
class TextSearchConfig(ConfigBase):
encoder: str
cross_encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase):
encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class SearchConfig(ConfigBase):

View File

@@ -5,6 +5,7 @@ import json
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import is_none_or_empty
from src.utils.rawconfig import Entry
def test_configure_heading_entry_to_jsonl(tmp_path):
@@ -61,6 +62,24 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
assert len(jsonl_data) == 2
def test_entry_split_drops_large_words(tmp_path):
"Ensure entries drops words larger than specified max word length from compiled version."
# Arrange
entry_text = f'''*** Heading
\t\r
Body Line 1
'''
entry = Entry(raw=entry_text, compiled=entry_text)
# Act
# Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length = 5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
def test_entry_with_body_to_jsonl(tmp_path):
"Ensure entries with valid body text are loaded."
# Arrange