mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Track Price tier for each Chat, Speech, Image, Voice AI model in DB
Enables users on free plan to choose AI models marked for free tier
This commit is contained in:
@@ -48,6 +48,7 @@ from khoj.database.models import (
|
|||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
|
PriceTier,
|
||||||
ProcessLock,
|
ProcessLock,
|
||||||
PublicConversation,
|
PublicConversation,
|
||||||
RateLimitRecord,
|
RateLimitRecord,
|
||||||
@@ -1153,22 +1154,36 @@ class ConversationAdapters:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_chat_model(user: KhojUser):
|
def get_chat_model(user: KhojUser):
|
||||||
subscribed = is_user_subscribed(user)
|
subscribed = is_user_subscribed(user)
|
||||||
if not subscribed:
|
|
||||||
return ConversationAdapters.get_default_chat_model(user)
|
|
||||||
config = UserConversationConfig.objects.filter(user=user).first()
|
config = UserConversationConfig.objects.filter(user=user).first()
|
||||||
if config:
|
if subscribed:
|
||||||
return config.setting
|
# Subscibed users can use any available chat model
|
||||||
return ConversationAdapters.get_advanced_chat_model(user)
|
if config:
|
||||||
|
return config.setting
|
||||||
|
# Fallback to the default advanced chat model
|
||||||
|
return ConversationAdapters.get_advanced_chat_model(user)
|
||||||
|
else:
|
||||||
|
# Non-subscribed users can use any free chat model
|
||||||
|
if config and config.setting.price_tier == PriceTier.FREE:
|
||||||
|
return config.setting
|
||||||
|
# Fallback to the default chat model
|
||||||
|
return ConversationAdapters.get_default_chat_model(user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_chat_model(user: KhojUser):
|
async def aget_chat_model(user: KhojUser):
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user)
|
||||||
if not subscribed:
|
|
||||||
return await ConversationAdapters.aget_default_chat_model(user)
|
|
||||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
if config:
|
if subscribed:
|
||||||
return config.setting
|
# Subscibed users can use any available chat model
|
||||||
return ConversationAdapters.aget_advanced_chat_model(user)
|
if config:
|
||||||
|
return config.setting
|
||||||
|
# Fallback to the default advanced chat model
|
||||||
|
return await ConversationAdapters.aget_advanced_chat_model(user)
|
||||||
|
else:
|
||||||
|
# Non-subscribed users can use any free chat model
|
||||||
|
if config and config.setting.price_tier == PriceTier.FREE:
|
||||||
|
return config.setting
|
||||||
|
# Fallback to the default chat model
|
||||||
|
return await ConversationAdapters.aget_default_chat_model(user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
# Generated by Django 5.1.8 on 2025-04-18 15:15
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0088_ratelimitrecord"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="chatmodel",
|
||||||
|
name="price_tier",
|
||||||
|
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="speechtotextmodeloptions",
|
||||||
|
name="price_tier",
|
||||||
|
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="price_tier",
|
||||||
|
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="voicemodeloption",
|
||||||
|
name="price_tier",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class PriceTier(models.TextChoices):
|
||||||
|
FREE = "free"
|
||||||
|
STANDARD = "standard"
|
||||||
|
|
||||||
|
|
||||||
class ChatModel(DbBaseModel):
|
class ChatModel(DbBaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
@@ -207,6 +212,7 @@ class ChatModel(DbBaseModel):
|
|||||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
|
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||||
vision_enabled = models.BooleanField(default=False)
|
vision_enabled = models.BooleanField(default=False)
|
||||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
description = models.TextField(default=None, null=True, blank=True)
|
description = models.TextField(default=None, null=True, blank=True)
|
||||||
@@ -219,6 +225,7 @@ class ChatModel(DbBaseModel):
|
|||||||
class VoiceModelOption(DbBaseModel):
|
class VoiceModelOption(DbBaseModel):
|
||||||
model_id = models.CharField(max_length=200)
|
model_id = models.CharField(max_length=200)
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
|
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD)
|
||||||
|
|
||||||
|
|
||||||
class Agent(DbBaseModel):
|
class Agent(DbBaseModel):
|
||||||
@@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel):
|
|||||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
error = {}
|
||||||
|
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
|
||||||
|
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
|
||||||
|
if error:
|
||||||
|
raise ValidationError(error)
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
self.clean()
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LocalOrgConfig(DbBaseModel):
|
class LocalOrgConfig(DbBaseModel):
|
||||||
input_files = models.JSONField(default=list, null=True)
|
input_files = models.JSONField(default=list, null=True)
|
||||||
@@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel):
|
|||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||||
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
@@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel):
|
|||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="base")
|
model_name = models.CharField(max_length=200, default="base")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
|
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user