diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 12a127e9..f1f4031e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -22,6 +22,7 @@ from khoj.database.models import ( GithubConfig, GithubRepoConfig, GoogleUser, + TextToImageModelConfig, KhojApiUser, KhojUser, NotionConfig, @@ -414,6 +415,10 @@ class ConversationAdapters: else: raise ValueError("Invalid conversation config - either configure offline chat or openai chat") + @staticmethod + async def aget_text_to_image_model_config(): + return await TextToImageModelConfig.objects.filter().afirst() + class EntryAdapters: word_filer = WordFilter() diff --git a/src/khoj/database/migrations/0022_texttoimagemodelconfig.py b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py new file mode 100644 index 00000000..7450dc40 --- /dev/null +++ b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.7 on 2023-12-04 22:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0021_speechtotextmodeloptions_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="TextToImageModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("model_name", models.CharField(default="dall-e-3", max_length=200)), + ("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82348fbe..00700f2f 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel): cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") +class TextToImageModelConfig(BaseModel): + class ModelType(models.TextChoices): + OPENAI = "openai" + + model_name = models.CharField(max_length=200, default="dall-e-3") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + + class OpenAIProcessorConversationConfig(BaseModel): api_key = models.CharField(max_length=200) diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 313b18fc..0bb78dbe 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -7,6 +7,7 @@ from khoj.database.models import ( OpenAIProcessorConversationConfig, ChatModelOptions, SpeechToTextModelOptions, + TextToImageModelConfig, ) from khoj.utils.constants import default_offline_chat_model, default_online_chat_model @@ -103,6 +104,15 @@ def initialization(): model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI ) + default_text_to_image_model = "dall-e-3" + openai_text_to_image_model = input( + f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): " + ) + openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model + TextToImageModelConfig.objects.create( + model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI + ) + if use_offline_model == "y" or use_openai_model == "y": logger.info("🗣️ Chat model configuration complete")