mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Allow AI model switching based on User Tier (#1151)
Overview --- Enable free tier users to chat with any AI model made available on free tier of production deployments like [Khoj cloud](https://app.khoj.dev). Previously model switching was completely disabled for users on free tier. Details --- - Track price tier of each Chat, Speech, Image, Voice AI model in DB - Update API to allow free tier users to switch between free models - Update web app to allow model switching on agent creation, settings chat page (via right side pane), even for free tier users.
This commit is contained in:
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
- name: 📂 Copy Generated Files
|
||||
run: |
|
||||
mkdir -p src/khoj/interface/compiled
|
||||
cp -r /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/
|
||||
cp -r /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled/* src/khoj/interface/compiled/
|
||||
|
||||
- name: ⚙️ Build Python Package
|
||||
run: |
|
||||
|
||||
@@ -24,9 +24,7 @@ import {
|
||||
ChatOptions,
|
||||
} from "../components/chatInputArea/chatInputArea";
|
||||
import { useAuthenticatedData } from "../common/auth";
|
||||
import {
|
||||
AgentData,
|
||||
} from "@/app/components/agentCard/agentCard";
|
||||
import { AgentData } from "@/app/components/agentCard/agentCard";
|
||||
import { ChatSessionActionMenu } from "../components/allConversations/allConversations";
|
||||
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||
import { AppSidebar } from "../components/appSidebar/appSidebar";
|
||||
@@ -50,6 +48,7 @@ interface ChatBodyDataProps {
|
||||
setTriggeredAbort: (triggeredAbort: boolean) => void;
|
||||
isChatSideBarOpen: boolean;
|
||||
setIsChatSideBarOpen: (open: boolean) => void;
|
||||
isActive?: boolean;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
@@ -180,9 +179,11 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||
</div>
|
||||
<ChatSidebar
|
||||
conversationId={conversationId}
|
||||
isActive={props.isActive}
|
||||
isOpen={props.isChatSideBarOpen}
|
||||
onOpenChange={props.setIsChatSideBarOpen}
|
||||
isMobileWidth={props.isMobileWidth} />
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -480,12 +481,13 @@ export default function Chat() {
|
||||
setTriggeredAbort={setTriggeredAbort}
|
||||
isChatSideBarOpen={isChatSideBarOpen}
|
||||
setIsChatSideBarOpen={setIsChatSideBarOpen}
|
||||
isActive={authenticatedData?.is_active}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</SidebarInset>
|
||||
</SidebarProvider >
|
||||
</SidebarProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ export function useAuthenticatedData() {
|
||||
export interface ModelOptions {
|
||||
id: number;
|
||||
name: string;
|
||||
tier: string;
|
||||
description: string;
|
||||
strengths: string;
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import { Skeleton } from "@/components/ui/skeleton";
|
||||
interface ModelSelectorProps extends PopoverProps {
|
||||
onSelect: (model: ModelOptions) => void;
|
||||
disabled?: boolean;
|
||||
isActive?: boolean;
|
||||
initialModel?: string;
|
||||
}
|
||||
|
||||
@@ -116,6 +117,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
|
||||
setSelectedModel(model)
|
||||
setOpen(false)
|
||||
}}
|
||||
isActive={props.isActive}
|
||||
/>
|
||||
))}
|
||||
</CommandGroup>
|
||||
@@ -165,6 +167,7 @@ export function ModelSelector({ ...props }: ModelSelectorProps) {
|
||||
setSelectedModel(model)
|
||||
setOpen(false)
|
||||
}}
|
||||
isActive={props.isActive}
|
||||
/>
|
||||
))}
|
||||
</CommandGroup>
|
||||
@@ -184,9 +187,10 @@ interface ModelItemProps {
|
||||
isSelected: boolean,
|
||||
onSelect: () => void,
|
||||
onPeek: (model: ModelOptions) => void
|
||||
isActive?: boolean
|
||||
}
|
||||
|
||||
function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
|
||||
function ModelItem({ model, isSelected, onSelect, onPeek, isActive }: ModelItemProps) {
|
||||
const ref = React.useRef<HTMLDivElement>(null)
|
||||
|
||||
useMutationObserver(ref, (mutations) => {
|
||||
@@ -207,8 +211,9 @@ function ModelItem({ model, isSelected, onSelect, onPeek }: ModelItemProps) {
|
||||
onSelect={onSelect}
|
||||
ref={ref}
|
||||
className="data-[selected=true]:bg-muted data-[selected=true]:text-secondary-foreground"
|
||||
disabled={!isActive && model.tier !== "free"}
|
||||
>
|
||||
{model.name}
|
||||
{model.name} {model.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
|
||||
<Check
|
||||
className={cn("ml-auto", isSelected ? "opacity-100" : "opacity-0")}
|
||||
/>
|
||||
|
||||
@@ -773,11 +773,7 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
|
||||
<p>Which chat model would you like to use?</p>
|
||||
)}
|
||||
</FormDescription>
|
||||
<Select
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
disabled={!props.isSubscribed}
|
||||
>
|
||||
<Select onValueChange={field.onChange} defaultValue={field.value}>
|
||||
<FormControl>
|
||||
<SelectTrigger className="text-left dark:bg-muted">
|
||||
<SelectValue />
|
||||
@@ -788,9 +784,18 @@ export function AgentModificationForm(props: AgentModificationFormProps) {
|
||||
<SelectItem
|
||||
key={modelOption.id}
|
||||
value={modelOption.name}
|
||||
disabled={
|
||||
!props.isSubscribed &&
|
||||
modelOption.tier !== "free"
|
||||
}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{modelOption.name}
|
||||
{modelOption.name}{" "}
|
||||
{modelOption.tier === "standard" && (
|
||||
<span className="text-green-500 ml-2">
|
||||
(Futurist)
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
|
||||
@@ -34,6 +34,7 @@ interface ChatSideBarProps {
|
||||
isOpen: boolean;
|
||||
isMobileWidth?: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
isActive?: boolean;
|
||||
}
|
||||
|
||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
||||
@@ -527,9 +528,10 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
|
||||
<SidebarMenu className="p-0 m-0">
|
||||
<SidebarMenuItem key={"model"} className="list-none">
|
||||
<ModelSelector
|
||||
disabled={!isEditable || !isSubscribed}
|
||||
disabled={!isEditable}
|
||||
onSelect={(model) => handleModelSelect(model.name)}
|
||||
initialModel={isDefaultAgent ? undefined : agentData?.chat_model}
|
||||
isActive={props.isActive}
|
||||
/>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
|
||||
@@ -77,10 +77,11 @@ import { saveAs } from 'file-saver';
|
||||
interface DropdownComponentProps {
|
||||
items: ModelOptions[];
|
||||
selected: number;
|
||||
isActive?: boolean;
|
||||
callbackFunc: (value: string) => Promise<void>;
|
||||
}
|
||||
|
||||
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, callbackFunc }) => {
|
||||
const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected, isActive, callbackFunc }) => {
|
||||
const [position, setPosition] = useState(selected?.toString() ?? "0");
|
||||
|
||||
return (
|
||||
@@ -111,8 +112,9 @@ const DropdownComponent: React.FC<DropdownComponentProps> = ({ items, selected,
|
||||
<DropdownMenuRadioItem
|
||||
key={item.id.toString()}
|
||||
value={item.id.toString()}
|
||||
disabled={!isActive && item.tier !== "free"}
|
||||
>
|
||||
{item.name}
|
||||
{item.name} {item.tier === "standard" && <span className="text-green-500 ml-2">(Futurist)</span>}
|
||||
</DropdownMenuRadioItem>
|
||||
))}
|
||||
</DropdownMenuRadioGroup>
|
||||
@@ -520,33 +522,44 @@ export default function SettingsView() {
|
||||
}
|
||||
};
|
||||
|
||||
const updateModel = (name: string) => async (id: string) => {
|
||||
if (!userConfig?.is_active) {
|
||||
const updateModel = (modelType: string) => async (id: string) => {
|
||||
// Get the selected model from the options
|
||||
const modelOptions = modelType === "chat"
|
||||
? userConfig?.chat_model_options
|
||||
: modelType === "paint"
|
||||
? userConfig?.paint_model_options
|
||||
: userConfig?.voice_model_options;
|
||||
|
||||
const selectedModel = modelOptions?.find(model => model.id.toString() === id);
|
||||
const modelName = selectedModel?.name;
|
||||
|
||||
// Check if the model is free tier or if the user is active
|
||||
if (!userConfig?.is_active && selectedModel?.tier !== "free") {
|
||||
toast({
|
||||
title: `Model Update`,
|
||||
description: `You need to be subscribed to update ${name} models`,
|
||||
description: `Subscribe to switch ${modelType} model to ${modelName}.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/model/${name}?id=` + id, {
|
||||
const response = await fetch(`/api/model/${modelType}?id=` + id, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error("Failed to update model");
|
||||
if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`);
|
||||
|
||||
toast({
|
||||
title: `✅ Updated ${toTitleCase(name)} Model`,
|
||||
title: `✅ Switched ${modelType} model to ${modelName}`,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Failed to update ${name} model:`, error);
|
||||
console.error(`Failed to update ${modelType} model to ${modelName}:`, error);
|
||||
toast({
|
||||
description: `❌ Failed to update ${toTitleCase(name)} model. Try again.`,
|
||||
description: `❌ Failed to switch ${modelType} model to ${modelName}. Try again.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
@@ -1103,13 +1116,16 @@ export default function SettingsView() {
|
||||
selected={
|
||||
userConfig.selected_chat_model_config
|
||||
}
|
||||
isActive={userConfig.is_active}
|
||||
callbackFunc={updateModel("chat")}
|
||||
/>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-wrap gap-4">
|
||||
{!userConfig.is_active && (
|
||||
<p className="text-gray-400">
|
||||
Subscribe to switch model
|
||||
{userConfig.chat_model_options.some(model => model.tier === "free")
|
||||
? "Free models available"
|
||||
: "Subscribe to switch model"}
|
||||
</p>
|
||||
)}
|
||||
</CardFooter>
|
||||
@@ -1131,13 +1147,16 @@ export default function SettingsView() {
|
||||
selected={
|
||||
userConfig.selected_paint_model_config
|
||||
}
|
||||
isActive={userConfig.is_active}
|
||||
callbackFunc={updateModel("paint")}
|
||||
/>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-wrap gap-4">
|
||||
{!userConfig.is_active && (
|
||||
<p className="text-gray-400">
|
||||
Subscribe to switch model
|
||||
{userConfig.paint_model_options.some(model => model.tier === "free")
|
||||
? "Free models available"
|
||||
: "Subscribe to switch model"}
|
||||
</p>
|
||||
)}
|
||||
</CardFooter>
|
||||
@@ -1159,13 +1178,16 @@ export default function SettingsView() {
|
||||
selected={
|
||||
userConfig.selected_voice_model_config
|
||||
}
|
||||
isActive={userConfig.is_active}
|
||||
callbackFunc={updateModel("voice")}
|
||||
/>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-wrap gap-4">
|
||||
{!userConfig.is_active && (
|
||||
<p className="text-gray-400">
|
||||
Subscribe to switch model
|
||||
{userConfig.voice_model_options.some(model => model.tier === "free")
|
||||
? "Free models available"
|
||||
: "Subscribe to switch model"}
|
||||
</p>
|
||||
)}
|
||||
</CardFooter>
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'",
|
||||
"export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic",
|
||||
"ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic",
|
||||
"pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic",
|
||||
"pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic",
|
||||
"watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'",
|
||||
"windowswatch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn windowsexport'",
|
||||
"windowscollectstatic": "cd ..\\..\\.. && .\\.venv\\Scripts\\Activate.bat && py .\\src\\khoj\\manage.py collectstatic --noinput && .\\.venv\\Scripts\\deactivate.bat && cd ..",
|
||||
|
||||
@@ -48,6 +48,7 @@ from khoj.database.models import (
|
||||
KhojApiUser,
|
||||
KhojUser,
|
||||
NotionConfig,
|
||||
PriceTier,
|
||||
ProcessLock,
|
||||
PublicConversation,
|
||||
RateLimitRecord,
|
||||
@@ -1153,22 +1154,36 @@ class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_chat_model(user: KhojUser):
|
||||
subscribed = is_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
config = UserConversationConfig.objects.filter(user=user).first()
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.get_advanced_chat_model(user)
|
||||
if subscribed:
|
||||
# Subscibed users can use any available chat model
|
||||
if config:
|
||||
return config.setting
|
||||
# Fallback to the default advanced chat model
|
||||
return ConversationAdapters.get_advanced_chat_model(user)
|
||||
else:
|
||||
# Non-subscribed users can use any free chat model
|
||||
if config and config.setting.price_tier == PriceTier.FREE:
|
||||
return config.setting
|
||||
# Fallback to the default chat model
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_chat_model(user: KhojUser):
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.aget_advanced_chat_model(user)
|
||||
if subscribed:
|
||||
# Subscibed users can use any available chat model
|
||||
if config:
|
||||
return config.setting
|
||||
# Fallback to the default advanced chat model
|
||||
return await ConversationAdapters.aget_advanced_chat_model(user)
|
||||
else:
|
||||
# Non-subscribed users can use any free chat model
|
||||
if config and config.setting.price_tier == PriceTier.FREE:
|
||||
return config.setting
|
||||
# Fallback to the default chat model
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
def get_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
||||
@@ -1176,6 +1191,12 @@ class ConversationAdapters:
|
||||
return ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).first()
|
||||
return ChatModel.objects.filter(name=chat_model_name).first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None):
|
||||
if ai_model_api_name:
|
||||
return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst()
|
||||
return await ChatModel.objects.filter(name=chat_model_name).afirst()
|
||||
|
||||
@staticmethod
|
||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# Generated by Django 5.1.8 on 2025-04-18 15:15
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0088_ratelimitrecord"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="chatmodel",
|
||||
name="price_tier",
|
||||
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="speechtotextmodeloptions",
|
||||
name="price_tier",
|
||||
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="price_tier",
|
||||
field=models.CharField(choices=[("free", "Free"), ("standard", "Standard")], default="free", max_length=20),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="voicemodeloption",
|
||||
name="price_tier",
|
||||
field=models.CharField(
|
||||
choices=[("free", "Free"), ("standard", "Standard")], default="standard", max_length=20
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -195,6 +195,11 @@ class AiModelApi(DbBaseModel):
|
||||
return self.name
|
||||
|
||||
|
||||
class PriceTier(models.TextChoices):
|
||||
FREE = "free"
|
||||
STANDARD = "standard"
|
||||
|
||||
|
||||
class ChatModel(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
@@ -207,6 +212,7 @@ class ChatModel(DbBaseModel):
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
vision_enabled = models.BooleanField(default=False)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
description = models.TextField(default=None, null=True, blank=True)
|
||||
@@ -219,6 +225,7 @@ class ChatModel(DbBaseModel):
|
||||
class VoiceModelOption(DbBaseModel):
|
||||
model_id = models.CharField(max_length=200)
|
||||
name = models.CharField(max_length=200)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.STANDARD)
|
||||
|
||||
|
||||
class Agent(DbBaseModel):
|
||||
@@ -452,6 +459,17 @@ class ServerChatSettings(DbBaseModel):
|
||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||
)
|
||||
|
||||
def clean(self):
|
||||
error = {}
|
||||
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
|
||||
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
|
||||
if error:
|
||||
raise ValidationError(error)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.clean()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class LocalOrgConfig(DbBaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
@@ -534,6 +552,7 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
@@ -571,6 +590,7 @@ class SpeechToTextModelOptions(DbBaseModel):
|
||||
|
||||
model_name = models.CharField(max_length=200, default="base")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, Conversation, KhojUser
|
||||
from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
|
||||
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
@@ -125,8 +125,20 @@ async def get_agent_by_conversation(
|
||||
else:
|
||||
agent = await AgentAdapters.aget_default_agent()
|
||||
|
||||
if agent is None:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Agent for conversation id {conversation_id} not found for user {user}."}),
|
||||
media_type="application/json",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
if is_subscribed or chat_model.price_tier == PriceTier.FREE:
|
||||
agent_chat_model = chat_model.name
|
||||
else:
|
||||
agent_chat_model = None
|
||||
|
||||
has_files = agent.fileobject_set.exists()
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
@@ -137,7 +149,7 @@ async def get_agent_by_conversation(
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.name if is_subscribed else None,
|
||||
"chat_model": agent_chat_model,
|
||||
"has_files": has_files,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
@@ -249,7 +261,11 @@ async def update_hidden_agent(
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
chat_model = body.chat_model if subscribed else None
|
||||
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
|
||||
if subscribed or chat_model.price_tier == PriceTier.FREE:
|
||||
agent_chat_model = body.chat_model
|
||||
else:
|
||||
agent_chat_model = None
|
||||
|
||||
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
|
||||
|
||||
@@ -264,7 +280,7 @@ async def update_hidden_agent(
|
||||
user=user,
|
||||
slug=body.slug,
|
||||
persona=body.persona,
|
||||
chat_model=chat_model,
|
||||
chat_model=agent_chat_model,
|
||||
input_tools=body.input_tools,
|
||||
output_modes=body.output_modes,
|
||||
existing_agent=selected_agent,
|
||||
@@ -295,7 +311,11 @@ async def create_hidden_agent(
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
chat_model = body.chat_model if subscribed else None
|
||||
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
|
||||
if subscribed or chat_model.price_tier == PriceTier.FREE:
|
||||
agent_chat_model = body.chat_model
|
||||
else:
|
||||
agent_chat_model = None
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
|
||||
if not conversation:
|
||||
@@ -320,7 +340,7 @@ async def create_hidden_agent(
|
||||
user=user,
|
||||
slug=body.slug,
|
||||
persona=body.persona,
|
||||
chat_model=chat_model,
|
||||
chat_model=agent_chat_model,
|
||||
input_tools=body.input_tools,
|
||||
output_modes=body.output_modes,
|
||||
existing_agent=None,
|
||||
@@ -364,7 +384,11 @@ async def create_agent(
|
||||
)
|
||||
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
chat_model = body.chat_model if subscribed else None
|
||||
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
|
||||
if subscribed or chat_model.price_tier == PriceTier.FREE:
|
||||
agent_chat_model = body.chat_model
|
||||
else:
|
||||
agent_chat_model = None
|
||||
|
||||
agent = await AgentAdapters.aupdate_agent(
|
||||
user,
|
||||
@@ -373,7 +397,7 @@ async def create_agent(
|
||||
body.privacy_level,
|
||||
body.icon,
|
||||
body.color,
|
||||
chat_model,
|
||||
agent_chat_model,
|
||||
body.files,
|
||||
body.input_tools,
|
||||
body.output_modes,
|
||||
@@ -431,7 +455,11 @@ async def update_agent(
|
||||
)
|
||||
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
chat_model = body.chat_model if subscribed else None
|
||||
chat_model = await ConversationAdapters.aget_chat_model_by_name(body.chat_model)
|
||||
if subscribed or chat_model.price_tier == PriceTier.FREE:
|
||||
agent_chat_model = body.chat_model
|
||||
else:
|
||||
agent_chat_model = None
|
||||
|
||||
agent = await AgentAdapters.aupdate_agent(
|
||||
user,
|
||||
@@ -440,7 +468,7 @@ async def update_agent(
|
||||
body.privacy_level,
|
||||
body.icon,
|
||||
body.color,
|
||||
chat_model,
|
||||
agent_chat_model,
|
||||
body.files,
|
||||
body.input_tools,
|
||||
body.output_modes,
|
||||
|
||||
@@ -2,13 +2,18 @@ import json
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import (
|
||||
ChatModel,
|
||||
PriceTier,
|
||||
TextToImageModelConfig,
|
||||
VoiceModelOption,
|
||||
)
|
||||
from khoj.routers.helpers import update_telemetry_state
|
||||
|
||||
api_model = APIRouter()
|
||||
@@ -53,13 +58,24 @@ def get_user_chat_model(
|
||||
|
||||
|
||||
@api_model.post("/chat", status_code=200)
|
||||
@requires(["authenticated", "premium"])
|
||||
@requires(["authenticated"])
|
||||
async def update_chat_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
# Validate if model can be switched
|
||||
chat_model = await ChatModel.objects.filter(id=int(id)).afirst()
|
||||
if chat_model is None:
|
||||
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Chat model not found"}))
|
||||
if not subscribed and chat_model.price_tier != PriceTier.FREE:
|
||||
raise Response(
|
||||
status_code=403,
|
||||
content=json.dumps({"status": "error", "message": "Subscribe to switch to this chat model"}),
|
||||
)
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||
|
||||
@@ -78,13 +94,24 @@ async def update_chat_model(
|
||||
|
||||
|
||||
@api_model.post("/voice", status_code=200)
|
||||
@requires(["authenticated", "premium"])
|
||||
@requires(["authenticated"])
|
||||
async def update_voice_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
# Validate if model can be switched
|
||||
voice_model = await VoiceModelOption.objects.filter(id=int(id)).afirst()
|
||||
if voice_model is None:
|
||||
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Voice model not found"}))
|
||||
if not subscribed and voice_model.price_tier != PriceTier.FREE:
|
||||
raise Response(
|
||||
status_code=403,
|
||||
content=json.dumps({"status": "error", "message": "Subscribe to switch to this voice model"}),
|
||||
)
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
|
||||
|
||||
@@ -111,8 +138,15 @@ async def update_paint_model(
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
if not subscribed:
|
||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||
# Validate if model can be switched
|
||||
image_model = await TextToImageModelConfig.objects.filter(id=int(id)).afirst()
|
||||
if image_model is None:
|
||||
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Image model not found"}))
|
||||
if not subscribed and image_model.price_tier != PriceTier.FREE:
|
||||
raise Response(
|
||||
status_code=403,
|
||||
content=json.dumps({"status": "error", "message": "Subscribe to switch to this image model"}),
|
||||
)
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||
|
||||
|
||||
@@ -2364,6 +2364,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||
"id": chat_model.id,
|
||||
"strengths": chat_model.strengths,
|
||||
"description": chat_model.description,
|
||||
"tier": chat_model.price_tier,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -2371,12 +2372,24 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||
all_paint_model_options = list()
|
||||
for paint_model in paint_model_options:
|
||||
all_paint_model_options.append({"name": paint_model.model_name, "id": paint_model.id})
|
||||
all_paint_model_options.append(
|
||||
{
|
||||
"name": paint_model.model_name,
|
||||
"id": paint_model.id,
|
||||
"tier": paint_model.price_tier,
|
||||
}
|
||||
)
|
||||
|
||||
voice_models = ConversationAdapters.get_voice_model_options()
|
||||
voice_model_options = list()
|
||||
for voice_model in voice_models:
|
||||
voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id})
|
||||
voice_model_options.append(
|
||||
{
|
||||
"name": voice_model.name,
|
||||
"id": voice_model.model_id,
|
||||
"tier": voice_model.price_tier,
|
||||
}
|
||||
)
|
||||
|
||||
if len(voice_model_options) == 0:
|
||||
eleven_labs_enabled = False
|
||||
|
||||
Reference in New Issue
Block a user