Add feature to create a custom agent directly from the side panel with currently configured settings

- Also, when in not subscribed state, fallback to the default model when chatting with an agent
- With conversion, create a brand new agent from inside the chat view that can be managed separately
This commit is contained in:
sabaimran
2025-02-12 18:24:41 -08:00
parent 5d6eca4c22
commit d0d30ace06
7 changed files with 279 additions and 12 deletions

View File

@@ -35,6 +35,7 @@ import { AppSidebar } from "../components/appSidebar/appSidebar";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo"; import { KhojLogoType } from "../components/logo/khojLogo";
import { DialogTitle } from "@radix-ui/react-dialog"; import { DialogTitle } from "@radix-ui/react-dialog";
import Link from "next/link";
const agentsFetcher = () => const agentsFetcher = () =>
window window
@@ -343,6 +344,14 @@ export default function Agents() {
/> />
<span className="font-bold">How it works</span> Use any of these <span className="font-bold">How it works</span> Use any of these
specialized personas to tune your conversation to your needs. specialized personas to tune your conversation to your needs.
{
!isSubscribed && (
<span>
{" "}
<Link href="/settings" className="font-bold">Upgrade your plan</Link> to leverage custom models. You will fallback to the default model when chatting.
</span>
)
}
</AlertDescription> </AlertDescription>
</Alert> </Alert>
<div className="pt-6 md:pt-8"> <div className="pt-6 md:pt-8">

View File

@@ -453,7 +453,7 @@ export function AgentCard(props: AgentCardProps) {
/> />
</DialogContent> </DialogContent>
) : ( ) : (
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] rounded-lg"> <DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] md:max-w-[50vw] rounded-lg">
<DialogHeader> <DialogHeader>
<div className="flex items-center"> <div className="flex items-center">
{getIconFromIconName(props.data.icon, props.data.color)} {getIconFromIconName(props.data.icon, props.data.color)}

View File

@@ -1,6 +1,6 @@
"use client" "use client"
import { ArrowsDownUp, CaretCircleDown, CircleNotch, Sparkle } from "@phosphor-icons/react"; import { ArrowsDownUp, CaretCircleDown, CheckCircle, Circle, CircleNotch, PersonSimpleTaiChi, Sparkle } from "@phosphor-icons/react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
@@ -14,13 +14,20 @@ import { mutate } from "swr";
import { Sheet, SheetContent } from "@/components/ui/sheet"; import { Sheet, SheetContent } from "@/components/ui/sheet";
import { AgentData } from "../agentCard/agentCard"; import { AgentData } from "../agentCard/agentCard";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils"; import { getAvailableIcons, getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
import { Label } from "@/components/ui/label"; import { Label } from "@/components/ui/label";
import { Checkbox } from "@/components/ui/checkbox"; import { Checkbox } from "@/components/ui/checkbox";
import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip";
import { TooltipContent } from "@radix-ui/react-tooltip"; import { TooltipContent } from "@radix-ui/react-tooltip";
import { useAuthenticatedData } from "@/app/common/auth"; import { useAuthenticatedData } from "@/app/common/auth";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Dialog, DialogClose, DialogContent, DialogFooter, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
import { convertColorToTextClass, tailwindColors } from "@/app/common/colorUtils";
import { Input } from "@/components/ui/input";
import Link from "next/link";
import { motion } from "framer-motion";
interface ChatSideBarProps { interface ChatSideBarProps {
conversationId: string; conversationId: string;
@@ -54,6 +61,240 @@ export function ChatSidebar({ ...props }: ChatSideBarProps) {
); );
} }
interface IAgentCreationProps {
customPrompt: string;
selectedModel: string;
inputTools: string[];
outputModes: string[];
}
interface AgentError {
detail: string;
}
function AgentCreationForm(props: IAgentCreationProps) {
const iconOptions = getAvailableIcons();
const colorOptions = tailwindColors;
const [isCreating, setIsCreating] = useState<boolean>(false);
const [customAgentName, setCustomAgentName] = useState<string | undefined>();
const [customAgentIcon, setCustomAgentIcon] = useState<string | undefined>();
const [customAgentColor, setCustomAgentColor] = useState<string | undefined>();
const [doneCreating, setDoneCreating] = useState<boolean>(false);
const [createdSlug, setCreatedSlug] = useState<string | undefined>();
const [isValid, setIsValid] = useState<boolean>(false);
const [error, setError] = useState<string | undefined>();
function createAgent() {
if (isCreating) {
return;
}
setIsCreating(true);
const data = {
name: customAgentName,
icon: customAgentIcon,
color: customAgentColor,
persona: props.customPrompt,
chat_model: props.selectedModel,
input_tools: props.inputTools,
output_modes: props.outputModes,
privacy_level: "private",
};
const createAgentUrl = `/api/agents`;
fetch(createAgentUrl, {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify(data)
})
.then((res) => res.json())
.then((data: AgentData | AgentError) => {
console.log("Success:", data);
if ('detail' in data) {
setError(`Error creating agent: ${data.detail}`);
setIsCreating(false);
return;
}
setDoneCreating(true);
setCreatedSlug(data.slug);
setIsCreating(false);
})
.catch((error) => {
console.error("Error:", error);
setError(`Error creating agent: ${error}`);
setIsCreating(false);
});
}
useEffect(() => {
if (customAgentName && customAgentIcon && customAgentColor) {
setIsValid(true);
} else {
setIsValid(false);
}
}, [customAgentName, customAgentIcon, customAgentColor]);
return (
<Dialog>
<DialogTrigger asChild>
<Button
className="p-1"
variant="ghost"
>
Create Agent
</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
{
doneCreating && createdSlug ? (
<DialogTitle>
Created {customAgentName}
</DialogTitle>
) : (
<DialogTitle>
Create a New Agent
</DialogTitle>
)
}
<DialogClose />
</DialogHeader>
<div className="py-4">
{
doneCreating && createdSlug ? (
<div className="flex flex-col items-center justify-center gap-4 py-8">
<motion.div
initial={{ scale: 0 }}
animate={{ scale: 1 }}
transition={{
type: "spring",
stiffness: 260,
damping: 20
}}
>
<CheckCircle
className="w-16 h-16 text-green-500"
weight="fill"
/>
</motion.div>
<motion.p
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.2 }}
className="text-center text-lg font-medium text-accent-foreground"
>
Created successfully!
</motion.p>
<motion.div
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.4 }}
>
<Link href={`/agents?agent=${createdSlug}`}>
<Button variant="secondary" className="mt-2">
Manage Agent
</Button>
</Link>
</motion.div>
</div>
) :
<div className="flex flex-col gap-4">
<div>
<Label htmlFor="agent_name">Name</Label>
<Input
id="agent_name"
className="w-full p-2 border mt-4 border-slate-500 rounded-lg"
disabled={isCreating}
value={customAgentName}
onChange={(e) => setCustomAgentName(e.target.value)}
/>
</div>
<div className="flex gap-4">
<div className="flex-1">
<Select onValueChange={setCustomAgentColor} defaultValue={customAgentColor}>
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
<SelectValue placeholder="Color" />
</SelectTrigger>
<SelectContent className="items-center space-y-1 inline-flex flex-col">
{colorOptions.map((colorOption) => (
<SelectItem key={colorOption} value={colorOption}>
<div className="flex items-center space-x-2">
<Circle
className={`w-6 h-6 mr-2 ${convertColorToTextClass(colorOption)}`}
weight="fill"
/>
{colorOption}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex-1">
<Select onValueChange={setCustomAgentIcon} defaultValue={customAgentIcon}>
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
<SelectValue placeholder="Icon" />
</SelectTrigger>
<SelectContent className="items-center space-y-1 inline-flex flex-col">
{iconOptions.map((iconOption) => (
<SelectItem key={iconOption} value={iconOption}>
<div className="flex items-center space-x-2">
{getIconFromIconName(
iconOption,
customAgentColor ?? "gray",
"w-6",
"h-6",
)}
{iconOption}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
</div>
}
</div>
<DialogFooter>
{
error && (
<div className="text-red-500 text-sm">
{error}
</div>
)
}
{
!doneCreating && (
<Button
type="submit"
onClick={() => createAgent()}
disabled={isCreating || !isValid}
>
{
isCreating ?
<CircleNotch className="animate-spin" />
:
<PersonSimpleTaiChi />
}
Create
</Button>
)
}
<DialogClose />
</DialogFooter>
</DialogContent>
</Dialog >
)
}
function ChatSidebarInternal({ ...props }: ChatSideBarProps) { function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
const [isEditable, setIsEditable] = useState<boolean>(false); const [isEditable, setIsEditable] = useState<boolean>(false);
@@ -211,9 +452,20 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
</a> </a>
</div> </div>
) : ( ) : (
<div className="flex items-center relative text-sm"> <div className="flex items-center relative text-sm justify-between">
{getIconFromIconName("lightbulb", "orange")} <p>
Chat Options Chat Options
</p>
{
isEditable && customPrompt && !isDefaultAgent && selectedModel && (
<AgentCreationForm
customPrompt={customPrompt}
selectedModel={selectedModel}
inputTools={inputTools ?? []}
outputModes={outputModes ?? []}
/>
)
}
</div> </div>
) )
} }

View File

@@ -1356,8 +1356,10 @@ class ConversationAdapters:
return random.sample(all_questions, max_results) return random.sample(all_questions, max_results)
@staticmethod @staticmethod
def get_valid_chat_model(user: KhojUser, conversation: Conversation): def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None agent: Agent = (
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
)
if agent and agent.chat_model: if agent and agent.chat_model:
chat_model = conversation.agent.chat_model chat_model = conversation.agent.chat_model
else: else:

View File

@@ -110,6 +110,7 @@ async def get_agent_by_conversation(
conversation_id: str, conversation_id: str,
) -> Response: ) -> Response:
user: KhojUser = request.user.object if request.user.is_authenticated else None user: KhojUser = request.user.object if request.user.is_authenticated else None
is_subscribed = has_required_scope(request, ["premium"])
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id) conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
if not conversation: if not conversation:
@@ -132,7 +133,7 @@ async def get_agent_by_conversation(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name, "chat_model": agent.chat_model.name if is_subscribed else None,
"has_files": has_files, "has_files": has_files,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,

View File

@@ -12,7 +12,7 @@ from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import has_required_scope, requires
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
@@ -637,6 +637,7 @@ async def chat(
chat_metadata: dict = {} chat_metadata: dict = {}
connection_alive = True connection_alive = True
user: KhojUser = request.user.object user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
train_of_thought = [] train_of_thought = []
@@ -1251,6 +1252,7 @@ async def chat(
generated_mermaidjs_diagram, generated_mermaidjs_diagram,
program_execution_context, program_execution_context,
generated_asset_results, generated_asset_results,
is_subscribed,
tracer, tracer,
) )

View File

@@ -1390,6 +1390,7 @@ def generate_chat_response(
generated_mermaidjs_diagram: str = None, generated_mermaidjs_diagram: str = None,
program_execution_context: List[str] = [], program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {}, tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
@@ -1426,7 +1427,7 @@ def generate_chat_response(
online_results = {} online_results = {}
code_results = {} code_results = {}
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation) chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled vision_available = chat_model.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config() vision_enabled_config = ConversationAdapters.get_vision_enabled_config()