mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Track Price tier for each Chat, Speech, Image, Voice AI model in DB
Enables users on free plan to choose AI models marked for free tier
This commit is contained in:
@@ -48,6 +48,7 @@ from khoj.database.models import (
|
||||
KhojApiUser,
|
||||
KhojUser,
|
||||
NotionConfig,
|
||||
PriceTier,
|
||||
ProcessLock,
|
||||
PublicConversation,
|
||||
RateLimitRecord,
|
||||
@@ -1153,22 +1154,36 @@ class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_chat_model(user: KhojUser):
|
||||
subscribed = is_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
config = UserConversationConfig.objects.filter(user=user).first()
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.get_advanced_chat_model(user)
|
||||
if subscribed:
|
||||
# Subscibed users can use any available chat model
|
||||
if config:
|
||||
return config.setting
|
||||
# Fallback to the default advanced chat model
|
||||
return ConversationAdapters.get_advanced_chat_model(user)
|
||||
else:
|
||||
# Non-subscribed users can use any free chat model
|
||||
if config and config.setting.price_tier == PriceTier.FREE:
|
||||
return config.setting
|
||||
# Fallback to the default chat model
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_chat_model(user: KhojUser):
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.aget_advanced_chat_model(user)
|
||||
if subscribed:
|
||||
# Subscibed users can use any available chat model
|
||||
if config:
|
||||
return config.setting
|
||||
# Fallback to the default advanced chat model
|
||||
return await ConversationAdapters.aget_advanced_chat_model(user)
|
||||
else:
|
||||
# Non-subscribed users can use any free chat model
|
||||
if config and config.setting.price_tier == PriceTier.FREE:
|
||||
return config.setting
|
||||
# Fallback to the default chat model
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# Generated by Django 5.1.8 on 2025-04-18 15:15
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0088_ratelimitrecord"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="chatmodel",
|
||||
name="price_tier",
|
||||
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="speechtotextmodeloptions",
|
||||
name="price_tier",
|
||||
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="price_tier",
|
||||
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="voicemodeloption",
|
||||
name="price_tier",
|
||||
field=models.CharField(
|
||||
choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel):
|
||||
return self.name
|
||||
|
||||
|
||||
class PriceTier(models.TextChoices):
|
||||
FREE = "free"
|
||||
STANDARD = "standard"
|
||||
|
||||
|
||||
class ChatModel(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
@@ -207,6 +212,7 @@ class ChatModel(DbBaseModel):
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
vision_enabled = models.BooleanField(default=False)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
description = models.TextField(default=None, null=True, blank=True)
|
||||
@@ -219,6 +225,7 @@ class ChatModel(DbBaseModel):
|
||||
class VoiceModelOption(DbBaseModel):
|
||||
model_id = models.CharField(max_length=200)
|
||||
name = models.CharField(max_length=200)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD)
|
||||
|
||||
|
||||
class Agent(DbBaseModel):
|
||||
@@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel):
|
||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||
)
|
||||
|
||||
def clean(self):
|
||||
error = {}
|
||||
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
|
||||
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
|
||||
if error:
|
||||
raise ValidationError(error)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.clean()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class LocalOrgConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
@@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
@@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel):
|
||||
|
||||
model_name = models.CharField(max_length=200, default="base")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
Reference in New Issue
Block a user