From 30570e3e06549062ae125b5854f1b27cf4ca50a7 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 1 Apr 2025 11:41:47 +0530 Subject: [PATCH] Track Price tier for each Chat, Speech, Image, Voice AI model in DB Enables users on free plan to choose AI models marked for free tier --- src/khoj/database/adapters/__init__.py | 35 +++++++++++++------ .../0089_chatmodel_price_tier_and_more.py | 34 ++++++++++++++++++ src/khoj/database/models/__init__.py | 20 +++++++++++ 3 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 086be4b0..28e3a04c 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -48,6 +48,7 @@ from khoj.database.models import ( KhojApiUser, KhojUser, NotionConfig, + PriceTier, ProcessLock, PublicConversation, RateLimitRecord, @@ -1153,22 +1154,36 @@ class ConversationAdapters: @staticmethod def get_chat_model(user: KhojUser): subscribed = is_user_subscribed(user) - if not subscribed: - return ConversationAdapters.get_default_chat_model(user) config = UserConversationConfig.objects.filter(user=user).first() - if config: - return config.setting - return ConversationAdapters.get_advanced_chat_model(user) + if subscribed: + # Subscibed users can use any available chat model + if config: + return config.setting + # Fallback to the default advanced chat model + return ConversationAdapters.get_advanced_chat_model(user) + else: + # Non-subscribed users can use any free chat model + if config and config.setting.price_tier == PriceTier.FREE: + return config.setting + # Fallback to the default chat model + return ConversationAdapters.get_default_chat_model(user) @staticmethod async def aget_chat_model(user: KhojUser): subscribed = await ais_user_subscribed(user) - if not subscribed: - return await ConversationAdapters.aget_default_chat_model(user) config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() - if config: - return config.setting - return ConversationAdapters.aget_advanced_chat_model(user) + if subscribed: + # Subscibed users can use any available chat model + if config: + return config.setting + # Fallback to the default advanced chat model + return await ConversationAdapters.aget_advanced_chat_model(user) + else: + # Non-subscribed users can use any free chat model + if config and config.setting.price_tier == PriceTier.FREE: + return config.setting + # Fallback to the default chat model + return await ConversationAdapters.aget_default_chat_model(user) @staticmethod def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): diff --git a/src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py b/src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py new file mode 100644 index 00000000..9c363098 --- /dev/null +++ b/src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py @@ -0,0 +1,34 @@ +# Generated by Django 5.1.8 on 2025-04-18 15:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0088_ratelimitrecord"), + ] + + operations = [ + migrations.AddField( + model_name="chatmodel", + name="price_tier", + field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20), + ), + migrations.AddField( + model_name="speechtotextmodeloptions", + name="price_tier", + field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20), + ), + migrations.AddField( + model_name="texttoimagemodelconfig", + name="price_tier", + field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20), + ), + migrations.AddField( + model_name="voicemodeloption", + name="price_tier", + field=models.CharField( + choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20 + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 429d010e..bd49aa8c 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel): return self.name +class PriceTier(models.TextChoices): + FREE = "free" + STANDARD = "standard" + + class ChatModel(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" @@ -207,6 +212,7 @@ class ChatModel(DbBaseModel): tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) vision_enabled = models.BooleanField(default=False) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) description = models.TextField(default=None, null=True, blank=True) @@ -219,6 +225,7 @@ class ChatModel(DbBaseModel): class VoiceModelOption(DbBaseModel): model_id = models.CharField(max_length=200) name = models.CharField(max_length=200) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD) class Agent(DbBaseModel): @@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel): WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" ) + def clean(self): + error = {} + if self.chat_default and self.chat_default.price_tier != PriceTier.FREE: + error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model." + if error: + raise ValidationError(error) + + def save(self, *args, **kwargs): + self.clean() + super().save(*args, **kwargs) + class LocalOrgConfig(DbBaseModel): input_files = models.JSONField(default=list, null=True) @@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel): model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) api_key = models.CharField(max_length=200, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) @@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel): model_name = models.CharField(max_length=200, default="base") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) def __str__(self):