= ({ items, selected,
- {item.name}
+ {item.name} {item.tier === "standard" && (Futurist)}
))}
@@ -520,33 +522,44 @@ export default function SettingsView() {
}
};
- const updateModel = (name: string) => async (id: string) => {
- if (!userConfig?.is_active) {
+ const updateModel = (modelType: string) => async (id: string) => {
+ // Get the selected model from the options
+ const modelOptions = modelType === "chat"
+ ? userConfig?.chat_model_options
+ : modelType === "paint"
+ ? userConfig?.paint_model_options
+ : userConfig?.voice_model_options;
+
+ const selectedModel = modelOptions?.find(model => model.id.toString() === id);
+ const modelName = selectedModel?.name;
+
+ // Check if the model is free tier or if the user is active
+ if (!userConfig?.is_active && selectedModel?.tier !== "free") {
toast({
title: `Model Update`,
- description: `You need to be subscribed to update ${name} models`,
+ description: `Subscribe to switch ${modelType} model to ${modelName}.`,
variant: "destructive",
});
return;
}
try {
- const response = await fetch(`/api/model/${name}?id=` + id, {
+ const response = await fetch(`/api/model/${modelType}?id=` + id, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
- if (!response.ok) throw new Error("Failed to update model");
+ if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`);
toast({
- title: `✅ Updated ${toTitleCase(name)} Model`,
+ title: `✅ Switched ${modelType} model to ${modelName}`,
});
} catch (error) {
- console.error(`Failed to update ${name} model:`, error);
+ console.error(`Failed to update ${modelType} model to ${modelName}:`, error);
toast({
- description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`,
+ description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`,
variant: "destructive",
});
}
@@ -1103,13 +1116,16 @@ export default function SettingsView() {
selected={
userConfig.selected_chat_model_config
}
+ isActive={userConfig.is_active}
callbackFunc={updateModel("chat")}
/>
{!userConfig.is_active && (
- Subscribe to switch model
+ {userConfig.chat_model_options.some(model => model.tier === "free")
+ ? "Free models available"
+ : "Subscribe to switch model"}
)}
@@ -1131,13 +1147,16 @@ export default function SettingsView() {
selected={
userConfig.selected_paint_model_config
}
+ isActive={userConfig.is_active}
callbackFunc={updateModel("paint")}
/>
{!userConfig.is_active && (
- Subscribe to switch model
+ {userConfig.paint_model_options.some(model => model.tier === "free")
+ ? "Free models available"
+ : "Subscribe to switch model"}
)}
@@ -1159,13 +1178,16 @@ export default function SettingsView() {
selected={
userConfig.selected_voice_model_config
}
+ isActive={userConfig.is_active}
callbackFunc={updateModel("voice")}
/>
{!userConfig.is_active && (
- Subscribe to switch model
+ {userConfig.voice_model_options.some(model => model.tier === "free")
+ ? "Free models available"
+ : "Subscribe to switch model"}
)}
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 28e3a04c..92846020 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -1191,6 +1191,12 @@ class ConversationAdapters:
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
return ChatModel.objects.filter(name=chat_model_name).first()
+ @staticmethod
+ async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
+ if ai_model_api_name:
+ return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst()
+ return await ChatModel.objects.filter(name=chat_model_name).afirst()
+
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py
index f97ee584..117663a7 100644
--- a/src/khoj/routers/api_agents.py
+++ b/src/khoj/routers/api_agents.py
@@ -12,7 +12,7 @@ from pydantic import BaseModel
from starlette.authentication import has_required_scope, requires
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
-from khoj.database.models import Agent, Conversation, KhojUser
+from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import (
ConversationCommand,
@@ -125,8 +125,20 @@ async def get_agent_by_conversation(
else:
agent = await AgentAdapters.aget_default_agent()
+ if agent is None:
+ return Response(
+ content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}),
+ media_type="application/json",
+ status_code=404,
+ )
+
+ chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
+ if is_subscribed or chat_model.price_tier == PriceTier.FREE:
+ agent_chat_model = chat_model.name
+ else:
+ agent_chat_model = None
+
has_files = agent.fileobject_set.exists()
- agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
@@ -137,7 +149,7 @@ async def get_agent_by_conversation(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
- "chat_model": agent.chat_model.name if is_subscribed else None,
+ "chat_model": agent_chat_model,
"has_files": has_files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -249,7 +261,11 @@ async def update_hidden_agent(
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
- chat_model = body.chat_model if subscribed else None
+ chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
+ if subscribed or chat_model.price_tier == PriceTier.FREE:
+ agent_chat_model = body.chat_model
+ else:
+ agent_chat_model = None
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
@@ -264,7 +280,7 @@ async def update_hidden_agent(
user=user,
slug=body.slug,
persona=body.persona,
- chat_model=chat_model,
+ chat_model=agent_chat_model,
input_tools=body.input_tools,
output_modes=body.output_modes,
existing_agent=selected_agent,
@@ -295,7 +311,11 @@ async def create_hidden_agent(
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
- chat_model = body.chat_model if subscribed else None
+ chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
+ if subscribed or chat_model.price_tier == PriceTier.FREE:
+ agent_chat_model = body.chat_model
+ else:
+ agent_chat_model = None
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
if not conversation:
@@ -320,7 +340,7 @@ async def create_hidden_agent(
user=user,
slug=body.slug,
persona=body.persona,
- chat_model=chat_model,
+ chat_model=agent_chat_model,
input_tools=body.input_tools,
output_modes=body.output_modes,
existing_agent=None,
@@ -364,7 +384,11 @@ async def create_agent(
)
subscribed = has_required_scope(request, ["premium"])
- chat_model = body.chat_model if subscribed else None
+ chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
+ if subscribed or chat_model.price_tier == PriceTier.FREE:
+ agent_chat_model = body.chat_model
+ else:
+ agent_chat_model = None
agent = await AgentAdapters.aupdate_agent(
user,
@@ -373,7 +397,7 @@ async def create_agent(
body.privacy_level,
body.icon,
body.color,
- chat_model,
+ agent_chat_model,
body.files,
body.input_tools,
body.output_modes,
@@ -431,7 +455,11 @@ async def update_agent(
)
subscribed = has_required_scope(request, ["premium"])
- chat_model = body.chat_model if subscribed else None
+ chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
+ if subscribed or chat_model.price_tier == PriceTier.FREE:
+ agent_chat_model = body.chat_model
+ else:
+ agent_chat_model = None
agent = await AgentAdapters.aupdate_agent(
user,
@@ -440,7 +468,7 @@ async def update_agent(
body.privacy_level,
body.icon,
body.color,
- chat_model,
+ agent_chat_model,
body.files,
body.input_tools,
body.output_modes,
diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py
index 26404c3f..ac37eb0f 100644
--- a/src/khoj/routers/api_model.py
+++ b/src/khoj/routers/api_model.py
@@ -2,13 +2,18 @@ import json
import logging
from typing import Dict, Optional, Union
-from fastapi import APIRouter, HTTPException, Request
+from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
-from khoj.database import adapters
-from khoj.database.adapters import ConversationAdapters, EntryAdapters
+from khoj.database.adapters import ConversationAdapters
+from khoj.database.models import (
+ ChatModel,
+ PriceTier,
+ TextToImageModelConfig,
+ VoiceModelOption,
+)
from khoj.routers.helpers import update_telemetry_state
api_model = APIRouter()
@@ -53,13 +58,24 @@ def get_user_chat_model(
@api_model.post("/chat", status_code=200)
-@requires(["authenticated", "premium"])
+@requires(["authenticated"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
+ subscribed = has_required_scope(request, ["premium"])
+
+ # Validate if model can be switched
+ chat_model = await ChatModel.objects.filter(id=int(id)).afirst()
+ if chat_model is None:
+ return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
+ if not subscribed and chat_model.price_tier != PriceTier.FREE:
+ raise Response(
+ status_code=403,
+ content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
+ )
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
@@ -78,13 +94,24 @@ async def update_chat_model(
@api_model.post("/voice", status_code=200)
-@requires(["authenticated", "premium"])
+@requires(["authenticated"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
+ subscribed = has_required_scope(request, ["premium"])
+
+ # Validate if model can be switched
+ voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst()
+ if voice_model is None:
+ return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
+ if not subscribed and voice_model.price_tier != PriceTier.FREE:
+ raise Response(
+ status_code=403,
+ content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
+ )
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
@@ -111,8 +138,15 @@ async def update_paint_model(
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
- if not subscribed:
- raise HTTPException(status_code=403, detail="User is not subscribed to premium")
+ # Validate if model can be switched
+ image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst()
+ if image_model is None:
+ return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
+ if not subscribed and image_model.price_tier != PriceTier.FREE:
+ raise Response(
+ status_code=403,
+ content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
+ )
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 5779fab6..443810e8 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"id": chat_model.id,
"strengths": chat_model.strengths,
"description": chat_model.description,
+ "tier": chat_model.price_tier,
}
)
@@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
- all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id})
+ all_paint_model_options.append(
+ {
+ "name": paint_model.model_name,
+ "id": paint_model.id,
+ "tier": paint_model.price_tier,
+ }
+ )
voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list()
for voice_model in voice_models:
- voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id})
+ voice_model_options.append(
+ {
+ "name": voice_model.name,
+ "id": voice_model.model_id,
+ "tier": voice_model.price_tier,
+ }
+ )
if len(voice_model_options) == 0:
eleven_labs_enabled = False