Allow AI model switching based on User Tier (#1151)

Overview
---
Enable free tier users to chat with any AI model made available on free tier 
of production deployments like [Khoj cloud](https://app.khoj.dev).

Previously model switching was completely disabled for users on free tier.

Details
---
- Track price tier of each Chat, Speech, Image, Voice AI model in DB
- Update API to allow free tier users to switch between free models
- Update web app to allow model switching on agent creation, settings
  chat page (via right side pane), even for free tier users.
This commit is contained in:
Debanjum
2025-04-19 18:14:37 +05:30
committed by GitHub
14 changed files with 246 additions and 59 deletions

View File

@@ -49,7 +49,7 @@ jobs:
- name: 📂 Copy Generated Files - name: 📂 Copy Generated Files
run: | run: |
mkdir -p src/khoj/interface/compiled mkdir -p src/khoj/interface/compiled
cp -r /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/ cp -r /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/
- name: ⚙️ Build Python Package - name: ⚙️ Build Python Package
run: | run: |

View File

@@ -24,9 +24,7 @@ import {
ChatOptions, ChatOptions,
} from "../components/chatInputArea/chatInputArea"; } from "../components/chatInputArea/chatInputArea";
import { useAuthenticatedData } from "../common/auth"; import { useAuthenticatedData } from "../common/auth";
import { import { AgentData } from "@/app/components/agentCard/agentCard";
AgentData,
} from "@/app/components/agentCard/agentCard";
import { ChatSessionActionMenu } from "../components/allConversations/allConversations"; import { ChatSessionActionMenu } from "../components/allConversations/allConversations";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar";
@@ -50,6 +48,7 @@ interface ChatBodyDataProps {
setTriggeredAbort: (triggeredAbort: boolean) => void; setTriggeredAbort: (triggeredAbort: boolean) => void;
isChatSideBarOpen: boolean; isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void; setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean;
} }
function ChatBodyData(props: ChatBodyDataProps) { function ChatBodyData(props: ChatBodyDataProps) {
@@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) {
</div> </div>
<ChatSidebar <ChatSidebar
conversationId={conversationId} conversationId={conversationId}
isActive={props.isActive}
isOpen={props.isChatSideBarOpen} isOpen={props.isChatSideBarOpen}
onOpenChange={props.setIsChatSideBarOpen} onOpenChange={props.setIsChatSideBarOpen}
isMobileWidth={props.isMobileWidth} /> isMobileWidth={props.isMobileWidth}
/>
</div> </div>
); );
} }
@@ -480,12 +481,13 @@ export default function Chat() {
setTriggeredAbort={setTriggeredAbort} setTriggeredAbort={setTriggeredAbort}
isChatSideBarOpen={isChatSideBarOpen} isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active}
/> />
</Suspense> </Suspense>
</div> </div>
</div> </div>
</div> </div>
</SidebarInset> </SidebarInset>
</SidebarProvider > </SidebarProvider>
); );
} }

View File

@@ -33,6 +33,7 @@ export function useAuthenticatedData() {
export interface ModelOptions { export interface ModelOptions {
id: number; id: number;
name: string; name: string;
tier: string;
description: string; description: string;
strengths: string; strengths: string;
} }

View File

@@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton";
interface ModelSelectorProps extends PopoverProps { interface ModelSelectorProps extends PopoverProps {
onSelect: (model: ModelOptions) => void; onSelect: (model: ModelOptions) => void;
disabled?: boolean; disabled?: boolean;
isActive?: boolean;
initialModel?: string; initialModel?: string;
} }
@@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
setSelectedModel(model) setSelectedModel(model)
setOpen(false) setOpen(false)
}} }}
isActive={props.isActive}
/> />
))} ))}
</CommandGroup> </CommandGroup>
@@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
setSelectedModel(model) setSelectedModel(model)
setOpen(false) setOpen(false)
}} }}
isActive={props.isActive}
/> />
))} ))}
</CommandGroup> </CommandGroup>
@@ -184,9 +187,10 @@ interface ModelItemProps {
isSelected: boolean, isSelected: boolean,
onSelect: () => void, onSelect: () => void,
onPeek: (model: ModelOptions) => void onPeek: (model: ModelOptions) => void
isActive?: boolean
} }
function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) { function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) {
const ref = React.useRef<HTMLDivElement>(null) const ref = React.useRef<HTMLDivElement>(null)
useMutationObserver(ref, (mutations) => { useMutationObserver(ref, (mutations) => {
@@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
onSelect={onSelect} onSelect={onSelect}
ref={ref} ref={ref}
className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground" className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground"
disabled={!isActive && model.tier !== "free"}
> >
{model.name} {model.name} {model.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
<Check <Check
className={cn("ml-auto", isSelected ? "opacity-100" : "opacity-0")} className={cn("ml-auto", isSelected ? "opacity-100" : "opacity-0")}
/> />

View File

@@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
<p>Which chat model would you like to use?</p> <p>Which chat model would you like to use?</p>
)} )}
</FormDescription> </FormDescription>
<Select <Select onValueChange={field.onChange} defaultValue={field.value}>
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!props.isSubscribed}
>
<FormControl> <FormControl>
<SelectTrigger className="text-left dark:bg-muted"> <SelectTrigger className="text-left dark:bg-muted">
<SelectValue /> <SelectValue />
@@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
<SelectItem <SelectItem
key={modelOption.id} key={modelOption.id}
value={modelOption.name} value={modelOption.name}
disabled={
!props.isSubscribed &&
modelOption.tier !== "free"
}
> >
<div className="flex items-center space-x-2"> <div className="flex items-center space-x-2">
{modelOption.name} {modelOption.name}{" "}
{modelOption.tier === "standard" && (
<span className="text-green-500 ml-2">
(Futurist)
</span>
)}
</div> </div>
</SelectItem> </SelectItem>
))} ))}

