Enable free tier users to switch between free tier AI models

- Update API to allow free tier users to switch between free models
- Update web app to allow model switching on agent creation, settings
  chat page (via right side pane), even for free tier users.

Previously the model switching APIs and UX fields on web app were
completely disabled for free tier users
This commit is contained in:
Debanjum
2025-04-01 11:54:09 +05:30
parent 30570e3e06
commit 79fc911633
10 changed files with 165 additions and 47 deletions

View File

@@ -24,9 +24,7 @@ import {
ChatOptions,
} from "../components/chatInputArea/chatInputArea";
import { useAuthenticatedData } from "../common/auth";
import {
AgentData,
} from "@/app/components/agentCard/agentCard";
import { AgentData } from "@/app/components/agentCard/agentCard";
import { ChatSessionActionMenu } from "../components/allConversations/allConversations";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar";
@@ -50,6 +48,7 @@ interface ChatBodyDataProps {
setTriggeredAbort: (triggeredAbort: boolean) => void;
isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean;
}
function ChatBodyData(props: ChatBodyDataProps) {
@@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) {
</div>
<ChatSidebar
conversationId={conversationId}
isActive={props.isActive}
isOpen={props.isChatSideBarOpen}
onOpenChange={props.setIsChatSideBarOpen}
isMobileWidth={props.isMobileWidth} />
isMobileWidth={props.isMobileWidth}
/>
</div>
);
}
@@ -480,12 +481,13 @@ export default function Chat() {
setTriggeredAbort={setTriggeredAbort}
isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active}
/>
</Suspense>
</div>
</div>
</div>
</SidebarInset>
</SidebarProvider >
</SidebarProvider>
);
}

View File

@@ -33,6 +33,7 @@ export function useAuthenticatedData() {
export interface ModelOptions {
id: number;
name: string;
tier: string;
description: string;
strengths: string;
}

View File

@@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton";
interface ModelSelectorProps extends PopoverProps {
onSelect: (model: ModelOptions) => void;
disabled?: boolean;
isActive?: boolean;
initialModel?: string;
}
@@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
setSelectedModel(model)
setOpen(false)
}}
isActive={props.isActive}
/>
))}
</CommandGroup>
@@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
setSelectedModel(model)
setOpen(false)
}}
isActive={props.isActive}
/>
))}
</CommandGroup>
@@ -184,9 +187,10 @@ interface ModelItemProps {
isSelected: boolean,
onSelect: () => void,
onPeek: (model: ModelOptions) => void
isActive?: boolean
}
function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) {
const ref = React.useRef<HTMLDivElement>(null)
useMutationObserver(ref, (mutations) => {
@@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
onSelect={onSelect}
ref={ref}
className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground"
disabled={!isActive && model.tier !== "free"}
>
{model.name}
{model.name} {model.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
<Check
className={cn("ml-auto", isSelected ? "opacity-100" : "opacity-0")}
/>

View File

@@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
<p>Which chat model would you like to use?</p>
)}
</FormDescription>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!props.isSubscribed}
>
<Select onValueChange={field.onChange} defaultValue={field.value}>
<FormControl>
<SelectTrigger className="text-left dark:bg-muted">
<SelectValue />
@@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
<SelectItem
key={modelOption.id}
value={modelOption.name}
disabled={
!props.isSubscribed &&
modelOption.tier !== "free"
}
>
<div className="flex items-center space-x-2">
{modelOption.name}
{modelOption.name}{" "}
{modelOption.tier === "standard" && (
<span className="text-green-500 ml-2">
(Futurist)
</span>
)}
</div>
</SelectItem>
))}

View File

@@ -34,6 +34,7 @@ interface ChatSideBarProps {
isOpen: boolean;
isMobileWidth?: boolean;
onOpenChange: (open: boolean) => void;
isActive?: boolean;
}
const fetcher = (url: string) => fetch(url).then((res) => res.json());
@@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
<SidebarMenu className="p-0 m-0">
<SidebarMenuItem key={"model"} className="list-none">
<ModelSelector
disabled={!isEditable || !isSubscribed}
disabled={!isEditable}
onSelect={(model) => handleModelSelect(model.name)}
initialModel={isDefaultAgent ? undefined : agentData?.chat_model}
isActive={props.isActive}
/>
</SidebarMenuItem>
</SidebarMenu>

View File

