diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 9df6d46e..90bb9921 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -1,5 +1,6 @@ import logging import os +from typing import Tuple from khoj.database.adapters import ConversationAdapters from khoj.database.models import ( @@ -41,41 +42,18 @@ def initialization(interactive: bool = True): "🗣️ Configure chat models available to your server. You can always update these at /server/admin using your admin account" ) - # Set up OpenAI's online models - default_openai_api_key = os.getenv("OPENAI_API_KEY") - default_use_openai_model = {True: "y", False: "n"}[default_openai_api_key != None] - use_model_provider = default_use_openai_model if not interactive else input("Add OpenAI models? (y/n): ") - if use_model_provider == "y": - logger.info("️💬 Setting up your OpenAI configuration") - if interactive: - user_api_key = input(f"Enter your OpenAI API key (default: {default_openai_api_key}): ") - api_key = user_api_key if user_api_key != "" else default_openai_api_key - else: - api_key = default_openai_api_key - chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="OpenAI") + # Set up OpenAI's online chat models + openai_configured, openai_provider = _setup_chat_model_provider( + ChatModelOptions.ModelType.OPENAI, + default_openai_chat_models, + default_api_key=os.getenv("OPENAI_API_KEY"), + vision_enabled=True, + is_offline=False, + interactive=interactive, + ) - if interactive: - chat_model_names = input( - f"Enter the OpenAI chat models you want to use (default: {','.join(default_openai_chat_models)}): " - ) - chat_models = chat_model_names.split(",") if chat_model_names != "" else default_openai_chat_models - chat_models = [model.strip() for model in chat_models] - else: - chat_models = default_openai_chat_models - - # Add OpenAI chat models - for chat_model in chat_models: - vision_enabled = chat_model in ["gpt-4o-mini", "gpt-4o"] - default_max_tokens = model_to_prompt_size.get(chat_model) - ChatModelOptions.objects.create( - chat_model=chat_model, - model_type=ChatModelOptions.ModelType.OPENAI, - max_prompt_size=default_max_tokens, - openai_config=chat_model_provider, - vision_enabled=vision_enabled, - ) - - # Add OpenAI speech to text model + # Setup OpenAI speech to text model + if openai_configured: default_speech2text_model = "whisper-1" if interactive: openai_speech2text_model = input( @@ -88,7 +66,8 @@ def initialization(interactive: bool = True): model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI ) - # Add OpenAI text to image model + # Setup OpenAI text to image model + if openai_configured: default_text_to_image_model = "dall-e-3" if interactive: openai_text_to_image_model = input( @@ -98,107 +77,44 @@ def initialization(interactive: bool = True): else: openai_text_to_image_model = default_text_to_image_model TextToImageModelConfig.objects.create( - model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI + model_name=openai_text_to_image_model, + model_type=TextToImageModelConfig.ModelType.OPENAI, + openai_config=openai_provider, ) # Set up Google's Gemini online chat models - default_gemini_api_key = os.getenv("GEMINI_API_KEY") - default_use_gemini_model = {True: "y", False: "n"}[default_gemini_api_key != None] - use_model_provider = default_use_gemini_model if not interactive else input("Add Google's chat models? (y/n): ") - if use_model_provider == "y": - logger.info("️💬 Setting up your Google Gemini configuration") - if interactive: - user_api_key = input(f"Enter your Gemini API key (default: {default_gemini_api_key}): ") - api_key = user_api_key if user_api_key != "" else default_gemini_api_key - else: - api_key = default_gemini_api_key - chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="Gemini") - - if interactive: - chat_model_names = input( - f"Enter the Gemini chat models you want to use (default: {','.join(default_gemini_chat_models)}): " - ) - chat_models = chat_model_names.split(",") if chat_model_names != "" else default_gemini_chat_models - chat_models = [model.strip() for model in chat_models] - else: - chat_models = default_gemini_chat_models - - # Add Gemini chat models - for chat_model in chat_models: - default_max_tokens = model_to_prompt_size.get(chat_model) - vision_enabled = False - ChatModelOptions.objects.create( - chat_model=chat_model, - model_type=ChatModelOptions.ModelType.GOOGLE, - max_prompt_size=default_max_tokens, - openai_config=chat_model_provider, - vision_enabled=False, - ) + _setup_chat_model_provider( + ChatModelOptions.ModelType.GOOGLE, + default_gemini_chat_models, + default_api_key=os.getenv("GEMINI_API_KEY"), + vision_enabled=False, + is_offline=False, + interactive=interactive, + provider_name="Google Gemini", + ) # Set up Anthropic's online chat models - default_anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") - default_use_anthropic_model = {True: "y", False: "n"}[default_anthropic_api_key != None] - use_model_provider = ( - default_use_anthropic_model if not interactive else input("Add Anthropic's chat models? (y/n): ") + _setup_chat_model_provider( + ChatModelOptions.ModelType.ANTHROPIC, + default_anthropic_chat_models, + default_api_key=os.getenv("ANTHROPIC_API_KEY"), + vision_enabled=False, + is_offline=False, + interactive=interactive, ) - if use_model_provider == "y": - logger.info("️💬 Setting up your Anthropic configuration") - if interactive: - user_api_key = input(f"Enter your Anthropic API key (default: {default_anthropic_api_key}): ") - api_key = user_api_key if user_api_key != "" else default_anthropic_api_key - else: - api_key = default_anthropic_api_key - chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="Anthropic") - - if interactive: - chat_model_names = input( - f"Enter the Anthropic chat models you want to use (default: {','.join(default_anthropic_chat_models)}): " - ) - chat_models = chat_model_names.split(",") if chat_model_names != "" else default_anthropic_chat_models - chat_models = [model.strip() for model in chat_models] - else: - chat_models = default_anthropic_chat_models - - # Add Anthropic chat models - for chat_model in chat_models: - vision_enabled = False - default_max_tokens = model_to_prompt_size.get(chat_model) - ChatModelOptions.objects.create( - chat_model=chat_model, - model_type=ChatModelOptions.ModelType.ANTHROPIC, - max_prompt_size=default_max_tokens, - openai_config=chat_model_provider, - vision_enabled=False, - ) # Set up offline chat models - use_model_provider = "y" if not interactive else input("Add Offline chat models? (y/n): ") - if use_model_provider == "y": - logger.info("️💬 Setting up Offline chat models") - - if interactive: - chat_model_names = input( - f"Enter the offline chat models you want to use. See HuggingFace for available GGUF models (default: {','.join(default_offline_chat_models)}): " - ) - chat_models = chat_model_names.split(",") if chat_model_names != "" else default_offline_chat_models - chat_models = [model.strip() for model in chat_models] - else: - chat_models = default_offline_chat_models - - # Add chat models - for chat_model in chat_models: - default_max_tokens = model_to_prompt_size.get(chat_model) - default_tokenizer = model_to_tokenizer.get(chat_model) - ChatModelOptions.objects.create( - chat_model=chat_model, - model_type=ChatModelOptions.ModelType.OFFLINE, - max_prompt_size=default_max_tokens, - tokenizer=default_tokenizer, - ) - - chat_models_configured = ChatModelOptions.objects.count() + _setup_chat_model_provider( + ChatModelOptions.ModelType.OFFLINE, + default_offline_chat_models, + default_api_key=None, + vision_enabled=False, + is_offline=True, + interactive=interactive, + ) # Explicitly set default chat model + chat_models_configured = ChatModelOptions.objects.count() if chat_models_configured > 0: default_chat_model_name = ChatModelOptions.objects.first().chat_model # If there are multiple chat models, ask the user to choose the default chat model @@ -236,6 +152,64 @@ def initialization(interactive: bool = True): logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}") + def _setup_chat_model_provider( + model_type: ChatModelOptions.ModelType, + default_chat_models: list, + default_api_key: str, + interactive: bool, + vision_enabled: bool = False, + is_offline: bool = False, + provider_name: str = None, + ) -> Tuple[bool, OpenAIProcessorConversationConfig]: + supported_vision_models = ["gpt-4o-mini", "gpt-4o"] + provider_name = provider_name or model_type.name.capitalize() + default_use_model = {True: "y", False: "n"}[default_api_key is not None or is_offline] + use_model_provider = ( + default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") + ) + + if use_model_provider != "y": + return False, None + + logger.info(f"️💬 Setting up your {provider_name} chat configuration") + + chat_model_provider = None + if not is_offline: + if interactive: + user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ") + api_key = user_api_key if user_api_key != "" else default_api_key + else: + api_key = default_api_key + chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name=provider_name) + + if interactive: + chat_model_names = input( + f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): " + ) + chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models + chat_models = [model.strip() for model in chat_models] + else: + chat_models = default_chat_models + + for chat_model in chat_models: + default_max_tokens = model_to_prompt_size.get(chat_model) + default_tokenizer = model_to_tokenizer.get(chat_model) + vision_enabled = vision_enabled and chat_model in supported_vision_models + + chat_model_options = { + "chat_model": chat_model, + "model_type": model_type, + "max_prompt_size": default_max_tokens, + "vision_enabled": vision_enabled, + "tokenizer": default_tokenizer, + "openai_config": chat_model_provider, + } + + ChatModelOptions.objects.create(**chat_model_options) + + logger.info(f"🗣️ {provider_name} chat model configuration complete") + return True, chat_model_provider + admin_user = KhojUser.objects.filter(is_staff=True).first() if admin_user is None: while True: