mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Create BaseEncoder class. Make OpenAI encoder its child. Use for typing
- Set type of all bi_encoders to BaseEncoder - Make load_model return type Union of CrossEncoder and BaseEncoder
This commit is contained in:
@@ -13,6 +13,7 @@ from src.search_filter.base_filter import BaseFilter
|
|||||||
from src.utils import state
|
from src.utils import state
|
||||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||||
from src.utils.config import TextSearchModel
|
from src.utils.config import TextSearchModel
|
||||||
|
from src.utils.models import BaseEncoder
|
||||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||||
from src.utils.jsonl import load_jsonl
|
from src.utils.jsonl import load_jsonl
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ def extract_entries(jsonl_file) -> list[Entry]:
|
|||||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
|
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
new_entries = []
|
new_entries = []
|
||||||
# Load pre-computed embeddings from file if exists and update them if required
|
# Load pre-computed embeddings from file if exists and update them if required
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import torch
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.rawconfig import ConversationProcessorConfig, Entry
|
from src.utils.rawconfig import ConversationProcessorConfig, Entry
|
||||||
from src.search_filter.base_filter import BaseFilter
|
from src.search_filter.base_filter import BaseFilter
|
||||||
|
from src.utils.models import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
class SearchType(str, Enum):
|
||||||
@@ -24,7 +25,7 @@ class ProcessorType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class TextSearchModel():
|
class TextSearchModel():
|
||||||
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.corpus_embeddings = corpus_embeddings
|
self.corpus_embeddings = corpus_embeddings
|
||||||
self.bi_encoder = bi_encoder
|
self.bi_encoder = bi_encoder
|
||||||
@@ -34,7 +35,7 @@ class TextSearchModel():
|
|||||||
|
|
||||||
|
|
||||||
class ImageSearchModel():
|
class ImageSearchModel():
|
||||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
||||||
self.image_encoder = image_encoder
|
self.image_encoder = image_encoder
|
||||||
self.image_names = image_names
|
self.image_names = image_names
|
||||||
self.image_embeddings = image_embeddings
|
self.image_embeddings = image_embeddings
|
||||||
|
|||||||
@@ -7,6 +7,12 @@ from collections import OrderedDict
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from src.utils.models import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
def is_none_or_empty(item):
|
def is_none_or_empty(item):
|
||||||
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
|
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
|
||||||
@@ -45,7 +51,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
|||||||
return merged_dict
|
return merged_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
|
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
|
||||||
"Load model from disk or huggingface"
|
"Load model from disk or huggingface"
|
||||||
# Construct model path
|
# Construct model path
|
||||||
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
# Standard Packages
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import openai
|
import openai
|
||||||
import torch
|
import torch
|
||||||
@@ -7,7 +10,15 @@ from tqdm import trange
|
|||||||
from src.utils.state import processor_config, config_file
|
from src.utils.state import processor_config, config_file
|
||||||
|
|
||||||
|
|
||||||
class OpenAI:
|
class BaseEncoder(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAI(BaseEncoder):
|
||||||
def __init__(self, model_name, device=None):
|
def __init__(self, model_name, device=None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
|
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
|
||||||
@@ -15,7 +26,7 @@ class OpenAI:
|
|||||||
openai.api_key = processor_config.conversation.openai_api_key
|
openai.api_key = processor_config.conversation.openai_api_key
|
||||||
self.embedding_dimensions = None
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
def encode(self, entries: list[str], device=None, **kwargs):
|
def encode(self, entries, device=None, **kwargs):
|
||||||
embedding_tensors = []
|
embedding_tensors = []
|
||||||
|
|
||||||
for index in trange(0, len(entries)):
|
for index in trange(0, len(entries)):
|
||||||
|
|||||||
Reference in New Issue
Block a user