View File

@@ -34,6 +34,7 @@ interface ChatSideBarProps {
isOpen: boolean; isOpen: boolean;
isMobileWidth?: boolean; isMobileWidth?: boolean;
onOpenChange: (open: boolean) => void; onOpenChange: (open: boolean) => void;
isActive?: boolean;
} }
const fetcher = (url: string) => fetch(url).then((res) => res.json()); const fetcher = (url: string) => fetch(url).then((res) => res.json());
@@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
<SidebarMenu className="p-0 m-0"> <SidebarMenu className="p-0 m-0">
<SidebarMenuItem key={"model"} className="list-none"> <SidebarMenuItem key={"model"} className="list-none">
<ModelSelector <ModelSelector
disabled={!isEditable || !isSubscribed} disabled={!isEditable}
onSelect={(model) => handleModelSelect(model.name)} onSelect={(model) => handleModelSelect(model.name)}
initialModel={isDefaultAgent ? undefined : agentData?.chat_model} initialModel={isDefaultAgent ? undefined : agentData?.chat_model}
isActive={props.isActive}
/> />
</SidebarMenuItem> </SidebarMenuItem>
</SidebarMenu> </SidebarMenu>

View File

@@ -77,10 +77,11 @@ import { saveAs } from 'file-saver';
interface DropdownComponentProps { interface DropdownComponentProps {
items: ModelOptions[]; items: ModelOptions[];
selected: number; selected: number;
isActive?: boolean;
callbackFunc: (value: string) => Promise<void>; callbackFunc: (value: string) => Promise<void>;
} }
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, callbackFunc }) => { const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, isActive, callbackFunc }) => {
const [position, setPosition] = useState(selected?.toString() ?? "0"); const [position, setPosition] = useState(selected?.toString() ?? "0");
return ( return (
@@ -111,8 +112,9 @@ const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected,
<DropdownMenuRadioItem <DropdownMenuRadioItem
key={item.id.toString()} key={item.id.toString()}
value={item.id.toString()} value={item.id.toString()}
disabled={!isActive && item.tier !== "free"}
> >
{item.name} {item.name} {item.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
</DropdownMenuRadioItem> </DropdownMenuRadioItem>
))} ))}
</DropdownMenuRadioGroup> </DropdownMenuRadioGroup>
@@ -520,33 +522,44 @@ export default function SettingsView() {
} }
}; };
const updateModel = (name: string) => async (id: string) => { const updateModel = (modelType: string) => async (id: string) => {
if (!userConfig?.is_active) { // Get the selected model from the options
const modelOptions = modelType === "chat"
? userConfig?.chat_model_options
: modelType === "paint"
? userConfig?.paint_model_options
: userConfig?.voice_model_options;
const selectedModel = modelOptions?.find(model => model.id.toString() === id);
const modelName = selectedModel?.name;
// Check if the model is free tier or if the user is active
if (!userConfig?.is_active && selectedModel?.tier !== "free") {
toast({ toast({
title: `Model Update`, title: `Model Update`,
description: `You need to be subscribed to update ${name} models`, description: `Subscribe to switch ${modelType} model to ${modelName}.`,
variant: "destructive", variant: "destructive",
}); });
return; return;
} }
try { try {
const response = await fetch(`/api/model/${name}?id=` + id, { const response = await fetch(`/api/model/${modelType}?id=` + id, {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
}); });
if (!response.ok) throw new Error("Failed to update model"); if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`);
toast({ toast({
title: `Updated ${toTitleCase(name)} Model`, title: `Switched ${modelType} model to ${modelName}`,
}); });
} catch (error) { } catch (error) {
console.error(`Failed to update ${name} model:`, error); console.error(`Failed to update ${modelType} model to ${modelName}:`, error);
toast({ toast({
description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`, description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`,
variant: "destructive", variant: "destructive",
}); });
} }
@@ -1103,13 +1116,16 @@ export default function SettingsView() {
selected={ selected={
userConfig.selected_chat_model_config userConfig.selected_chat_model_config
} }
isActive={userConfig.is_active}
callbackFunc={updateModel("chat")} callbackFunc={updateModel("chat")}
/> />
</CardContent> </CardContent>
<CardFooter className="flex flex-wrap gap-4"> <CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && ( {!userConfig.is_active && (
<p className="text-gray-400"> <p className="text-gray-400">
Subscribe to switch model {userConfig.chat_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p> </p>
)} )}
</CardFooter> </CardFooter>
@@ -1131,13 +1147,16 @@ export default function SettingsView() {
selected={ selected={
userConfig.selected_paint_model_config userConfig.selected_paint_model_config
} }
isActive={userConfig.is_active}
callbackFunc={updateModel("paint")} callbackFunc={updateModel("paint")}
/> />
</CardContent> </CardContent>
<CardFooter className="flex flex-wrap gap-4"> <CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && ( {!userConfig.is_active && (
<p className="text-gray-400"> <p className="text-gray-400">
Subscribe to switch model {userConfig.paint_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p> </p>
)} )}
</CardFooter> </CardFooter>
@@ -1159,13 +1178,16 @@ export default function SettingsView() {
selected={ selected={
userConfig.selected_voice_model_config userConfig.selected_voice_model_config
} }
isActive={userConfig.is_active}
callbackFunc={updateModel("voice")} callbackFunc={updateModel("voice")}
/> />
</CardContent> </CardContent>
<CardFooter className="flex flex-wrap gap-4"> <CardFooter className="flex flex-wrap gap-4">
{!userConfig.is_active && ( {!userConfig.is_active && (
<p className="text-gray-400"> <p className="text-gray-400">
Subscribe to switch model {userConfig.voice_model_options.some(model => model.tier === "free")
? "Free models available"
: "Subscribe to switch model"}
</p> </p>
)} )}
</CardFooter> </CardFooter>

View File

@@ -11,7 +11,7 @@
"cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'", "cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'",
"export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic", "export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic",
"ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic", "ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic",
"pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic", "pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic",
"watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'", "watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'",
"windowswatch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn windowsexport'", "windowswatch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn windowsexport'",
"windowscollectstatic": "cd ..\\..\\.. && .\\.venv\\Scripts\\Activate.bat && py .\\src\\khoj\\manage.py collectstatic --noinput && .\\.venv\\Scripts\\deactivate.bat && cd ..", "windowscollectstatic": "cd ..\\..\\.. && .\\.venv\\Scripts\\Activate.bat && py .\\src\\khoj\\manage.py collectstatic --noinput && .\\.venv\\Scripts\\deactivate.bat && cd ..",

View File

@@ -48,6 +48,7 @@ from khoj.database.models import (
KhojApiUser, KhojApiUser,
KhojUser, KhojUser,
NotionConfig, NotionConfig,
PriceTier,
ProcessLock, ProcessLock,
PublicConversation, PublicConversation,
RateLimitRecord, RateLimitRecord,
@@ -1153,22 +1154,36 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_chat_model(user: KhojUser): def get_chat_model(user: KhojUser):
subscribed = is_user_subscribed(user) subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_chat_model(user)
config = UserConversationConfig.objects.filter(user=user).first() config = UserConversationConfig.objects.filter(user=user).first()
if subscribed:
# Subscibed users can use any available chat model
if config: if config:
return config.setting return config.setting
# Fallback to the default advanced chat model
return ConversationAdapters.get_advanced_chat_model(user) return ConversationAdapters.get_advanced_chat_model(user)
else:
# Non-subscribed users can use any free chat model
if config and config.setting.price_tier == PriceTier.FREE:
return config.setting
# Fallback to the default chat model
return ConversationAdapters.get_default_chat_model(user)
@staticmethod @staticmethod
async def aget_chat_model(user: KhojUser): async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_chat_model(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if subscribed:
# Subscibed users can use any available chat model
if config: if config:
return config.setting return config.setting
return ConversationAdapters.aget_advanced_chat_model(user) # Fallback to the default advanced chat model
return await ConversationAdapters.aget_advanced_chat_model(user)
else:
# Non-subscribed users can use any free chat model
if config and config.setting.price_tier == PriceTier.FREE:
return config.setting
# Fallback to the default chat model
return await ConversationAdapters.aget_default_chat_model(user)
@staticmethod @staticmethod
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
@@ -1176,6 +1191,12 @@ class ConversationAdapters:
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first() return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
return ChatModel.objects.filter(name=chat_model_name).first() return ChatModel.objects.filter(name=chat_model_name).first()
@staticmethod
async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
if ai_model_api_name:
return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst()
return await ChatModel.objects.filter(name=chat_model_name).afirst()
@staticmethod @staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()

View File

@@ -0,0 +1,34 @@
# Generated by Django 5.1.8 on 2025-04-18 15:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0088_ratelimitrecord"),
]
operations = [
migrations.AddField(
model_name="chatmodel",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="speechtotextmodeloptions",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="price_tier",
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
),
migrations.AddField(
model_name="voicemodeloption",
name="price_tier",
field=models.CharField(
choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20
),
),
]

View File

@@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel):
return self.name return self.name
class PriceTier(models.TextChoices):
FREE = "free"
STANDARD = "standard"
class ChatModel(DbBaseModel): class ChatModel(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
@@ -207,6 +212,7 @@ class ChatModel(DbBaseModel):
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
vision_enabled = models.BooleanField(default=False) vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
description = models.TextField(default=None, null=True, blank=True) description = models.TextField(default=None, null=True, blank=True)
@@ -219,6 +225,7 @@ class ChatModel(DbBaseModel):
class VoiceModelOption(DbBaseModel): class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200) model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD)
class Agent(DbBaseModel): class Agent(DbBaseModel):
@@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel):
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
) )
def clean(self):
error = {}
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class LocalOrgConfig(DbBaseModel): class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True) input_files = models.JSONField(default=list, null=True)
@@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel):
model_name = models.CharField(max_length=200, default="dall-e-3") model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True) api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel):
model_name = models.CharField(max_length=200, default="base") model_name = models.CharField(max_length=200, default="base")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
def __str__(self): def __str__(self):

View File

@@ -12,7 +12,7 @@ from pydantic import BaseModel
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.models import Agent, Conversation, KhojUser from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import ( from khoj.utils.helpers import (
ConversationCommand, ConversationCommand,
@@ -125,8 +125,20 @@ async def get_agent_by_conversation(
else: else:
agent = await AgentAdapters.aget_default_agent() agent = await AgentAdapters.aget_default_agent()
if agent is None:
return Response(
content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}),
media_type="application/json",
status_code=404,
)
chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
if is_subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = chat_model.name
else:
agent_chat_model = None
has_files = agent.fileobject_set.exists() has_files = agent.fileobject_set.exists()
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = { agents_packet = {
"slug": agent.slug, "slug": agent.slug,
@@ -137,7 +149,7 @@ async def get_agent_by_conversation(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name if is_subscribed else None, "chat_model": agent_chat_model,
"has_files": has_files, "has_files": has_files,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@@ -249,7 +261,11 @@ async def update_hidden_agent(
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user) selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
@@ -264,7 +280,7 @@ async def update_hidden_agent(
user=user, user=user,
slug=body.slug, slug=body.slug,
persona=body.persona, persona=body.persona,
chat_model=chat_model, chat_model=agent_chat_model,
input_tools=body.input_tools, input_tools=body.input_tools,
output_modes=body.output_modes, output_modes=body.output_modes,
existing_agent=selected_agent, existing_agent=selected_agent,
@@ -295,7 +311,11 @@ async def create_hidden_agent(
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id) conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
if not conversation: if not conversation:
@@ -320,7 +340,7 @@ async def create_hidden_agent(
user=user, user=user,
slug=body.slug, slug=body.slug,
persona=body.persona, persona=body.persona,
chat_model=chat_model, chat_model=agent_chat_model,
input_tools=body.input_tools, input_tools=body.input_tools,
output_modes=body.output_modes, output_modes=body.output_modes,
existing_agent=None, existing_agent=None,
@@ -364,7 +384,11 @@ async def create_agent(
) )
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
agent = await AgentAdapters.aupdate_agent( agent = await AgentAdapters.aupdate_agent(
user, user,
@@ -373,7 +397,7 @@ async def create_agent(
body.privacy_level, body.privacy_level,
body.icon, body.icon,
body.color, body.color,
chat_model, agent_chat_model,
body.files, body.files,
body.input_tools, body.input_tools,
body.output_modes, body.output_modes,
@@ -431,7 +455,11 @@ async def update_agent(
) )
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
chat_model = body.chat_model if subscribed else None chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
if subscribed or chat_model.price_tier == PriceTier.FREE:
agent_chat_model = body.chat_model
else:
agent_chat_model = None
agent = await AgentAdapters.aupdate_agent( agent = await AgentAdapters.aupdate_agent(
user, user,
@@ -440,7 +468,7 @@ async def update_agent(
body.privacy_level, body.privacy_level,
body.icon, body.icon,
body.color, body.color,
chat_model, agent_chat_model,
body.files, body.files,
body.input_tools, body.input_tools,
body.output_modes, body.output_modes,

View File

@@ -2,13 +2,18 @@ import json
import logging import logging
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
from khoj.database import adapters from khoj.database.adapters import ConversationAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import (
ChatModel,
PriceTier,
TextToImageModelConfig,
VoiceModelOption,
)
from khoj.routers.helpers import update_telemetry_state from khoj.routers.helpers import update_telemetry_state
api_model = APIRouter() api_model = APIRouter()
@@ -53,13 +58,24 @@ def get_user_chat_model(
@api_model.post("/chat", status_code=200) @api_model.post("/chat", status_code=200)
@requires(["authenticated", "premium"]) @requires(["authenticated"])
async def update_chat_model( async def update_chat_model(
request: Request, request: Request,
id: str, id: str,
client: Optional[str] = None, client: Optional[str] = None,
): ):
user = request.user.object user = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Validate if model can be switched
chat_model = await ChatModel.objects.filter(id=int(id)).afirst()
if chat_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
if not subscribed and chat_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
)
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
@@ -78,13 +94,24 @@ async def update_chat_model(
@api_model.post("/voice", status_code=200) @api_model.post("/voice", status_code=200)
@requires(["authenticated", "premium"]) @requires(["authenticated"])
async def update_voice_model( async def update_voice_model(
request: Request, request: Request,
id: str, id: str,
client: Optional[str] = None, client: Optional[str] = None,
): ):
user = request.user.object user = request.user.object
subscribed = has_required_scope(request, ["premium"])
# Validate if model can be switched
voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst()
if voice_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
if not subscribed and voice_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
)
new_config = await ConversationAdapters.aset_user_voice_model(user, id) new_config = await ConversationAdapters.aset_user_voice_model(user, id)
@@ -111,8 +138,15 @@ async def update_paint_model(
user = request.user.object user = request.user.object
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
if not subscribed: # Validate if model can be switched
raise HTTPException(status_code=403, detail="User is not subscribed to premium") image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst()
if image_model is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
if not subscribed and image_model.price_tier != PriceTier.FREE:
raise Response(
status_code=403,
content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
)
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))

View File

@@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"id": chat_model.id, "id": chat_model.id,
"strengths": chat_model.strengths, "strengths": chat_model.strengths,
"description": chat_model.description, "description": chat_model.description,
"tier": chat_model.price_tier,
} }
) )
@@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list() all_paint_model_options = list()
for paint_model in paint_model_options: for paint_model in paint_model_options:
all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id}) all_paint_model_options.append(
{
"name": paint_model.model_name,
"id": paint_model.id,
"tier": paint_model.price_tier,
}
)
voice_models = ConversationAdapters.get_voice_model_options() voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list() voice_model_options = list()
for voice_model in voice_models: for voice_model in voice_models:
voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id}) voice_model_options.append(
{
"name": voice_model.name,
"id": voice_model.model_id,
"tier": voice_model.price_tier,
}
)
if len(voice_model_options) == 0: if len(voice_model_options) == 0:
eleven_labs_enabled = False eleven_labs_enabled = False