Initial changes to support multiple search model configurations

- All search models are loaded into memory, and stored in a dictionary indexed by name
- Still need to add database migrations and create a UI for user to select their choice. Presently, it uses the default option
This commit is contained in:
sabaimran
2023-12-05 00:35:40 -05:00
parent d2ddbef08f
commit ef21d78c99
9 changed files with 46 additions and 19 deletions

View File

@@ -23,7 +23,7 @@ from starlette.authentication import (
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import (
get_all_users,
get_or_create_search_model,
get_or_create_search_models,
aget_user_subscription_state,
SubscriptionState,
)
@@ -140,8 +140,14 @@ def configure_server(
# Initialize Search Models from Config and initialize content
try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
search_models = get_or_create_search_models()
state.embeddings_model = dict()
state.cross_encoder_model = dict()
for model in search_models:
state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)})
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init, user)

View File

@@ -249,12 +249,19 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
def get_or_create_search_model():
search_model = SearchModelConfig.objects.filter().first()
if not search_model:
search_model = SearchModelConfig.objects.create()
def get_default_search_model():
if SearchModelConfig.objects.filter(name="default").exists():
return SearchModelConfig.objects.filter(name="default").first()
return SearchModelConfig.objects.first()
return search_model
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
SearchModelConfig.objects.create()
search_models = SearchModelConfig.objects.all()
return search_models
class ConversationAdapters:

View File

@@ -12,6 +12,7 @@ from khoj.database.models import (
SpeechToTextModelOptions,
Subscription,
ReflectiveQuestion,
UserSearchModelConfig,
)
admin.site.register(KhojUser, UserAdmin)
@@ -23,3 +24,4 @@ admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModelConfig)
admin.site.register(Subscription)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)

View File

@@ -145,6 +145,11 @@ class UserConversationConfig(BaseModel):
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)

View File

@@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher
from khoj.utils.rawconfig import Entry
from khoj.search_filter.date_filter import DateFilter
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
from khoj.database.adapters import EntryAdapters
from khoj.database.adapters import EntryAdapters, get_default_search_model
logger = logging.getLogger(__name__)
@@ -112,7 +112,8 @@ class TextToEntries(ABC):
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings += self.embeddings_model.embed_documents(data_to_embed)
model = get_default_search_model()
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
added_entries: list[DbEntry] = []
with timer("Added entries to database in", logger):

View File

@@ -18,7 +18,7 @@ from starlette.authentication import requires
# Internal Packages
from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
from khoj.database.models import ChatModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
@@ -412,7 +412,8 @@ async def search(
]
if text_search_models:
with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
if t in [

View File

@@ -19,7 +19,7 @@ from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.database.adapters import EntryAdapters
from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.models import KhojUser, Entry as DbEntry
logger = logging.getLogger(__name__)
@@ -115,7 +115,8 @@ async def query(
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
question_embedding = state.embeddings_model.embed_query(query)
search_model = await sync_to_async(get_default_search_model)()
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
# Find relevant entries for the query
top_k = 10

View File

@@ -18,7 +18,7 @@ from khoj.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
search_models = SearchModels()
embeddings_model: EmbeddingsModel = None
embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None

View File

@@ -45,8 +45,10 @@ def enable_db_access_for_all_tests(db):
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
state.embeddings_model = dict()
state.embeddings_model["default"] = EmbeddingsModel()
state.cross_encoder_model = dict()
state.cross_encoder_model["default"] = CrossEncoderModel()
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
@@ -317,8 +319,10 @@ def client(
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types()
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
state.embeddings_model = dict()
state.embeddings_model["default"] = EmbeddingsModel()
state.cross_encoder_model = dict()
state.cross_encoder_model["default"] = CrossEncoderModel()
# These lines help us Mock the Search models for these search types
state.search_models.image_search = image_search.initialize_model(search_config.image)