Allow AI model switching based on User Tier (#1151)

Overview
---
Enable free tier users to chat with any AI model made available on free tier 
of production deployments like [Khoj cloud](https://app.khoj.dev).

Previously model switching was completely disabled for users on free tier.

Details
---
- Track price tier of each Chat, Speech, Image, Voice AI model in DB
- Update API to allow free tier users to switch between free models
- Update web app to allow model switching on agent creation, settings
  chat page (via right side pane), even for free tier users.
This commit is contained in:
Debanjum
2025-04-19 18:14:37 +05:30
committed by GitHub
14 changed files with 246 additions and 59 deletions

View File

@@ -49,7 +49,7 @@ jobs:
- name: 📂 Copy Generated Files
run: |
mkdir -p src/khoj/interface/compiled
cp -r /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/
cp -r /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/
- name: ⚙️ Build Python Package
run: |

View File

@@ -24,9 +24,7 @@ import {
ChatOptions,
} from "../components/chatInputArea/chatInputArea";
import { useAuthenticatedData } from "../common/auth";
import {
AgentData,
} from "@/app/components/agentCard/agentCard";
import { AgentData } from "@/app/components/agentCard/agentCard";
import { ChatSessionActionMenu } from "../components/allConversations/allConversations";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar";
@@ -50,6 +48,7 @@ interface ChatBodyDataProps {
setTriggeredAbort: (triggeredAbort: boolean) => void;
isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean;
}
function ChatBodyData(props: ChatBodyDataProps) {
@@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) {
</div>
<ChatSidebar
conversationId={conversationId}
isActive={props.isActive}
isOpen={props.isChatSideBarOpen}
onOpenChange={props.setIsChatSideBarOpen}
isMobileWidth={props.isMobileWidth} />
isMobileWidth={props.isMobileWidth}
/>
</div>
);
}
@@ -480,12 +481,13 @@ export default function Chat() {
setTriggeredAbort={setTriggeredAbort}
isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active}
/>
</Suspense>
</div>
</div>
</div>
</SidebarInset>
</SidebarProvider >
</SidebarProvider>
);
}

View File

@@ -33,6 +33,7 @@ export function useAuthenticatedData() {
export interface ModelOptions {
id: number;
name: string;
tier: string;
description: string;
strengths: string;
}

View File

@@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton";
interface ModelSelectorProps extends PopoverProps {
onSelect: (model: ModelOptions) => void;
disabled?: boolean;
isActive?: boolean;
initialModel?: string;
}
@@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
setSelectedModel(model)
setOpen(false)
}}
isActive={props.isActive}
/>
))}
</CommandGroup>
@@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
setSelectedModel(model)
setOpen(false)
}}
isActive={props.isActive}
/>
))}
</CommandGroup>
@@ -184,9 +187,10 @@ interface ModelItemProps {
isSelected: boolean,
onSelect: () => void,
onPeek: (model: ModelOptions) => void
isActive?: boolean
}
function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) {
const ref = React.useRef<HTMLDivElement>(null)
useMutationObserver(ref, (mutations) => {
@@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
onSelect={onSelect}
ref={ref}
className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground"
disabled={!isActive && model.tier !== "free"}
>
{model.name}
{model.name} {model.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
<Check
className={cn("ml-auto", isSelected ? "opacity-100" : "opacity-0")}
/>

View File

@@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
<p>Which chat model would you like to use?</p>
)}
</FormDescription>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!props.isSubscribed}
>
<Select onValueChange={field.onChange} defaultValue={field.value}>
<FormControl>
<SelectTrigger className="text-left dark:bg-muted">
<SelectValue />
@@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
<SelectItem
key={modelOption.id}
value={modelOption.name}
disabled={
!props.isSubscribed &&
modelOption.tier !== "free"
}
>
<div className="flex items-center space-x-2">
{modelOption.name}
{modelOption.name}{" "}
{modelOption.tier === "standard" && (
<span className="text-green-500 ml-2">
(Futurist)
</span>
)}
</div>
</SelectItem>
))}

View File

@@ -34,6 +34,7 @@ interface ChatSideBarProps {
isOpen: boolean;
isMobileWidth?: boolean;
onOpenChange: (open: boolean) => void;
isActive?: boolean;
}
const fetcher = (url: string) => fetch(url).then((res) => res.json());
@@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
<SidebarMenu className="p-0 m-0">
<SidebarMenuItem key={"model"} className="list-none">
<ModelSelector
disabled={!isEditable || !isSubscribed}
disabled={!isEditable}
onSelect={(model) => handleModelSelect(model.name)}
initialModel={isDefaultAgent ? undefined : agentData?.chat_model}
isActive={props.isActive}
/>
</SidebarMenuItem>
</SidebarMenu>

View File

