diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 2428bb1b..e4926ccc 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -24,9 +24,7 @@ import { ChatOptions, } from "../components/chatInputArea/chatInputArea"; import { useAuthenticatedData } from "../common/auth"; -import { - AgentData, -} from "@/app/components/agentCard/agentCard"; +import { AgentData } from "@/app/components/agentCard/agentCard"; import { ChatSessionActionMenu } from "../components/allConversations/allConversations"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; @@ -50,6 +48,7 @@ interface ChatBodyDataProps { setTriggeredAbort: (triggeredAbort: boolean) => void; isChatSideBarOpen: boolean; setIsChatSideBarOpen: (open: boolean) => void; + isActive?: boolean; } function ChatBodyData(props: ChatBodyDataProps) { @@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) { + isMobileWidth={props.isMobileWidth} + /> ); } @@ -480,12 +481,13 @@ export default function Chat() { setTriggeredAbort={setTriggeredAbort} isChatSideBarOpen={isChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen} + isActive={authenticatedData?.is_active} /> - + ); } diff --git a/src/interface/web/app/common/auth.ts b/src/interface/web/app/common/auth.ts index 917fa7da..4716fa13 100644 --- a/src/interface/web/app/common/auth.ts +++ b/src/interface/web/app/common/auth.ts @@ -33,6 +33,7 @@ export function useAuthenticatedData() { export interface ModelOptions { id: number; name: string; + tier: string; description: string; strengths: string; } diff --git a/src/interface/web/app/common/modelSelector.tsx b/src/interface/web/app/common/modelSelector.tsx index 70c171b0..d234d5b5 100644 --- a/src/interface/web/app/common/modelSelector.tsx +++ b/src/interface/web/app/common/modelSelector.tsx @@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton"; interface ModelSelectorProps extends PopoverProps { onSelect: (model: ModelOptions) => void; disabled?: boolean; + isActive?: boolean; initialModel?: string; } @@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) { setSelectedModel(model) setOpen(false) }} + isActive={props.isActive} /> ))} @@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) { setSelectedModel(model) setOpen(false) }} + isActive={props.isActive} /> ))} @@ -184,9 +187,10 @@ interface ModelItemProps { isSelected: boolean, onSelect: () => void, onPeek: (model: ModelOptions) => void + isActive?: boolean } -function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { +function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) { const ref = React.useRef(null) useMutationObserver(ref, (mutations) => { @@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { onSelect={onSelect} ref={ref} className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground" + disabled={!isActive && model.tier !== "free"} > - {model.name} + {model.name} {model.tier === "standard" && (Futurist)} diff --git a/src/interface/web/app/components/agentCard/agentCard.tsx b/src/interface/web/app/components/agentCard/agentCard.tsx index dfbade2a..04bc26a3 100644 --- a/src/interface/web/app/components/agentCard/agentCard.tsx +++ b/src/interface/web/app/components/agentCard/agentCard.tsx @@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {

Which chat model would you like to use?

)} - @@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
- {modelOption.name} + {modelOption.name}{" "} + {modelOption.tier === "standard" && ( + + (Futurist) + + )}
))} diff --git a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx index 1e7c0325..1d0b4c0d 100644 --- a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx +++ b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx @@ -34,6 +34,7 @@ interface ChatSideBarProps { isOpen: boolean; isMobileWidth?: boolean; onOpenChange: (open: boolean) => void; + isActive?: boolean; } const fetcher = (url: string) => fetch(url).then((res) => res.json()); @@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) { handleModelSelect(model.name)} initialModel={isDefaultAgent ? undefined : agentData?.chat_model} + isActive={props.isActive} /> diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index c0d5fc8c..5983c591 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -77,10 +77,11 @@ import { saveAs } from 'file-saver'; interface DropdownComponentProps { items: ModelOptions[]; selected: number; + isActive?: boolean; callbackFunc: (value: string) => Promise; } -const DropdownComponent: React.FC = ({ items, selected, callbackFunc }) => { +const DropdownComponent: React.FC = ({ items, selected, isActive, callbackFunc }) => { const [position, setPosition] = useState(selected?.toString() ?? "0"); return ( @@ -111,8 +112,9 @@ const DropdownComponent: React.FC = ({ items, selected, - {item.name} + {item.name} {item.tier === "standard" && (Futurist)} ))} @@ -520,33 +522,44 @@ export default function SettingsView() { } }; - const updateModel = (name: string) => async (id: string) => { - if (!userConfig?.is_active) { + const updateModel = (modelType: string) => async (id: string) => { + // Get the selected model from the options + const modelOptions = modelType === "chat" + ? userConfig?.chat_model_options + : modelType === "paint" + ? userConfig?.paint_model_options + : userConfig?.voice_model_options; + + const selectedModel = modelOptions?.find(model => model.id.toString() === id); + const modelName = selectedModel?.name; + + // Check if the model is free tier or if the user is active + if (!userConfig?.is_active && selectedModel?.tier !== "free") { toast({ title: `Model Update`, - description: `You need to be subscribed to update ${name} models`, + description: `Subscribe to switch ${modelType} model to ${modelName}.`, variant: "destructive", }); return; } try { - const response = await fetch(`/api/model/${name}?id=` + id, { + const response = await fetch(`/api/model/${modelType}?id=` + id, { method: "POST", headers: { "Content-Type": "application/json", }, }); - if (!response.ok) throw new Error("Failed to update model"); + if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`); toast({ - title: `✅ Updated ${toTitleCase(name)} Model`, + title: `✅ Switched ${modelType} model to ${modelName}`, }); } catch (error) { - console.error(`Failed to update ${name} model:`, error); + console.error(`Failed to update ${modelType} model to ${modelName}:`, error); toast({ - description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`, + description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`, variant: "destructive", }); } @@ -1103,13 +1116,16 @@ export default function SettingsView() { selected={ userConfig.selected_chat_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("chat")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.chat_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
@@ -1131,13 +1147,16 @@ export default function SettingsView() { selected={ userConfig.selected_paint_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("paint")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.paint_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
@@ -1159,13 +1178,16 @@ export default function SettingsView() { selected={ userConfig.selected_voice_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("voice")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.voice_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 28e3a04c..92846020 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1191,6 +1191,12 @@ class ConversationAdapters: return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first() return ChatModel.objects.filter(name=chat_model_name).first() + @staticmethod + async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): + if ai_model_api_name: + return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst() + return await ChatModel.objects.filter(name=chat_model_name).afirst() + @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index f97ee584..117663a7 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from starlette.authentication import has_required_scope, requires from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters -from khoj.database.models import Agent, Conversation, KhojUser +from khoj.database.models import Agent, Conversation, KhojUser, PriceTier from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt from khoj.utils.helpers import ( ConversationCommand, @@ -125,8 +125,20 @@ async def get_agent_by_conversation( else: agent = await AgentAdapters.aget_default_agent() + if agent is None: + return Response( + content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}), + media_type="application/json", + status_code=404, + ) + + chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) + if is_subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = chat_model.name + else: + agent_chat_model = None + has_files = agent.fileobject_set.exists() - agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, @@ -137,7 +149,7 @@ async def get_agent_by_conversation( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.name if is_subscribed else None, + "chat_model": agent_chat_model, "has_files": has_files, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -249,7 +261,11 @@ async def update_hidden_agent( user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user) @@ -264,7 +280,7 @@ async def update_hidden_agent( user=user, slug=body.slug, persona=body.persona, - chat_model=chat_model, + chat_model=agent_chat_model, input_tools=body.input_tools, output_modes=body.output_modes, existing_agent=selected_agent, @@ -295,7 +311,11 @@ async def create_hidden_agent( user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id) if not conversation: @@ -320,7 +340,7 @@ async def create_hidden_agent( user=user, slug=body.slug, persona=body.persona, - chat_model=chat_model, + chat_model=agent_chat_model, input_tools=body.input_tools, output_modes=body.output_modes, existing_agent=None, @@ -364,7 +384,11 @@ async def create_agent( ) subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None agent = await AgentAdapters.aupdate_agent( user, @@ -373,7 +397,7 @@ async def create_agent( body.privacy_level, body.icon, body.color, - chat_model, + agent_chat_model, body.files, body.input_tools, body.output_modes, @@ -431,7 +455,11 @@ async def update_agent( ) subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None agent = await AgentAdapters.aupdate_agent( user, @@ -440,7 +468,7 @@ async def update_agent( body.privacy_level, body.icon, body.color, - chat_model, + agent_chat_model, body.files, body.input_tools, body.output_modes, diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index 26404c3f..ac37eb0f 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -2,13 +2,18 @@ import json import logging from typing import Dict, Optional, Union -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Request from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import has_required_scope, requires -from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import ( + ChatModel, + PriceTier, + TextToImageModelConfig, + VoiceModelOption, +) from khoj.routers.helpers import update_telemetry_state api_model = APIRouter() @@ -53,13 +58,24 @@ def get_user_chat_model( @api_model.post("/chat", status_code=200) -@requires(["authenticated", "premium"]) +@requires(["authenticated"]) async def update_chat_model( request: Request, id: str, client: Optional[str] = None, ): user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + # Validate if model can be switched + chat_model = await ChatModel.objects.filter(id=int(id)).afirst() + if chat_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"})) + if not subscribed and chat_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}), + ) new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) @@ -78,13 +94,24 @@ async def update_chat_model( @api_model.post("/voice", status_code=200) -@requires(["authenticated", "premium"]) +@requires(["authenticated"]) async def update_voice_model( request: Request, id: str, client: Optional[str] = None, ): user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + # Validate if model can be switched + voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst() + if voice_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"})) + if not subscribed and voice_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}), + ) new_config = await ConversationAdapters.aset_user_voice_model(user, id) @@ -111,8 +138,15 @@ async def update_paint_model( user = request.user.object subscribed = has_required_scope(request, ["premium"]) - if not subscribed: - raise HTTPException(status_code=403, detail="User is not subscribed to premium") + # Validate if model can be switched + image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst() + if image_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"})) + if not subscribed and image_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}), + ) new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5779fab6..443810e8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) "id": chat_model.id, "strengths": chat_model.strengths, "description": chat_model.description, + "tier": chat_model.price_tier, } ) @@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() all_paint_model_options = list() for paint_model in paint_model_options: - all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id}) + all_paint_model_options.append( + { + "name": paint_model.model_name, + "id": paint_model.id, + "tier": paint_model.price_tier, + } + ) voice_models = ConversationAdapters.get_voice_model_options() voice_model_options = list() for voice_model in voice_models: - voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id}) + voice_model_options.append( + { + "name": voice_model.name, + "id": voice_model.model_id, + "tier": voice_model.price_tier, + } + ) if len(voice_model_options) == 0: eleven_labs_enabled = False