diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 0168778c..55aeaa21 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -51,11 +51,15 @@ def initialization(interactive: bool = True): # Get available chat models from OpenAI compatible API try: openai_client = openai.OpenAI(api_key=openai_api_key, base_url=openai_base_url) - default_chat_models = [model.id for model in openai_client.models.list()] + available_chat_models = [model.id for model in openai_client.models.list()] # Put the available default OpenAI models at the top - valid_default_models = [model for model in default_openai_chat_models if model in default_chat_models] - other_available_models = [model for model in default_chat_models if model not in valid_default_models] - default_chat_models = valid_default_models + other_available_models + known_available_models = [ + model for model in default_openai_chat_models if model in available_chat_models + ] + other_available_models = [ + model for model in available_chat_models if model not in known_available_models + ] + default_chat_models = known_available_models + other_available_models except Exception as e: logger.warning( f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}" @@ -75,7 +79,7 @@ def initialization(interactive: bool = True): # Setup OpenAI speech to text model if openai_configured: - default_speech2text_model = "whisper-1" + default_speech2text_model = "whisper-1" if openai_base_url is None else None if interactive: openai_speech2text_model = input( f"Enter the OpenAI speech to text model you want to use (default: {default_speech2text_model}): " @@ -83,13 +87,16 @@ def initialization(interactive: bool = True): openai_speech2text_model = openai_speech2text_model or default_speech2text_model else: openai_speech2text_model = default_speech2text_model - SpeechToTextModelOptions.objects.create( - model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI - ) + + if openai_speech2text_model: + SpeechToTextModelOptions.objects.create( + model_name=openai_speech2text_model, + model_type=SpeechToTextModelOptions.ModelType.OPENAI, + ) # Setup OpenAI text to image model if openai_configured: - default_text_to_image_model = "dall-e-3" + default_text_to_image_model = "dall-e-3" if openai_base_url is None else None if interactive: openai_text_to_image_model = input( f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): " @@ -97,11 +104,13 @@ def initialization(interactive: bool = True): openai_text_to_image_model = openai_text_to_image_model or default_text_to_image_model else: openai_text_to_image_model = default_text_to_image_model - TextToImageModelConfig.objects.create( - model_name=openai_text_to_image_model, - model_type=TextToImageModelConfig.ModelType.OPENAI, - ai_model_api=openai_provider, - ) + + if openai_text_to_image_model: + TextToImageModelConfig.objects.create( + model_name=openai_text_to_image_model, + model_type=TextToImageModelConfig.ModelType.OPENAI, + ai_model_api=openai_provider, + ) # Set up Google's Gemini online chat models google_ai_configured, google_ai_provider = _setup_chat_model_provider(