diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 6ec8ffe3..ee3bb9b3 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -49,7 +49,7 @@ jobs: - name: 📂 Copy Generated Files run: | mkdir -p src/khoj/interface/compiled - cp -r /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/ + cp -r /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/ - name: ⚙️ Build Python Package run: | diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 2428bb1b..e4926ccc 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -24,9 +24,7 @@ import { ChatOptions, } from "../components/chatInputArea/chatInputArea"; import { useAuthenticatedData } from "../common/auth"; -import { - AgentData, -} from "@/app/components/agentCard/agentCard"; +import { AgentData } from "@/app/components/agentCard/agentCard"; import { ChatSessionActionMenu } from "../components/allConversations/allConversations"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; @@ -50,6 +48,7 @@ interface ChatBodyDataProps { setTriggeredAbort: (triggeredAbort: boolean) => void; isChatSideBarOpen: boolean; setIsChatSideBarOpen: (open: boolean) => void; + isActive?: boolean; } function ChatBodyData(props: ChatBodyDataProps) { @@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) { + isMobileWidth={props.isMobileWidth} + /> ); } @@ -480,12 +481,13 @@ export default function Chat() { setTriggeredAbort={setTriggeredAbort} isChatSideBarOpen={isChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen} + isActive={authenticatedData?.is_active} /> - + ); } diff --git a/src/interface/web/app/common/auth.ts b/src/interface/web/app/common/auth.ts index 917fa7da..4716fa13 100644 --- a/src/interface/web/app/common/auth.ts +++ b/src/interface/web/app/common/auth.ts @@ -33,6 +33,7 @@ export function useAuthenticatedData() { export interface ModelOptions { id: number; name: string; + tier: string; description: string; strengths: string; } diff --git a/src/interface/web/app/common/modelSelector.tsx b/src/interface/web/app/common/modelSelector.tsx index 70c171b0..d234d5b5 100644 --- a/src/interface/web/app/common/modelSelector.tsx +++ b/src/interface/web/app/common/modelSelector.tsx @@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton"; interface ModelSelectorProps extends PopoverProps { onSelect: (model: ModelOptions) => void; disabled?: boolean; + isActive?: boolean; initialModel?: string; } @@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) { setSelectedModel(model) setOpen(false) }} + isActive={props.isActive} /> ))} @@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) { setSelectedModel(model) setOpen(false) }} + isActive={props.isActive} /> ))} @@ -184,9 +187,10 @@ interface ModelItemProps { isSelected: boolean, onSelect: () => void, onPeek: (model: ModelOptions) => void + isActive?: boolean } -function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { +function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) { const ref = React.useRef(null) useMutationObserver(ref, (mutations) => { @@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { onSelect={onSelect} ref={ref} className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground" + disabled={!isActive && model.tier !== "free"} > - {model.name} + {model.name} {model.tier === "standard" && (Futurist)} diff --git a/src/interface/web/app/components/agentCard/agentCard.tsx b/src/interface/web/app/components/agentCard/agentCard.tsx index dfbade2a..04bc26a3 100644 --- a/src/interface/web/app/components/agentCard/agentCard.tsx +++ b/src/interface/web/app/components/agentCard/agentCard.tsx @@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {

Which chat model would you like to use?

)} - @@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
- {modelOption.name} + {modelOption.name}{" "} + {modelOption.tier === "standard" && ( + + (Futurist) + + )}
))} diff --git a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx index 1e7c0325..1d0b4c0d 100644 --- a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx +++ b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx @@ -34,6 +34,7 @@ interface ChatSideBarProps { isOpen: boolean; isMobileWidth?: boolean; onOpenChange: (open: boolean) => void; + isActive?: boolean; } const fetcher = (url: string) => fetch(url).then((res) => res.json()); @@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) { handleModelSelect(model.name)} initialModel={isDefaultAgent ? undefined : agentData?.chat_model} + isActive={props.isActive} /> diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index c0d5fc8c..5983c591 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -77,10 +77,11 @@ import { saveAs } from 'file-saver'; interface DropdownComponentProps { items: ModelOptions[]; selected: number; + isActive?: boolean; callbackFunc: (value: string) => Promise; } -const DropdownComponent: React.FC = ({ items, selected, callbackFunc }) => { +const DropdownComponent: React.FC = ({ items, selected, isActive, callbackFunc }) => { const [position, setPosition] = useState(selected?.toString() ?? "0"); return ( @@ -111,8 +112,9 @@ const DropdownComponent: React.FC = ({ items, selected, - {item.name} + {item.name} {item.tier === "standard" && (Futurist)} ))} @@ -520,33 +522,44 @@ export default function SettingsView() { } }; - const updateModel = (name: string) => async (id: string) => { - if (!userConfig?.is_active) { + const updateModel = (modelType: string) => async (id: string) => { + // Get the selected model from the options + const modelOptions = modelType === "chat" + ? userConfig?.chat_model_options + : modelType === "paint" + ? userConfig?.paint_model_options + : userConfig?.voice_model_options; + + const selectedModel = modelOptions?.find(model => model.id.toString() === id); + const modelName = selectedModel?.name; + + // Check if the model is free tier or if the user is active + if (!userConfig?.is_active && selectedModel?.tier !== "free") { toast({ title: `Model Update`, - description: `You need to be subscribed to update ${name} models`, + description: `Subscribe to switch ${modelType} model to ${modelName}.`, variant: "destructive", }); return; } try { - const response = await fetch(`/api/model/${name}?id=` + id, { + const response = await fetch(`/api/model/${modelType}?id=` + id, { method: "POST", headers: { "Content-Type": "application/json", }, }); - if (!response.ok) throw new Error("Failed to update model"); + if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`); toast({ - title: `✅ Updated ${toTitleCase(name)} Model`, + title: `✅ Switched ${modelType} model to ${modelName}`, }); } catch (error) { - console.error(`Failed to update ${name} model:`, error); + console.error(`Failed to update ${modelType} model to ${modelName}:`, error); toast({ - description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`, + description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`, variant: "destructive", }); } @@ -1103,13 +1116,16 @@ export default function SettingsView() { selected={ userConfig.selected_chat_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("chat")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.chat_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
@@ -1131,13 +1147,16 @@ export default function SettingsView() { selected={ userConfig.selected_paint_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("paint")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.paint_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
@@ -1159,13 +1178,16 @@ export default function SettingsView() { selected={ userConfig.selected_voice_model_config } + isActive={userConfig.is_active} callbackFunc={updateModel("voice")} /> {!userConfig.is_active && (

- Subscribe to switch model + {userConfig.voice_model_options.some(model => model.tier === "free") + ? "Free models available" + : "Subscribe to switch model"}

)}
diff --git a/src/interface/web/package.json b/src/interface/web/package.json index ae9fa15f..d82edcac 100644 --- a/src/interface/web/package.json +++ b/src/interface/web/package.json @@ -11,7 +11,7 @@ "cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'", "export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic", "ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic", - "pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic", + "pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic", "watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'", "windowswatch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn windowsexport'", "windowscollectstatic": "cd ..\\..\\.. && .\\.venv\\Scripts\\Activate.bat && py .\\src\\khoj\\manage.py collectstatic --noinput && .\\.venv\\Scripts\\deactivate.bat && cd ..", diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 086be4b0..92846020 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -48,6 +48,7 @@ from khoj.database.models import ( KhojApiUser, KhojUser, NotionConfig, + PriceTier, ProcessLock, PublicConversation, RateLimitRecord, @@ -1153,22 +1154,36 @@ class ConversationAdapters: @staticmethod def get_chat_model(user: KhojUser): subscribed = is_user_subscribed(user) - if not subscribed: - return ConversationAdapters.get_default_chat_model(user) config = UserConversationConfig.objects.filter(user=user).first() - if config: - return config.setting - return ConversationAdapters.get_advanced_chat_model(user) + if subscribed: + # Subscibed users can use any available chat model + if config: + return config.setting + # Fallback to the default advanced chat model + return ConversationAdapters.get_advanced_chat_model(user) + else: + # Non-subscribed users can use any free chat model + if config and config.setting.price_tier == PriceTier.FREE: + return config.setting + # Fallback to the default chat model + return ConversationAdapters.get_default_chat_model(user) @staticmethod async def aget_chat_model(user: KhojUser): subscribed = await ais_user_subscribed(user) - if not subscribed: - return await ConversationAdapters.aget_default_chat_model(user) config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() - if config: - return config.setting - return ConversationAdapters.aget_advanced_chat_model(user) + if subscribed: + # Subscibed users can use any available chat model + if config: + return config.setting + # Fallback to the default advanced chat model + return await ConversationAdapters.aget_advanced_chat_model(user) + else: + # Non-subscribed users can use any free chat model + if config and config.setting.price_tier == PriceTier.FREE: + return config.setting + # Fallback to the default chat model + return await ConversationAdapters.aget_default_chat_model(user) @staticmethod def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): @@ -1176,6 +1191,12 @@ class ConversationAdapters: return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first() return ChatModel.objects.filter(name=chat_model_name).first() + @staticmethod + async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): + if ai_model_api_name: + return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst() + return await ChatModel.objects.filter(name=chat_model_name).afirst() + @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() diff --git a/src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py b/src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py new file mode 100644 index 00000000..9c363098 --- /dev/null +++ b/src/khoj/database/migrations/0089_chatmodel_price_tier_and_more.py @@ -0,0 +1,34 @@ +# Generated by Django 5.1.8 on 2025-04-18 15:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0088_ratelimitrecord"), + ] + + operations = [ + migrations.AddField( + model_name="chatmodel", + name="price_tier", + field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20), + ), + migrations.AddField( + model_name="speechtotextmodeloptions", + name="price_tier", + field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20), + ), + migrations.AddField( + model_name="texttoimagemodelconfig", + name="price_tier", + field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20), + ), + migrations.AddField( + model_name="voicemodeloption", + name="price_tier", + field=models.CharField( + choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20 + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 429d010e..bd49aa8c 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel): return self.name +class PriceTier(models.TextChoices): + FREE = "free" + STANDARD = "standard" + + class ChatModel(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" @@ -207,6 +212,7 @@ class ChatModel(DbBaseModel): tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) vision_enabled = models.BooleanField(default=False) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) description = models.TextField(default=None, null=True, blank=True) @@ -219,6 +225,7 @@ class ChatModel(DbBaseModel): class VoiceModelOption(DbBaseModel): model_id = models.CharField(max_length=200) name = models.CharField(max_length=200) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD) class Agent(DbBaseModel): @@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel): WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" ) + def clean(self): + error = {} + if self.chat_default and self.chat_default.price_tier != PriceTier.FREE: + error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model." + if error: + raise ValidationError(error) + + def save(self, *args, **kwargs): + self.clean() + super().save(*args, **kwargs) + class LocalOrgConfig(DbBaseModel): input_files = models.JSONField(default=list, null=True) @@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel): model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) api_key = models.CharField(max_length=200, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) @@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel): model_name = models.CharField(max_length=200, default="base") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) def __str__(self): diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index f97ee584..117663a7 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from starlette.authentication import has_required_scope, requires from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters -from khoj.database.models import Agent, Conversation, KhojUser +from khoj.database.models import Agent, Conversation, KhojUser, PriceTier from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt from khoj.utils.helpers import ( ConversationCommand, @@ -125,8 +125,20 @@ async def get_agent_by_conversation( else: agent = await AgentAdapters.aget_default_agent() + if agent is None: + return Response( + content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}), + media_type="application/json", + status_code=404, + ) + + chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) + if is_subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = chat_model.name + else: + agent_chat_model = None + has_files = agent.fileobject_set.exists() - agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, @@ -137,7 +149,7 @@ async def get_agent_by_conversation( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.name if is_subscribed else None, + "chat_model": agent_chat_model, "has_files": has_files, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -249,7 +261,11 @@ async def update_hidden_agent( user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user) @@ -264,7 +280,7 @@ async def update_hidden_agent( user=user, slug=body.slug, persona=body.persona, - chat_model=chat_model, + chat_model=agent_chat_model, input_tools=body.input_tools, output_modes=body.output_modes, existing_agent=selected_agent, @@ -295,7 +311,11 @@ async def create_hidden_agent( user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id) if not conversation: @@ -320,7 +340,7 @@ async def create_hidden_agent( user=user, slug=body.slug, persona=body.persona, - chat_model=chat_model, + chat_model=agent_chat_model, input_tools=body.input_tools, output_modes=body.output_modes, existing_agent=None, @@ -364,7 +384,11 @@ async def create_agent( ) subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None agent = await AgentAdapters.aupdate_agent( user, @@ -373,7 +397,7 @@ async def create_agent( body.privacy_level, body.icon, body.color, - chat_model, + agent_chat_model, body.files, body.input_tools, body.output_modes, @@ -431,7 +455,11 @@ async def update_agent( ) subscribed = has_required_scope(request, ["premium"]) - chat_model = body.chat_model if subscribed else None + chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model) + if subscribed or chat_model.price_tier == PriceTier.FREE: + agent_chat_model = body.chat_model + else: + agent_chat_model = None agent = await AgentAdapters.aupdate_agent( user, @@ -440,7 +468,7 @@ async def update_agent( body.privacy_level, body.icon, body.color, - chat_model, + agent_chat_model, body.files, body.input_tools, body.output_modes, diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index 26404c3f..ac37eb0f 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -2,13 +2,18 @@ import json import logging from typing import Dict, Optional, Union -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Request from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import has_required_scope, requires -from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import ( + ChatModel, + PriceTier, + TextToImageModelConfig, + VoiceModelOption, +) from khoj.routers.helpers import update_telemetry_state api_model = APIRouter() @@ -53,13 +58,24 @@ def get_user_chat_model( @api_model.post("/chat", status_code=200) -@requires(["authenticated", "premium"]) +@requires(["authenticated"]) async def update_chat_model( request: Request, id: str, client: Optional[str] = None, ): user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + # Validate if model can be switched + chat_model = await ChatModel.objects.filter(id=int(id)).afirst() + if chat_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"})) + if not subscribed and chat_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}), + ) new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) @@ -78,13 +94,24 @@ async def update_chat_model( @api_model.post("/voice", status_code=200) -@requires(["authenticated", "premium"]) +@requires(["authenticated"]) async def update_voice_model( request: Request, id: str, client: Optional[str] = None, ): user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + # Validate if model can be switched + voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst() + if voice_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"})) + if not subscribed and voice_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}), + ) new_config = await ConversationAdapters.aset_user_voice_model(user, id) @@ -111,8 +138,15 @@ async def update_paint_model( user = request.user.object subscribed = has_required_scope(request, ["premium"]) - if not subscribed: - raise HTTPException(status_code=403, detail="User is not subscribed to premium") + # Validate if model can be switched + image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst() + if image_model is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"})) + if not subscribed and image_model.price_tier != PriceTier.FREE: + raise Response( + status_code=403, + content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}), + ) new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5779fab6..443810e8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) "id": chat_model.id, "strengths": chat_model.strengths, "description": chat_model.description, + "tier": chat_model.price_tier, } ) @@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() all_paint_model_options = list() for paint_model in paint_model_options: - all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id}) + all_paint_model_options.append( + { + "name": paint_model.model_name, + "id": paint_model.id, + "tier": paint_model.price_tier, + } + ) voice_models = ConversationAdapters.get_voice_model_options() voice_model_options = list() for voice_model in voice_models: - voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id}) + voice_model_options.append( + { + "name": voice_model.name, + "id": voice_model.model_id, + "tier": voice_model.price_tier, + } + ) if len(voice_model_options) == 0: eleven_labs_enabled = False