@@ -77,10 +77,11 @@ import { saveAs } from 'file-saver';
interface DropdownComponentProps {
items: ModelOptions[];
selected: number;
isActive?: boolean;
callbackFunc: (value: string) => Promise<void>;
}
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, callbackFunc }) => {
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, isActive, callbackFunc }) => {
const [position, setPosition] = useState(selected?.toString() ?? "0");
return (
@@ -111,8 +112,9 @@ const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected,
<DropdownMenuRadioItem
key={item.id.toString()}
value={item.id.toString()}
disabled={!isActive && item.tier !== "free"}
>
{item.name}
{item.name} {item.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
</DropdownMenuRadioItem>
))}
</DropdownMenuRadioGroup>
@@ -520,33 +522,44 @@ export default function SettingsView() {
}
};
const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active) {
const updateModel = (modelType: string) => async (id: string) => {
// Get the selected model from the options
const modelOptions = modelType === "chat"
? userConfig?.chat_model_options
: modelType === "paint"
? userConfig?.paint_model_options
: userConfig?.voice_model_options;
const selectedModel = modelOptions?.find(model => model.id.toString() === id);
const modelName = selectedModel?.name;
// Check if the model is free tier or if the user is active
if (!userConfig?.is_active && selectedModel?.tier !== "free") {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
description: `Subscribe to switch ${modelType} model to ${modelName}.`,
variant: "destructive",
});
return;
}
try {
const response = await fetch(`/api/model/${name}?id=` + id, {
const response = await fetch(`/api/model/${modelType}?id=` + id, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) throw new Error("Failed to update model");
if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`);
toast({
title: `Updated ${toTitleCase(name)} Model`,
title: `Switched ${modelType} model to ${modelName}`,
});
} catch (error) {
console.error(`Failed to update ${name} model:`, error);
console.error(`Failed to update ${modelType} model to ${modelName}:`, error);
toast({
description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`,
description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`,
variant: "destructive",
});
}
@@ -1103,13 +1116,16 @@ export default function SettingsView() {
selected={
userConfig.selected_chat_model_config
}
isActive={userConfig.is_active}
callbackFunc={updateModel("chat")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && (
<p className="text-gray-400">
Subscribe to switch model
{userConfig.chat_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p>
)}
</CardFooter>
@@ -1131,13 +1147,16 @@ export default function SettingsView() {
selected={
userConfig.selected_paint_model_config
}
isActive={userConfig.is_active}
callbackFunc={updateModel("paint")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && (
<p className="text-gray-400">
Subscribe to switch model
{userConfig.paint_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p>
)}
</CardFooter>
@@ -1159,13 +1178,16 @@ export default function SettingsView() {
selected={
userConfig.selected_voice_model_config
}
isActive={userConfig.is_active}
callbackFunc={updateModel("voice")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && (
<p className="text-gray-400">
Subscribe to switch model
{userConfig.voice_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p>
)}
</CardFooter>

View File

@@ -11,7 +11,7 @@
"cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'",
"export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic",
"ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic",
"pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic",
"pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic",
"watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'",
"windowswatch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn windowsexport'",
"windowscollectstatic": "cd ..\\..\\.. && .\\.venv\\Scripts\\Activate.bat && py .\\src\\khoj\\manage.py collectstatic --noinput && .\\.venv\\Scripts\\deactivate.bat && cd ..",

View File

@@ -48,6 +48,7 @@ from khoj.database.models import (
KhojApiUser,
KhojUser,
NotionConfig,
PriceTier,
ProcessLock,
PublicConversation,
RateLimitRecord,
@@ -1153,22 +1154,36 @@ class ConversationAdapters:
@staticmethod
def get_chat_model(user: KhojUser):
subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_chat_model(user)
config = UserConversationConfig.objects.filter(user=user).first()
if config:
return config.setting
return ConversationAdapters.get_advanced_chat_model(user)
if subscribed:
# Subscibed users can use any available chat model
if config:
return config.setting
# Fallback to the default advanced chat model
return ConversationAdapters.get_advanced_chat_model(user)
else:
# Non-subscribed users can use any free chat model
if config and config.setting.price_tier == PriceTier.FREE:
return config.setting
# Fallback to the default chat model
return ConversationAdapters.get_default_chat_model(user)
@staticmethod
async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_chat_model(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if config:
return config.setting
return ConversationAdapters.aget_advanced_chat_model(user)
if subscribed:
# Subscibed users can use any available chat model
if config:
return config.setting
# Fallback to the default advanced chat model
return await ConversationAdapters.aget_advanced_chat_model(user)
else:
# Non-subscribed users can use any free chat model
if config and config.setting.price_tier == PriceTier.FREE:
return config.setting
# Fallback to the default chat model
return await ConversationAdapters.aget_default_chat_model(user)
@staticmethod
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
@@ -1176,6 +1191,12 @@ class ConversationAdapters:
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
return ChatModel.objects.filter(name=chat_model_name).first()
@staticmethod
async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
if ai_model_api_name:
return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst()
return await ChatModel.objects.filter(name=chat_model_name).afirst()
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()

View File

@@ -0,0 +1,34 @@
# Generated by Django 5.1.8 on 2025-04-18 15:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0088_ratelimitrecord"),
]
operations = [
migrations.AddField(
model_name="chatmodel",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="speechtotextmodeloptions",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="voicemodeloption",
name="price_tier",
field=models.CharField(
choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20
),
),
]

View File

@@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel):
return self.name
class PriceTier(models.TextChoices):
FREE = "free"
STANDARD = "standard"
class ChatModel(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
@@ -207,6 +212,7 @@ class ChatModel(DbBaseModel):
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
description = models.TextField(default=None, null=True, blank=True)
@@ -219,6 +225,7 @@ class ChatModel(DbBaseModel):
class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD)
class Agent(DbBaseModel):
@@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel):
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
)
def clean(self):
error = {}
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
@@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel):
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel):
model_name = models.CharField(max_length=200, default="base")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
def __str__(self):

View File

@@ -12,7 +12,7 @@ from pydantic import BaseModel
from starlette.authentication import has_required_scope, requires
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.models import Agent, Conversation, KhojUser
from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import (
ConversationCommand,
@@ -125,8 +125,20 @@ async def get_agent_by_conversation(
else:
agent = await AgentAdapters.aget_default_agent()
if agent is None:
return Response(
content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}),
media_type="application/json",
status_code=404,
)
chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
if is_subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = chat_model.name
else:
agent_chat_model = None
has_files = agent.fileobject_set.exists()
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
@@ -137,7 +149,7 @@ async def get_agent_by_conversation(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name if is_subscribed else None,
"chat_model": agent_chat_model,
"has_files": has_files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -249,7 +261,11 @@ async def update_hidden_agent(
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
@@ -264,7 +280,7 @@ async def update_hidden_agent(
user=user,
slug=body.slug,
persona=body.persona,
chat_model=chat_model,
chat_model=agent_chat_model,
input_tools=body.input_tools,
output_modes=body.output_modes,
existing_agent=selected_agent,
@@ -295,7 +311,11 @@ async def create_hidden_agent(
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
if not conversation:
@@ -320,7 +340,7 @@ async def create_hidden_agent(
user=user,
slug=body.slug,
persona=body.persona,
chat_model=chat_model,
chat_model=agent_chat_model,
input_tools=body.input_tools,
output_modes=body.output_modes,
existing_agent=None,
@@ -364,7 +384,11 @@ async def create_agent(
)
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
agent = await AgentAdapters.aupdate_agent(
user,
@@ -373,7 +397,7 @@ async def create_agent(
body.privacy_level,
body.icon,
body.color,
chat_model,
agent_chat_model,
body.files,
body.input_tools,
body.output_modes,
@@ -431,7 +455,11 @@ async def update_agent(
)
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
agent = await AgentAdapters.aupdate_agent(
user,
@@ -440,7 +468,7 @@ async def update_agent(
body.privacy_level,
body.icon,
body.color,
chat_model,
agent_chat_model,
body.files,
body.input_tools,
body.output_modes,

View File

@@ -2,13 +2,18 @@ import json
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
ChatModel,
PriceTier,
TextToImageModelConfig,
VoiceModelOption,
)
from khoj.routers.helpers import update_telemetry_state
api_model = APIRouter()
@@ -53,13 +58,24 @@ def get_user_chat_model(
@api_model.post("/chat", status_code=200)
@requires(["authenticated", "premium"])
@requires(["authenticated"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Validate if model can be switched
chat_model = await ChatModel.objects.filter(id=int(id)).afirst()
if chat_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
if not subscribed and chat_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
)
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
@@ -78,13 +94,24 @@ async def update_chat_model(
@api_model.post("/voice", status_code=200)
@requires(["authenticated", "premium"])
@requires(["authenticated"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Validate if model can be switched
voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst()
if voice_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
if not subscribed and voice_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
)
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
@@ -111,8 +138,15 @@ async def update_paint_model(
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
# Validate if model can be switched
image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst()
if image_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
if not subscribed and image_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
)
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))

View File

@@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"id": chat_model.id,
"strengths": chat_model.strengths,
"description": chat_model.description,
"tier": chat_model.price_tier,
}
)
@@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id})
all_paint_model_options.append(
{
"name": paint_model.model_name,
"id": paint_model.id,
"tier": paint_model.price_tier,
}
)
voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list()
for voice_model in voice_models:
voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id})
voice_model_options.append(
{
"name": voice_model.name,
"id": voice_model.model_id,
"tier": voice_model.price_tier,
}
)
if len(voice_model_options) == 0:
eleven_labs_enabled = False