mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Add Validation logic to save PaintModel. Use API key from Paint Model
Rename Paint Model, Adapters to TextToImage for consistency
This commit is contained in:
@@ -45,9 +45,9 @@ from khoj.database.models import (
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserConversationConfig,
|
||||
UserPaintModelConfig,
|
||||
UserRequests,
|
||||
UserSearchModelConfig,
|
||||
UserTextToImageModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
VoiceModelOption,
|
||||
)
|
||||
@@ -897,25 +897,27 @@ class ConversationAdapters:
|
||||
return TextToImageModelConfig.objects.all()
|
||||
|
||||
@staticmethod
|
||||
def get_user_paint_model_config(user: KhojUser):
|
||||
config = UserPaintModelConfig.objects.filter(user=user).first()
|
||||
def get_user_text_to_image_model_config(user: KhojUser):
|
||||
config = UserTextToImageModelConfig.objects.filter(user=user).first()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def aget_user_paint_model(user: KhojUser):
|
||||
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
async def aget_user_text_to_image_model(user: KhojUser):
|
||||
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
|
||||
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
|
||||
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
||||
if not config:
|
||||
return None
|
||||
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
|
||||
user=user, defaults={"setting": config}
|
||||
)
|
||||
return new_config
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Generated by Django 4.2.11 on 2024-06-20 19:48
|
||||
# Generated by Django 4.2.11 on 2024-06-26 03:27
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
@@ -7,7 +7,7 @@ from django.db import migrations, models
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0047_alter_entry_file_type"),
|
||||
("database", "0048_voicemodeloption_uservoicemodelconfig"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
@@ -16,6 +16,17 @@ class Migration(migrations.Migration):
|
||||
name="api_key",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="openai_config",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="database.openaiprocessorconversationconfig",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="model_type",
|
||||
@@ -24,7 +35,7 @@ class Migration(migrations.Migration):
|
||||
),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="UserPaintModelConfig",
|
||||
name="UserTextToImageModelConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
@@ -239,6 +239,32 @@ class TextToImageModelConfig(BaseModel):
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
openai_config = models.ForeignKey(
|
||||
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||
)
|
||||
|
||||
def clean(self):
|
||||
# Custom validation logic
|
||||
error = {}
|
||||
if self.model_type == self.ModelType.OPENAI:
|
||||
if self.api_key and self.openai_config:
|
||||
error[
|
||||
"api_key"
|
||||
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||
error[
|
||||
"openai_config"
|
||||
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||
if self.model_type != self.ModelType.OPENAI:
|
||||
if not self.api_key:
|
||||
error["api_key"] = "The API key field must be set for non OpenAI models."
|
||||
if self.openai_config:
|
||||
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
|
||||
if error:
|
||||
raise ValidationError(error)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.clean()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class SpeechToTextModelOptions(BaseModel):
|
||||
@@ -265,7 +291,7 @@ class UserSearchModelConfig(BaseModel):
|
||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class UserPaintModelConfig(BaseModel):
|
||||
class UserTextToImageModelConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
@@ -328,7 +328,7 @@ async def update_paint_model(
|
||||
if not subscribed:
|
||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
|
||||
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
||||
@@ -762,7 +762,7 @@ async def text_to_image(
|
||||
image_url = None
|
||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||
|
||||
text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
|
||||
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
@@ -796,9 +796,19 @@ async def text_to_image(
|
||||
|
||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
with timer("Generate image with OpenAI", logger):
|
||||
if text_to_image_config.api_key:
|
||||
api_key = text_to_image_config.api_key
|
||||
elif text_to_image_config.openai_config:
|
||||
api_key = text_to_image_config.openai_config.api_key
|
||||
elif state.openai_client:
|
||||
api_key = state.openai_client.api_key
|
||||
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
try:
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||
prompt=improved_image_prompt,
|
||||
model=text2image_model,
|
||||
response_format="b64_json",
|
||||
extra_headers=auth_header,
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
decoded_image = base64.b64decode(image)
|
||||
|
||||
@@ -262,7 +262,7 @@ def config_page(request: Request):
|
||||
|
||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||
|
||||
selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
|
||||
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||
all_paint_model_options = list()
|
||||
for paint_model in paint_model_options:
|
||||
|
||||
Reference in New Issue
Block a user