Add feature to create a custom agent directly from the side panel with currently configured settings

- Also, when in not subscribed state, fallback to the default model when chatting with an agent
- With conversion, create a brand new agent from inside the chat view that can be managed separately
This commit is contained in:
sabaimran
2025-02-12 18:24:41 -08:00
parent 5d6eca4c22
commit d0d30ace06
7 changed files with 279 additions and 12 deletions

View File

@@ -35,6 +35,7 @@ import { AppSidebar } from "../components/appSidebar/appSidebar";
import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo";
import { DialogTitle } from "@radix-ui/react-dialog";
import Link from "next/link";
const agentsFetcher = () =>
window
@@ -343,6 +344,14 @@ export default function Agents() {
/>
<span className="font-bold">How it works</span> Use any of these
specialized personas to tune your conversation to your needs.
{
!isSubscribed && (
<span>
{" "}
<Link href="/settings" className="font-bold">Upgrade your plan</Link> to leverage custom models. You will fallback to the default model when chatting.
</span>
)
}
</AlertDescription>
</Alert>
<div className="pt-6 md:pt-8">

View File

@@ -453,7 +453,7 @@ export function AgentCard(props: AgentCardProps) {
/>
</DialogContent>
) : (
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] rounded-lg">
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] md:max-w-[50vw] rounded-lg">
<DialogHeader>
<div className="flex items-center">
{getIconFromIconName(props.data.icon, props.data.color)}

View File

