diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 53e19f71..27c5ed08 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -244,7 +244,7 @@ def configure_server( state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) - setup_default_agent() + setup_default_agent(user) message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled" logger.info(message) @@ -256,8 +256,8 @@ def configure_server( raise e -def setup_default_agent(): - AgentAdapters.create_default_agent() +def setup_default_agent(user: KhojUser): + AgentAdapters.create_default_agent(user) def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9687ec01..40c82fa6 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -643,8 +643,8 @@ class AgentAdapters: return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() @staticmethod - def create_default_agent(): - default_conversation_config = ConversationAdapters.get_default_conversation_config() + def create_default_agent(user: KhojUser): + default_conversation_config = ConversationAdapters.get_default_conversation_config(user) if default_conversation_config is None: logger.info("No default conversation config found, skipping default agent creation") return None @@ -968,29 +968,51 @@ class ConversationAdapters: return VoiceModelOption.objects.first() @staticmethod - def get_default_conversation_config(): + def get_default_conversation_config(user: KhojUser = None): + """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" + # Get the server chat settings server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.chat_default is None: - return ChatModelOptions.objects.filter().first() - return server_chat_settings.chat_default + if server_chat_settings is not None and server_chat_settings.chat_default is not None: + return server_chat_settings.chat_default + + # Get the user's chat settings, if the server chat settings are not set + user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None + if user_chat_settings is not None and user_chat_settings.setting is not None: + return user_chat_settings.setting + + # Get the first chat model if even the user chat settings are not set + return ChatModelOptions.objects.filter().first() @staticmethod - async def aget_default_conversation_config(): + async def aget_default_conversation_config(user: KhojUser = None): + """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" + # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related("chat_default", "chat_default__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.chat_default is None: - return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() - return server_chat_settings.chat_default + if server_chat_settings is not None and server_chat_settings.chat_default is not None: + return server_chat_settings.chat_default + + # Get the user's chat settings, if the server chat settings are not set + user_chat_settings = ( + (await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()) + if user + else None + ) + if user_chat_settings is not None and user_chat_settings.setting is not None: + return user_chat_settings.setting + + # Get the first chat model if even the user chat settings are not set + return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() @staticmethod def get_advanced_conversation_config(): server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.chat_advanced is None: - return ConversationAdapters.get_default_conversation_config() - return server_chat_settings.chat_advanced + if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: + return server_chat_settings.chat_advanced + return ConversationAdapters.get_default_conversation_config() @staticmethod async def aget_advanced_conversation_config(): @@ -999,9 +1021,9 @@ class ConversationAdapters: .prefetch_related("chat_advanced", "chat_advanced__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.chat_advanced is None: - return await ConversationAdapters.aget_default_conversation_config() - return server_chat_settings.chat_advanced + if server_chat_settings is not None or server_chat_settings.chat_advanced is not None: + return server_chat_settings.chat_advanced + return await ConversationAdapters.aget_default_conversation_config() @staticmethod def create_conversation_from_public_conversation(