@@ -77,10 +77,11 @@ import { saveAs } from 'file-saver';
interface DropdownComponentProps {
items: ModelOptions[];
selected: number;
isActive?: boolean;
callbackFunc: (value: string) => Promise<void>;
}
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, callbackFunc }) => {
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, isActive, callbackFunc }) => {
const [position, setPosition] = useState(selected?.toString() ?? "0");
return (
@@ -111,8 +112,9 @@ const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected,
<DropdownMenuRadioItem
key={item.id.toString()}
value={item.id.toString()}
disabled={!isActive && item.tier !== "free"}
>
{item.name}
{item.name} {item.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
</DropdownMenuRadioItem>
))}
</DropdownMenuRadioGroup>
@@ -520,33 +522,44 @@ export default function SettingsView() {
}
};
const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active) {
const updateModel = (modelType: string) => async (id: string) => {
// Get the selected model from the options
const modelOptions = modelType === "chat"
? userConfig?.chat_model_options
: modelType === "paint"
? userConfig?.paint_model_options
: userConfig?.voice_model_options;
const selectedModel = modelOptions?.find(model => model.id.toString() === id);
const modelName = selectedModel?.name;
// Check if the model is free tier or if the user is active
if (!userConfig?.is_active && selectedModel?.tier !== "free") {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
description: `Subscribe to switch ${modelType} model to ${modelName}.`,
variant: "destructive",
});
return;
}
try {
const response = await fetch(`/api/model/${name}?id=` + id, {
const response = await fetch(`/api/model/${modelType}?id=` + id, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) throw new Error("Failed to update model");
if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`);
toast({
title: `Updated ${toTitleCase(name)} Model`,
title: `Switched ${modelType} model to ${modelName}`,
});
} catch (error) {
console.error(`Failed to update ${name} model:`, error);
console.error(`Failed to update ${modelType} model to ${modelName}:`, error);
toast({
description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`,
description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`,
variant: "destructive",
});
}
@@ -1103,13 +1116,16 @@ export default function SettingsView() {
selected={
userConfig.selected_chat_model_config
}
isActive={userConfig.is_active}
callbackFunc={updateModel("chat")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && (
<p className="text-gray-400">
Subscribe to switch model
{userConfig.chat_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p>
)}
</CardFooter>
@@ -1131,13 +1147,16 @@ export default function SettingsView() {
selected={
userConfig.selected_paint_model_config
}
isActive={userConfig.is_active}
callbackFunc={updateModel("paint")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && (
<p className="text-gray-400">
Subscribe to switch model
{userConfig.paint_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p>
)}
</CardFooter>
@@ -1159,13 +1178,16 @@ export default function SettingsView() {
selected={
userConfig.selected_voice_model_config
}
isActive={userConfig.is_active}
callbackFunc={updateModel("voice")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && (
<p className="text-gray-400">
Subscribe to switch model
{userConfig.voice_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p>
)}
</CardFooter>

View File

@@ -1191,6 +1191,12 @@ class ConversationAdapters:
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
return ChatModel.objects.filter(name=chat_model_name).first()
@staticmethod
async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
if ai_model_api_name:
return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst()
return await ChatModel.objects.filter(name=chat_model_name).afirst()
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()

View File

@@ -12,7 +12,7 @@ from pydantic import BaseModel
from starlette.authentication import has_required_scope, requires
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.models import Agent, Conversation, KhojUser
from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import (
ConversationCommand,
@@ -125,8 +125,20 @@ async def get_agent_by_conversation(
else:
agent = await AgentAdapters.aget_default_agent()
if agent is None:
return Response(
content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}),
media_type="application/json",
status_code=404,
)
chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
if is_subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = chat_model.name
else:
agent_chat_model = None
has_files = agent.fileobject_set.exists()
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
@@ -137,7 +149,7 @@ async def get_agent_by_conversation(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name if is_subscribed else None,
"chat_model": agent_chat_model,
"has_files": has_files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -249,7 +261,11 @@ async def update_hidden_agent(
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
@@ -264,7 +280,7 @@ async def update_hidden_agent(
user=user,
slug=body.slug,
persona=body.persona,
chat_model=chat_model,
chat_model=agent_chat_model,
input_tools=body.input_tools,
output_modes=body.output_modes,
existing_agent=selected_agent,
@@ -295,7 +311,11 @@ async def create_hidden_agent(
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
if not conversation:
@@ -320,7 +340,7 @@ async def create_hidden_agent(
user=user,
slug=body.slug,
persona=body.persona,
chat_model=chat_model,
chat_model=agent_chat_model,
input_tools=body.input_tools,
output_modes=body.output_modes,
existing_agent=None,
@@ -364,7 +384,11 @@ async def create_agent(
)
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
agent = await AgentAdapters.aupdate_agent(
user,
@@ -373,7 +397,7 @@ async def create_agent(
body.privacy_level,
body.icon,
body.color,
chat_model,
agent_chat_model,
body.files,
body.input_tools,
body.output_modes,
@@ -431,7 +455,11 @@ async def update_agent(
)
subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
agent = await AgentAdapters.aupdate_agent(
user,
@@ -440,7 +468,7 @@ async def update_agent(
body.privacy_level,
body.icon,
body.color,
chat_model,
agent_chat_model,
body.files,
body.input_tools,
body.output_modes,

View File

@@ -2,13 +2,18 @@ import json
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
ChatModel,
PriceTier,
TextToImageModelConfig,
VoiceModelOption,
)
from khoj.routers.helpers import update_telemetry_state
api_model = APIRouter()
@@ -53,13 +58,24 @@ def get_user_chat_model(
@api_model.post("/chat", status_code=200)
@requires(["authenticated", "premium"])
@requires(["authenticated"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Validate if model can be switched
chat_model = await ChatModel.objects.filter(id=int(id)).afirst()
if chat_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
if not subscribed and chat_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
)
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
@@ -78,13 +94,24 @@ async def update_chat_model(
@api_model.post("/voice", status_code=200)
@requires(["authenticated", "premium"])
@requires(["authenticated"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Validate if model can be switched
voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst()
if voice_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
if not subscribed and voice_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
)
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
@@ -111,8 +138,15 @@ async def update_paint_model(
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
# Validate if model can be switched
image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst()
if image_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
if not subscribed and image_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
)
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))

View File

@@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"id": chat_model.id,
"strengths": chat_model.strengths,
"description": chat_model.description,
"tier": chat_model.price_tier,
}
)
@@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id})
all_paint_model_options.append(
{
"name": paint_model.model_name,
"id": paint_model.id,
"tier": paint_model.price_tier,
}
)
voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list()
for voice_model in voice_models:
voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id})
voice_model_options.append(
{
"name": voice_model.name,
"id": voice_model.model_id,
"tier": voice_model.price_tier,
}
)
if len(voice_model_options) == 0:
eleven_labs_enabled = False