@@ -1,6 +1,6 @@
"use client"
import { ArrowsDownUp, CaretCircleDown, CircleNotch, Sparkle } from "@phosphor-icons/react";
import { ArrowsDownUp, CaretCircleDown, CheckCircle, Circle, CircleNotch, PersonSimpleTaiChi, Sparkle } from "@phosphor-icons/react";
import { Button } from "@/components/ui/button";
@@ -14,13 +14,20 @@ import { mutate } from "swr";
import { Sheet, SheetContent } from "@/components/ui/sheet";
import { AgentData } from "../agentCard/agentCard";
import { useEffect, useState } from "react";
import { getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
import { getAvailableIcons, getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
import { Label } from "@/components/ui/label";
import { Checkbox } from "@/components/ui/checkbox";
import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip";
import { TooltipContent } from "@radix-ui/react-tooltip";
import { useAuthenticatedData } from "@/app/common/auth";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Dialog, DialogClose, DialogContent, DialogFooter, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
import { convertColorToTextClass, tailwindColors } from "@/app/common/colorUtils";
import { Input } from "@/components/ui/input";
import Link from "next/link";
import { motion } from "framer-motion";
interface ChatSideBarProps {
conversationId: string;
@@ -54,11 +61,245 @@ export function ChatSidebar({ ...props }: ChatSideBarProps) {
);
}
interface IAgentCreationProps {
customPrompt: string;
selectedModel: string;
inputTools: string[];
outputModes: string[];
}
interface AgentError {
detail: string;
}
function AgentCreationForm(props: IAgentCreationProps) {
const iconOptions = getAvailableIcons();
const colorOptions = tailwindColors;
const [isCreating, setIsCreating] = useState<boolean>(false);
const [customAgentName, setCustomAgentName] = useState<string | undefined>();
const [customAgentIcon, setCustomAgentIcon] = useState<string | undefined>();
const [customAgentColor, setCustomAgentColor] = useState<string | undefined>();
const [doneCreating, setDoneCreating] = useState<boolean>(false);
const [createdSlug, setCreatedSlug] = useState<string | undefined>();
const [isValid, setIsValid] = useState<boolean>(false);
const [error, setError] = useState<string | undefined>();
function createAgent() {
if (isCreating) {
return;
}
setIsCreating(true);
const data = {
name: customAgentName,
icon: customAgentIcon,
color: customAgentColor,
persona: props.customPrompt,
chat_model: props.selectedModel,
input_tools: props.inputTools,
output_modes: props.outputModes,
privacy_level: "private",
};
const createAgentUrl = `/api/agents`;
fetch(createAgentUrl, {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify(data)
})
.then((res) => res.json())
.then((data: AgentData | AgentError) => {
console.log("Success:", data);
if ('detail' in data) {
setError(`Error creating agent: ${data.detail}`);
setIsCreating(false);
return;
}
setDoneCreating(true);
setCreatedSlug(data.slug);
setIsCreating(false);
})
.catch((error) => {
console.error("Error:", error);
setError(`Error creating agent: ${error}`);
setIsCreating(false);
});
}
useEffect(() => {
if (customAgentName && customAgentIcon && customAgentColor) {
setIsValid(true);
} else {
setIsValid(false);
}
}, [customAgentName, customAgentIcon, customAgentColor]);
return (
<Dialog>
<DialogTrigger asChild>
<Button
className="p-1"
variant="ghost"
>
Create Agent
</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
{
doneCreating && createdSlug ? (
<DialogTitle>
Created {customAgentName}
</DialogTitle>
) : (
<DialogTitle>
Create a New Agent
</DialogTitle>
)
}
<DialogClose />
</DialogHeader>
<div className="py-4">
{
doneCreating && createdSlug ? (
<div className="flex flex-col items-center justify-center gap-4 py-8">
<motion.div
initial={{ scale: 0 }}
animate={{ scale: 1 }}
transition={{
type: "spring",
stiffness: 260,
damping: 20
}}
>
<CheckCircle
className="w-16 h-16 text-green-500"
weight="fill"
/>
</motion.div>
<motion.p
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.2 }}
className="text-center text-lg font-medium text-accent-foreground"
>
Created successfully!
</motion.p>
<motion.div
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.4 }}
>
<Link href={`/agents?agent=${createdSlug}`}>
<Button variant="secondary" className="mt-2">
Manage Agent
</Button>
</Link>
</motion.div>
</div>
) :
<div className="flex flex-col gap-4">
<div>
<Label htmlFor="agent_name">Name</Label>
<Input
id="agent_name"
className="w-full p-2 border mt-4 border-slate-500 rounded-lg"
disabled={isCreating}
value={customAgentName}
onChange={(e) => setCustomAgentName(e.target.value)}
/>
</div>
<div className="flex gap-4">
<div className="flex-1">
<Select onValueChange={setCustomAgentColor} defaultValue={customAgentColor}>
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
<SelectValue placeholder="Color" />
</SelectTrigger>
<SelectContent className="items-center space-y-1 inline-flex flex-col">
{colorOptions.map((colorOption) => (
<SelectItem key={colorOption} value={colorOption}>
<div className="flex items-center space-x-2">
<Circle
className={`w-6 h-6 mr-2 ${convertColorToTextClass(colorOption)}`}
weight="fill"
/>
{colorOption}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex-1">
<Select onValueChange={setCustomAgentIcon} defaultValue={customAgentIcon}>
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
<SelectValue placeholder="Icon" />
</SelectTrigger>
<SelectContent className="items-center space-y-1 inline-flex flex-col">
{iconOptions.map((iconOption) => (
<SelectItem key={iconOption} value={iconOption}>
<div className="flex items-center space-x-2">
{getIconFromIconName(
iconOption,
customAgentColor ?? "gray",
"w-6",
"h-6",
)}
{iconOption}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
</div>
}
</div>
<DialogFooter>
{
error && (
<div className="text-red-500 text-sm">
{error}
</div>
)
}
{
!doneCreating && (
<Button
type="submit"
onClick={() => createAgent()}
disabled={isCreating || !isValid}
>
{
isCreating ?
<CircleNotch className="animate-spin" />
:
<PersonSimpleTaiChi />
}
Create
</Button>
)
}
<DialogClose />
</DialogFooter>
</DialogContent>
</Dialog >
)
}
function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
const [isEditable, setIsEditable] = useState<boolean>(false);
const { data: agentConfigurationOptions, error: agentConfigurationOptionsError } =
useSWR<AgentConfigurationOptions>("/api/agents/options", fetcher);
useSWR<AgentConfigurationOptions>("/api/agents/options", fetcher);
const { data: agentData, isLoading: agentDataLoading, error: agentDataError } = useSWR<AgentData>(`/api/agents/conversation?conversation_id=${props.conversationId}`, fetcher);
const {
@@ -211,9 +452,20 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
</a>
</div>
) : (
<div className="flex items-center relative text-sm">
{getIconFromIconName("lightbulb", "orange")}
Chat Options
<div className="flex items-center relative text-sm justify-between">
<p>
Chat Options
</p>
{
isEditable && customPrompt && !isDefaultAgent && selectedModel && (
<AgentCreationForm
customPrompt={customPrompt}
selectedModel={selectedModel}
inputTools={inputTools ?? []}
outputModes={outputModes ?? []}
/>
)
}
</div>
)
}

View File

@@ -1356,8 +1356,10 @@ class ConversationAdapters:
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
agent: Agent = (
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
)
if agent and agent.chat_model:
chat_model = conversation.agent.chat_model
else:

View File

@@ -110,6 +110,7 @@ async def get_agent_by_conversation(
conversation_id: str,
) -> Response:
user: KhojUser = request.user.object if request.user.is_authenticated else None
is_subscribed = has_required_scope(request, ["premium"])
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
if not conversation:
@@ -132,7 +133,7 @@ async def get_agent_by_conversation(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name,
"chat_model": agent.chat_model.name if is_subscribed else None,
"has_files": has_files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,

View File

@@ -12,7 +12,7 @@ from urllib.parse import unquote
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
from starlette.authentication import has_required_scope, requires
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
@@ -637,6 +637,7 @@ async def chat(
chat_metadata: dict = {}
connection_alive = True
user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗"
q = unquote(q)
train_of_thought = []
@@ -1251,6 +1252,7 @@ async def chat(
generated_mermaidjs_diagram,
program_execution_context,
generated_asset_results,
is_subscribed,
tracer,
)

View File

@@ -1390,6 +1390,7 @@ def generate_chat_response(
generated_mermaidjs_diagram: str = None,
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
@@ -1426,7 +1427,7 @@ def generate_chat_response(
online_results = {}
code_results = {}
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()