diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19e7d403..8c302450 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -23,7 +23,7 @@ from starlette.authentication import ( from khoj.database.models import KhojUser, Subscription from khoj.database.adapters import ( get_all_users, - get_or_create_search_model, + get_or_create_search_models, aget_user_subscription_state, SubscriptionState, ) @@ -140,8 +140,14 @@ def configure_server( # Initialize Search Models from Config and initialize content try: - state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) - state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder) + search_models = get_or_create_search_models() + state.embeddings_model = dict() + state.cross_encoder_model = dict() + + for model in search_models: + state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)}) + state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)}) + state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) initialize_content(regenerate, search_type, init, user) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9d18c815..41b0ef06 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -249,12 +249,19 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config -def get_or_create_search_model(): - search_model = SearchModelConfig.objects.filter().first() - if not search_model: - search_model = SearchModelConfig.objects.create() +def get_default_search_model(): + if SearchModelConfig.objects.filter(name="default").exists(): + return SearchModelConfig.objects.filter(name="default").first() + return SearchModelConfig.objects.first() - return search_model + +def get_or_create_search_models(): + search_models = SearchModelConfig.objects.all() + if search_models.count() == 0: + SearchModelConfig.objects.create() + search_models = SearchModelConfig.objects.all() + + return search_models class ConversationAdapters: diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 2213fb6e..2561a5da 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -12,6 +12,7 @@ from khoj.database.models import ( SpeechToTextModelOptions, Subscription, ReflectiveQuestion, + UserSearchModelConfig, ) admin.site.register(KhojUser, UserAdmin) @@ -23,3 +24,4 @@ admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(SearchModelConfig) admin.site.register(Subscription) admin.site.register(ReflectiveQuestion) +admin.site.register(UserSearchModelConfig) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82348fbe..19393c9c 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -145,6 +145,11 @@ class UserConversationConfig(BaseModel): setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) +class UserSearchModelConfig(BaseModel): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) + + class Conversation(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 109c58e6..bfcf37f7 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher from khoj.utils.rawconfig import Entry from khoj.search_filter.date_filter import DateFilter from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates -from khoj.database.adapters import EntryAdapters +from khoj.database.adapters import EntryAdapters, get_default_search_model logger = logging.getLogger(__name__) @@ -112,7 +112,8 @@ class TextToEntries(ABC): with timer("Generated embeddings for entries to add to database in", logger): entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process] - embeddings += self.embeddings_model.embed_documents(data_to_embed) + model = get_default_search_model() + embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed) added_entries: list[DbEntry] = [] with timer("Added entries to database in", logger): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae125980..52b772fd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -18,7 +18,7 @@ from starlette.authentication import requires # Internal Packages from khoj.configure import configure_server from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model from khoj.database.models import ChatModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( @@ -412,7 +412,8 @@ async def search( ] if text_search_models: with timer("Encoding query took", logger=logger): - encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query) + search_model = await sync_to_async(get_default_search_model)() + encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: if t in [ diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index d04d4c6a..1523473c 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -19,7 +19,7 @@ from khoj.utils.state import SearchType from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.jsonl import load_jsonl from khoj.processor.content.text_to_entries import TextToEntries -from khoj.database.adapters import EntryAdapters +from khoj.database.adapters import EntryAdapters, get_default_search_model from khoj.database.models import KhojUser, Entry as DbEntry logger = logging.getLogger(__name__) @@ -115,7 +115,8 @@ async def query( # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - question_embedding = state.embeddings_model.embed_query(query) + search_model = await sync_to_async(get_default_search_model)() + question_embedding = state.embeddings_model[search_model.name].embed_query(query) # Find relevant entries for the query top_k = 10 diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index b54cf4b3..4e135b18 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -18,7 +18,7 @@ from khoj.utils.rawconfig import FullConfig # Application Global State config = FullConfig() search_models = SearchModels() -embeddings_model: EmbeddingsModel = None +embeddings_model: Dict[str, EmbeddingsModel] = None cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() gpt4all_processor_config: GPT4AllProcessorModel = None diff --git a/tests/conftest.py b/tests/conftest.py index 9a500609..bbb3aa39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,8 +45,10 @@ def enable_db_access_for_all_tests(db): @pytest.fixture(scope="session") def search_config() -> SearchConfig: - state.embeddings_model = EmbeddingsModel() - state.cross_encoder_model = CrossEncoderModel() + state.embeddings_model = dict() + state.embeddings_model["default"] = EmbeddingsModel() + state.cross_encoder_model = dict() + state.cross_encoder_model["default"] = CrossEncoderModel() model_dir = resolve_absolute_path("~/.khoj/search") model_dir.mkdir(parents=True, exist_ok=True) @@ -317,8 +319,10 @@ def client( state.config.content_type = content_config state.config.search_type = search_config state.SearchType = configure_search_types() - state.embeddings_model = EmbeddingsModel() - state.cross_encoder_model = CrossEncoderModel() + state.embeddings_model = dict() + state.embeddings_model["default"] = EmbeddingsModel() + state.cross_encoder_model = dict() + state.cross_encoder_model["default"] = CrossEncoderModel() # These lines help us Mock the Search models for these search types state.search_models.image_search = image_search.initialize_model(search_config.image)