Track Price tier for each Chat, Speech, Image, Voice AI model in DB

Enables users on free plan to choose AI models marked for free tier
This commit is contained in:
Debanjum
2025-04-01 11:41:47 +05:30
parent fdaf51f0ea
commit 30570e3e06
3 changed files with 79 additions and 10 deletions

View File

@@ -48,6 +48,7 @@ from khoj.database.models import (
KhojApiUser,
KhojUser,
NotionConfig,
PriceTier,
ProcessLock,
PublicConversation,
RateLimitRecord,
@@ -1153,22 +1154,36 @@ class ConversationAdapters:
@staticmethod
def get_chat_model(user: KhojUser):
subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_chat_model(user)
config = UserConversationConfig.objects.filter(user=user).first()
if config:
return config.setting
return ConversationAdapters.get_advanced_chat_model(user)
if subscribed:
# Subscibed users can use any available chat model
if config:
return config.setting
# Fallback to the default advanced chat model
return ConversationAdapters.get_advanced_chat_model(user)
else:
# Non-subscribed users can use any free chat model
if config and config.setting.price_tier == PriceTier.FREE:
return config.setting
# Fallback to the default chat model
return ConversationAdapters.get_default_chat_model(user)
@staticmethod
async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_chat_model(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if config:
return config.setting
return ConversationAdapters.aget_advanced_chat_model(user)
if subscribed:
# Subscibed users can use any available chat model
if config:
return config.setting
# Fallback to the default advanced chat model
return await ConversationAdapters.aget_advanced_chat_model(user)
else:
# Non-subscribed users can use any free chat model
if config and config.setting.price_tier == PriceTier.FREE:
return config.setting
# Fallback to the default chat model
return await ConversationAdapters.aget_default_chat_model(user)
@staticmethod
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):

View File

@@ -0,0 +1,34 @@
# Generated by Django 5.1.8 on 2025-04-18 15:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0088_ratelimitrecord"),
]
operations = [
migrations.AddField(
model_name="chatmodel",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="speechtotextmodeloptions",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="voicemodeloption",
name="price_tier",
field=models.CharField(
choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20
),
),
]

View File

@@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel):
return self.name
class PriceTier(models.TextChoices):
FREE = "free"
STANDARD = "standard"
class ChatModel(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
@@ -207,6 +212,7 @@ class ChatModel(DbBaseModel):
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
description = models.TextField(default=None, null=True, blank=True)
@@ -219,6 +225,7 @@ class ChatModel(DbBaseModel):
class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD)
class Agent(DbBaseModel):
@@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel):
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
)
def clean(self):
error = {}
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
@@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel):
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel):
model_name = models.CharField(max_length=200, default="base")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
def __str__(self):