From 79fc911633b80bf70ea58e9adcda11ae3a8f8b6c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 1 Apr 2025 11:54:09 +0530 Subject: [PATCH] Enable free tier users to switch between free tier AI models - Update API to allow free tier users to switch between free models - Update web app to allow model switching on agent creation, settings chat page (via right side pane), even for free tier users. Previously the model switching APIs and UX fields on web app were completely disabled for free tier users --- src/interface/web/app/chat/page.tsx | 12 +++-- src/interface/web/app/common/auth.ts | 1 + .../web/app/common/modelSelector.tsx | 9 +++- .../app/components/agentCard/agentCard.tsx | 17 ++++--- .../components/chatSidebar/chatSidebar.tsx | 4 +- src/interface/web/app/settings/page.tsx | 48 +++++++++++++----- src/khoj/database/adapters/__init__.py | 6 +++ src/khoj/routers/api_agents.py | 50 +++++++++++++++---- src/khoj/routers/api_model.py | 48 +++++++++++++++--- src/khoj/routers/helpers.py | 17 ++++++- 10 files changed, 165 insertions(+), 47 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 2428bb1b..e4926ccc 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -24,9 +24,7 @@ import { ChatOptions, } from "../components/chatInputArea/chatInputArea"; import { useAuthenticatedData } from "../common/auth"; -import { - AgentData, -} from "@/app/components/agentCard/agentCard"; +import { AgentData } from "@/app/components/agentCard/agentCard"; import { ChatSessionActionMenu } from "../components/allConversations/allConversations"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; @@ -50,6 +48,7 @@ interface ChatBodyDataProps { setTriggeredAbort: (triggeredAbort: boolean) => void; isChatSideBarOpen: boolean; setIsChatSideBarOpen: (open: boolean) => void; + isActive?: boolean; } function ChatBodyData(props: ChatBodyDataProps) { @@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) { + isMobileWidth={props.isMobileWidth} + /> ); } @@ -480,12 +481,13 @@ export default function Chat() { setTriggeredAbort={setTriggeredAbort} isChatSideBarOpen={isChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen} + isActive={authenticatedData?.is_active} /> - + ); } diff --git a/src/interface/web/app/common/auth.ts b/src/interface/web/app/common/auth.ts index 917fa7da..4716fa13 100644 --- a/src/interface/web/app/common/auth.ts +++ b/src/interface/web/app/common/auth.ts @@ -33,6 +33,7 @@ export function useAuthenticatedData() { export interface ModelOptions { id: number; name: string; + tier: string; description: string; strengths: string; } diff --git a/src/interface/web/app/common/modelSelector.tsx b/src/interface/web/app/common/modelSelector.tsx index 70c171b0..d234d5b5 100644 --- a/src/interface/web/app/common/modelSelector.tsx +++ b/src/interface/web/app/common/modelSelector.tsx @@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton"; interface ModelSelectorProps extends PopoverProps { onSelect: (model: ModelOptions) => void; disabled?: boolean; + isActive?: boolean; initialModel?: string; } @@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) { setSelectedModel(model) setOpen(false) }} + isActive={props.isActive} /> ))} @@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) { setSelectedModel(model) setOpen(false) }} + isActive={props.isActive} /> ))} @@ -184,9 +187,10 @@ interface ModelItemProps { isSelected: boolean, onSelect: () => void, onPeek: (model: ModelOptions) => void + isActive?: boolean } -function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { +function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) { const ref = React.useRef(null) useMutationObserver(ref, (mutations) => { @@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { onSelect={onSelect} ref={ref} className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground" + disabled={!isActive && model.tier !== "free"} > - {model.name} + {model.name} {model.tier === "standard" && (Futurist)} diff --git a/src/interface/web/app/components/agentCard/agentCard.tsx b/src/interface/web/app/components/agentCard/agentCard.tsx index dfbade2a..04bc26a3 100644 --- a/src/interface/web/app/components/agentCard/agentCard.tsx +++ b/src/interface/web/app/components/agentCard/agentCard.tsx @@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {

Which chat model would you like to use?

)} - @@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
- {modelOption.name} + {modelOption.name}{" "} + {modelOption.tier === "standard" && ( + + (Futurist) + + )}
))} diff --git a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx index 1e7c0325..1d0b4c0d 100644 --- a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx +++ b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx @@ -34,6 +34,7 @@ interface ChatSideBarProps { isOpen: boolean; isMobileWidth?: boolean; onOpenChange: (open: boolean) => void; + isActive?: boolean; } const fetcher = (url: string) => fetch(url).then((res) => res.json()); @@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) { handleModelSelect(model.name)} initialModel={isDefaultAgent ? undefined : agentData?.chat_model} + isActive={props.isActive} /> diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index c0d5fc8c..5983c591 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -77,10 +77,11 @@ import { saveAs } from 'file-saver'; interface DropdownComponentProps { items: ModelOptions[]; selected: number; + isActive?: boolean; callbackFunc: (value: string) => Promise; } -const DropdownComponent: React.FC = ({ items, selected, callbackFunc }) => { +const DropdownComponent: React.FC = ({ items, selected, isActive, callbackFunc }) => { const [position, setPosition] = useState(selected?.toString() ?? "0"); return ( @@ -111,8 +112,9 @@ const DropdownComponent: React.FC = ({ items, selected, - {item.name} + {item.name} {item.tier === "standard" && (Futurist)} ))} @@ -520,33 +522,44 @@ export default function SettingsView() { } }; - const updateModel = (name: string) => async (id: string) => { - if (!userConfig?.is_active) { + const updateModel = (modelType: string) => async (id: string) => { + // Get the selected model from the options + const modelOptions = modelType === "chat" + ? userConfig?.chat_model_options + : modelType === "paint" + ? userConfig?.paint_model_options + : userConfig?.voice_model_options; + + const selectedModel = modelOptions?.find(model => model.id.toString() === id); + const modelName = selectedModel?.name; + + // Check if the model is free tier or if the user is active + if (!userConfig?.is_active && selectedModel?.tier !== "free") { toast({ title: `Model Update`, - description: `You need to be subscribed to update ${name} models`, + description: `Subscribe to switch ${modelType} model to ${modelName}.`, variant: "destructive", }); return; } try { - const response = await fetch(`/api/model/${name}?id=` + id, { + const response = await fetch(`/api/model/${modelType}?id=` + id, { method: "POST", headers: { "Content-Type": "application/json", }, }); - if (!response.ok) throw new Error("Failed to update model"); + if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`); toast({ - title: `✅ Updated ${toTitleCase(name)} Model`, + title: `✅ Switched ${modelType} model to ${modelName}`, }); } catch (error) { - console.error(`Failed to update ${name} model:`, error); + console.error(`Failed to update ${modelType} model to ${modelName}:`, error); toast({ - description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`, + description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`, variant: "destructive", }); } @@ -1103,13 +1116,16 @@ export default function SettingsView() { selected={ userConfig.selected_chat_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("chat")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.chat_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
@@ -1131,13 +1147,16 @@ export default function SettingsView() { selected={ userConfig.selected_paint_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("paint")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.paint_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
@@ -1159,13 +1178,16 @@ export default function SettingsView() { selected={ userConfig.selected_voice_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("voice")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.voice_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 28e3a04c..92846020 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1191,6 +1191,12 @@ class ConversationAdapters: return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first() return ChatModel.objects.filter(name=chat_model_name).first() + @staticmethod + async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): + if ai_model_api_name: + return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst() + return await ChatModel.objects.filter(name=chat_model_name).afirst() + @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index f97ee584..117663a7 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from starlette.authentication import has_required_scope, requires from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters -from khoj.database.models import Agent, Conversation, KhojUser +from khoj.database.models import Agent, Conversation, KhojUser, PriceTier from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt from khoj.utils.helpers import ( ConversationCommand, @@ -125,8 +125,20 @@ async def get_agent_by_conversation( else: agent = await AgentAdapters.aget_default_agent() + if agent is None: + return Response( + content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}), + media_type="application/json", + status_code=404, + ) + + chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) + if is_subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = chat_model.name + else: + agent_chat_model = None + has_files = agent.fileobject_set.exists() - agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, @@ -137,7 +149,7 @@ async def get_agent_by_conversation( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.name if is_subscribed else None, + "chat_model": agent_chat_model, "has_files": has_files, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -249,7 +261,11 @@ async def update_hidden_agent( user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user) @@ -264,7 +280,7 @@ async def update_hidden_agent( user=user, slug=body.slug, persona=body.persona, - chat_model=chat_model, + chat_model=agent_chat_model, input_tools=body.input_tools, output_modes=body.output_modes, existing_agent=selected_agent, @@ -295,7 +311,11 @@ async def create_hidden_agent( user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id) if not conversation: @@ -320,7 +340,7 @@ async def create_hidden_agent( user=user, slug=body.slug, persona=body.persona, - chat_model=chat_model, + chat_model=agent_chat_model, input_tools=body.input_tools, output_modes=body.output_modes, existing_agent=None, @@ -364,7 +384,11 @@ async def create_agent( ) subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None agent = await AgentAdapters.aupdate_agent( user, @@ -373,7 +397,7 @@ async def create_agent( body.privacy_level, body.icon, body.color, - chat_model, + agent_chat_model, body.files, body.input_tools, body.output_modes, @@ -431,7 +455,11 @@ async def update_agent( ) subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None agent = await AgentAdapters.aupdate_agent( user, @@ -440,7 +468,7 @@ async def update_agent( body.privacy_level, body.icon, body.color, - chat_model, + agent_chat_model, body.files, body.input_tools, body.output_modes, diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index 26404c3f..ac37eb0f 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -2,13 +2,18 @@ import json import logging from typing import Dict, Optional, Union -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Request from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import has_required_scope, requires -from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import ( + ChatModel, + PriceTier, + TextToImageModelConfig, + VoiceModelOption, +) from khoj.routers.helpers import update_telemetry_state api_model = APIRouter() @@ -53,13 +58,24 @@ def get_user_chat_model( @api_model.post("/chat", status_code=200) -@requires(["authenticated", "premium"]) +@requires(["authenticated"]) async def update_chat_model( request: Request, id: str, client: Optional[str] = None, ): user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + # Validate if model can be switched + chat_model = await ChatModel.objects.filter(id=int(id)).afirst() + if chat_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"})) + if not subscribed and chat_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}), + ) new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) @@ -78,13 +94,24 @@ async def update_chat_model( @api_model.post("/voice", status_code=200) -@requires(["authenticated", "premium"]) +@requires(["authenticated"]) async def update_voice_model( request: Request, id: str, client: Optional[str] = None, ): user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + # Validate if model can be switched + voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst() + if voice_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"})) + if not subscribed and voice_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}), + ) new_config = await ConversationAdapters.aset_user_voice_model(user, id) @@ -111,8 +138,15 @@ async def update_paint_model( user = request.user.object subscribed = has_required_scope(request, ["premium"]) - if not subscribed: - raise HTTPException(status_code=403, detail="User is not subscribed to premium") + # Validate if model can be switched + image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst() + if image_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"})) + if not subscribed and image_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}), + ) new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5779fab6..443810e8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) "id": chat_model.id, "strengths": chat_model.strengths, "description": chat_model.description, + "tier": chat_model.price_tier, } ) @@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() all_paint_model_options = list() for paint_model in paint_model_options: - all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id}) + all_paint_model_options.append( + { + "name": paint_model.model_name, + "id": paint_model.id, + "tier": paint_model.price_tier, + } + ) voice_models = ConversationAdapters.get_voice_model_options() voice_model_options = list() for voice_model in voice_models: - voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id}) + voice_model_options.append( + { + "name": voice_model.name, + "id": voice_model.model_id, + "tier": voice_model.price_tier, + } + ) if len(voice_model_options) == 0: eleven_labs_enabled = False