Show friendly name for available ai models on clients when set

This commit is contained in:
Debanjum
2025-06-27 23:51:04 -07:00
parent 487826bc32
commit a8c47a70f7
7 changed files with 77 additions and 20 deletions

View File

@@ -1207,6 +1207,14 @@ class ConversationAdapters:
return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst() return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst()
return await ChatModel.objects.filter(name=chat_model_name).prefetch_related("ai_model_api").afirst() return await ChatModel.objects.filter(name=chat_model_name).prefetch_related("ai_model_api").afirst()
@staticmethod
async def aget_chat_model_by_friendly_name(chat_model_name: str, ai_model_api_name: str = None):
if ai_model_api_name:
return await ChatModel.objects.filter(
friendly_name=chat_model_name, ai_model_api__name=ai_model_api_name
).afirst()
return await ChatModel.objects.filter(friendly_name=chat_model_name).prefetch_related("ai_model_api").afirst()
@staticmethod @staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()

View File

@@ -232,6 +232,7 @@ class KhojUserSubscription(unfold_admin.ModelAdmin):
class ChatModelAdmin(unfold_admin.ModelAdmin): class ChatModelAdmin(unfold_admin.ModelAdmin):
list_display = ( list_display = (
"id", "id",
"friendly_name",
"name", "name",
"ai_model_api", "ai_model_api",
"max_prompt_size", "max_prompt_size",
@@ -243,6 +244,7 @@ class ChatModelAdmin(unfold_admin.ModelAdmin):
class TextToImageModelOptionsAdmin(unfold_admin.ModelAdmin): class TextToImageModelOptionsAdmin(unfold_admin.ModelAdmin):
list_display = ( list_display = (
"id", "id",
"friendly_name",
"model_name", "model_name",
"model_type", "model_type",
) )

View File

@@ -0,0 +1,44 @@
# Generated by Django 5.1.10 on 2025-06-28 06:50
from django.db import migrations, models
def initialize_friendly_names(apps, schema_editor):
"""Initialize friendly_name fields with values from the name/model_name fields"""
ChatModel = apps.get_model("database", "ChatModel")
SpeechToTextModelOptions = apps.get_model("database", "SpeechToTextModelOptions")
TextToImageModelConfig = apps.get_model("database", "TextToImageModelConfig")
# Initialize ChatModel friendly_name with name field
ChatModel.objects.update(friendly_name=models.F("name"))
# Initialize SpeechToTextModelOptions friendly_name with model_name field
SpeechToTextModelOptions.objects.update(friendly_name=models.F("model_name"))
# Initialize TextToImageModelConfig friendly_name with model_name field
TextToImageModelConfig.objects.update(friendly_name=models.F("model_name"))
class Migration(migrations.Migration):
dependencies = [
("database", "0090_alter_khojuser_uuid"),
]
operations = [
migrations.AddField(
model_name="chatmodel",
name="friendly_name",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="speechtotextmodeloptions",
name="friendly_name",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="friendly_name",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.RunPython(initialize_friendly_names, reverse_code=migrations.RunPython.noop),
]

View File

@@ -214,6 +214,7 @@ class ChatModel(DbBaseModel):
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
vision_enabled = models.BooleanField(default=False) vision_enabled = models.BooleanField(default=False)
@@ -222,7 +223,7 @@ class ChatModel(DbBaseModel):
strengths = models.TextField(default=None, null=True, blank=True) strengths = models.TextField(default=None, null=True, blank=True)
def __str__(self): def __str__(self):
return self.name return self.friendly_name
class VoiceModelOption(DbBaseModel): class VoiceModelOption(DbBaseModel):
@@ -554,6 +555,7 @@ class TextToImageModelConfig(DbBaseModel):
GOOGLE = "google" GOOGLE = "google"
model_name = models.CharField(max_length=200, default="dall-e-3") model_name = models.CharField(max_length=200, default="dall-e-3")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True) api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
@@ -592,6 +594,7 @@ class SpeechToTextModelOptions(DbBaseModel):
OFFLINE = "offline" OFFLINE = "offline"
model_name = models.CharField(max_length=200, default="base") model_name = models.CharField(max_length=200, default="base")
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)

View File

@@ -72,7 +72,7 @@ async def all_agents(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent_chat_model.name, "chat_model": agent_chat_model.friendly_name,
"files": file_names, "files": file_names,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@@ -134,7 +134,7 @@ async def get_agent_by_conversation(
chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
if is_subscribed or chat_model.price_tier == PriceTier.FREE: if is_subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = chat_model.name agent_chat_model = chat_model.friendly_name
else: else:
agent_chat_model = None agent_chat_model = None
@@ -219,7 +219,7 @@ async def get_agent(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name, "chat_model": agent.chat_model.friendly_name,
"files": file_names, "files": file_names,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@@ -261,9 +261,9 @@ async def update_hidden_agent(
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) chat_model = await ConversationAdapters.aget_chat_model_by_friendly_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE: if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model agent_chat_model = chat_model.name
else: else:
agent_chat_model = None agent_chat_model = None
@@ -292,7 +292,7 @@ async def update_hidden_agent(
"name": agent.name, "name": agent.name,
"persona": agent.personality, "persona": agent.personality,
"creator": agent.creator.username if agent.creator else None, "creator": agent.creator.username if agent.creator else None,
"chat_model": agent.chat_model.name, "chat_model": agent.chat_model.friendly_name,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
} }
@@ -311,9 +311,9 @@ async def create_hidden_agent(
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) chat_model = await ConversationAdapters.aget_chat_model_by_friendly_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE: if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model agent_chat_model = chat_model.name
else: else:
agent_chat_model = None agent_chat_model = None
@@ -355,7 +355,7 @@ async def create_hidden_agent(
"name": agent.name, "name": agent.name,
"persona": agent.personality, "persona": agent.personality,
"creator": agent.creator.username if agent.creator else None, "creator": agent.creator.username if agent.creator else None,
"chat_model": agent.chat_model.name, "chat_model": agent.chat_model.friendly_name,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
} }
@@ -384,9 +384,9 @@ async def create_agent(
) )
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) chat_model = await ConversationAdapters.aget_chat_model_by_friendly_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE: if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model agent_chat_model = chat_model.name
else: else:
agent_chat_model = None agent_chat_model = None
@@ -415,7 +415,7 @@ async def create_agent(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name, "chat_model": agent.chat_model.friendly_name,
"files": body.files, "files": body.files,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@@ -455,9 +455,9 @@ async def update_agent(
) )
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) chat_model = await ConversationAdapters.aget_chat_model_by_friendly_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE: if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model agent_chat_model = chat_model.name
else: else:
agent_chat_model = None agent_chat_model = None
@@ -485,7 +485,7 @@ async def update_agent(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name, "chat_model": agent.chat_model.friendly_name,
"files": body.files, "files": body.files,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,

View File

@@ -31,7 +31,7 @@ def get_chat_model_options(
for chat_model in chat_models: for chat_model in chat_models:
chat_model_options.append( chat_model_options.append(
{ {
"name": chat_model.name, "name": chat_model.friendly_name,
"id": chat_model.id, "id": chat_model.id,
"strengths": chat_model.strengths, "strengths": chat_model.strengths,
"description": chat_model.description, "description": chat_model.description,
@@ -54,7 +54,7 @@ def get_user_chat_model(
if chat_model is None: if chat_model is None:
chat_model = ConversationAdapters.get_default_chat_model(user) chat_model = ConversationAdapters.get_default_chat_model(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name})) return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.friendly_name}))
@api_model.post("/chat", status_code=200) @api_model.post("/chat", status_code=200)

View File

@@ -2550,7 +2550,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
for chat_model in chat_models: for chat_model in chat_models:
chat_model_options.append( chat_model_options.append(
{ {
"name": chat_model.name, "name": chat_model.friendly_name,
"id": chat_model.id, "id": chat_model.id,
"strengths": chat_model.strengths, "strengths": chat_model.strengths,
"description": chat_model.description, "description": chat_model.description,
@@ -2564,7 +2564,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
for paint_model in paint_model_options: for paint_model in paint_model_options:
all_paint_model_options.append( all_paint_model_options.append(
{ {
"name": paint_model.model_name, "name": paint_model.friendly_name,
"id": paint_model.id, "id": paint_model.id,
"tier": paint_model.price_tier, "tier": paint_model.price_tier,
} }