mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Add Validation logic to save PaintModel. Use API key from Paint Model
Rename Paint Model, Adapters to TextToImage for consistency
This commit is contained in:
@@ -45,9 +45,9 @@ from khoj.database.models import (
|
|||||||
Subscription,
|
Subscription,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
UserPaintModelConfig,
|
|
||||||
UserRequests,
|
UserRequests,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
|
UserTextToImageModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
)
|
)
|
||||||
@@ -897,25 +897,27 @@ class ConversationAdapters:
|
|||||||
return TextToImageModelConfig.objects.all()
|
return TextToImageModelConfig.objects.all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_user_paint_model_config(user: KhojUser):
|
def get_user_text_to_image_model_config(user: KhojUser):
|
||||||
config = UserPaintModelConfig.objects.filter(user=user).first()
|
config = UserTextToImageModelConfig.objects.filter(user=user).first()
|
||||||
if not config:
|
if not config:
|
||||||
return None
|
return None
|
||||||
return config.setting
|
return config.setting
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_user_paint_model(user: KhojUser):
|
async def aget_user_text_to_image_model(user: KhojUser):
|
||||||
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
if not config:
|
if not config:
|
||||||
return None
|
return None
|
||||||
return config.setting
|
return config.setting
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
|
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
|
||||||
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
|
||||||
if not config:
|
if not config:
|
||||||
return None
|
return None
|
||||||
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
|
||||||
|
user=user, defaults={"setting": config}
|
||||||
|
)
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Generated by Django 4.2.11 on 2024-06-20 19:48
|
# Generated by Django 4.2.11 on 2024-06-26 03:27
|
||||||
|
|
||||||
import django.db.models.deletion
|
import django.db.models.deletion
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
@@ -7,7 +7,7 @@ from django.db import migrations, models
|
|||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
dependencies = [
|
dependencies = [
|
||||||
("database", "0047_alter_entry_file_type"),
|
("database", "0048_voicemodeloption_uservoicemodelconfig"),
|
||||||
]
|
]
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
@@ -16,6 +16,17 @@ class Migration(migrations.Migration):
|
|||||||
name="api_key",
|
name="api_key",
|
||||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
),
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="openai_config",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="database.openaiprocessorconversationconfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="texttoimagemodelconfig",
|
model_name="texttoimagemodelconfig",
|
||||||
name="model_type",
|
name="model_type",
|
||||||
@@ -24,7 +35,7 @@ class Migration(migrations.Migration):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name="UserPaintModelConfig",
|
name="UserTextToImageModelConfig",
|
||||||
fields=[
|
fields=[
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
@@ -239,6 +239,32 @@ class TextToImageModelConfig(BaseModel):
|
|||||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
openai_config = models.ForeignKey(
|
||||||
|
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
# Custom validation logic
|
||||||
|
error = {}
|
||||||
|
if self.model_type == self.ModelType.OPENAI:
|
||||||
|
if self.api_key and self.openai_config:
|
||||||
|
error[
|
||||||
|
"api_key"
|
||||||
|
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||||
|
error[
|
||||||
|
"openai_config"
|
||||||
|
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||||
|
if self.model_type != self.ModelType.OPENAI:
|
||||||
|
if not self.api_key:
|
||||||
|
error["api_key"] = "The API key field must be set for non OpenAI models."
|
||||||
|
if self.openai_config:
|
||||||
|
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
|
||||||
|
if error:
|
||||||
|
raise ValidationError(error)
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
self.clean()
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextModelOptions(BaseModel):
|
class SpeechToTextModelOptions(BaseModel):
|
||||||
@@ -265,7 +291,7 @@ class UserSearchModelConfig(BaseModel):
|
|||||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class UserPaintModelConfig(BaseModel):
|
class UserTextToImageModelConfig(BaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ async def update_paint_model(
|
|||||||
if not subscribed:
|
if not subscribed:
|
||||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||||
|
|
||||||
new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
|
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
|||||||
@@ -762,7 +762,7 @@ async def text_to_image(
|
|||||||
image_url = None
|
image_url = None
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
|
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
@@ -796,9 +796,19 @@ async def text_to_image(
|
|||||||
|
|
||||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
with timer("Generate image with OpenAI", logger):
|
with timer("Generate image with OpenAI", logger):
|
||||||
|
if text_to_image_config.api_key:
|
||||||
|
api_key = text_to_image_config.api_key
|
||||||
|
elif text_to_image_config.openai_config:
|
||||||
|
api_key = text_to_image_config.openai_config.api_key
|
||||||
|
elif state.openai_client:
|
||||||
|
api_key = state.openai_client.api_key
|
||||||
|
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
try:
|
try:
|
||||||
response = state.openai_client.images.generate(
|
response = state.openai_client.images.generate(
|
||||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
prompt=improved_image_prompt,
|
||||||
|
model=text2image_model,
|
||||||
|
response_format="b64_json",
|
||||||
|
extra_headers=auth_header,
|
||||||
)
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
decoded_image = base64.b64decode(image)
|
decoded_image = base64.b64decode(image)
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ def config_page(request: Request):
|
|||||||
|
|
||||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||||
|
|
||||||
selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
|
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||||
all_paint_model_options = list()
|
all_paint_model_options = list()
|
||||||
for paint_model in paint_model_options:
|
for paint_model in paint_model_options:
|
||||||
|
|||||||
Reference in New Issue
Block a user