Add Validation logic to save PaintModel. Use API key from Paint Model

Rename Paint Model, Adapters to TextToImage for consistency
This commit is contained in:
Debanjum Singh Solanky
2024-06-26 09:51:06 +05:30
parent 1acf969c6e
commit c793d8a69e
6 changed files with 64 additions and 15 deletions

View File

@@ -45,9 +45,9 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserPaintModelConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
)
@@ -897,25 +897,27 @@ class ConversationAdapters:
return TextToImageModelConfig.objects.all()
@staticmethod
def get_user_paint_model_config(user: KhojUser):
config = UserPaintModelConfig.objects.filter(user=user).first()
def get_user_text_to_image_model_config(user: KhojUser):
config = UserTextToImageModelConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting
@staticmethod
async def aget_user_paint_model(user: KhojUser):
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
async def aget_user_text_to_image_model(user: KhojUser):
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
@staticmethod
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
user=user, defaults={"setting": config}
)
return new_config

View File

@@ -1,4 +1,4 @@
# Generated by Django 4.2.11 on 2024-06-20 19:48
# Generated by Django 4.2.11 on 2024-06-26 03:27
import django.db.models.deletion
from django.conf import settings
@@ -7,7 +7,7 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0047_alter_entry_file_type"),
("database", "0048_voicemodeloption_uservoicemodelconfig"),
]
operations = [
@@ -16,6 +16,17 @@ class Migration(migrations.Migration):
name="api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
@@ -24,7 +35,7 @@ class Migration(migrations.Migration):
),
),
migrations.CreateModel(
name="UserPaintModelConfig",
name="UserTextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),

View File

@@ -239,6 +239,32 @@ class TextToImageModelConfig(BaseModel):
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)
def clean(self):
# Custom validation logic
error = {}
if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.openai_config:
error[
"api_key"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
error[
"openai_config"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
if self.model_type != self.ModelType.OPENAI:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI models."
if self.openai_config:
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel):
@@ -265,7 +291,7 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserPaintModelConfig(BaseModel):
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)

View File

@@ -328,7 +328,7 @@ async def update_paint_model(
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,

View File

@@ -762,7 +762,7 @@ async def text_to_image(
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
@@ -796,9 +796,19 @@ async def text_to_image(
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
prompt=improved_image_prompt,
model=text2image_model,
response_format="b64_json",
extra_headers=auth_header,
)
image = response.data[0].b64_json
decoded_image = base64.b64decode(image)

View File

@@ -262,7 +262,7 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user)
selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options: