mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Add feature to create a custom agent directly from the side panel with currently configured settings
- Also, when in not subscribed state, fallback to the default model when chatting with an agent - With conversion, create a brand new agent from inside the chat view that can be managed separately
This commit is contained in:
@@ -35,6 +35,7 @@ import { AppSidebar } from "../components/appSidebar/appSidebar";
|
|||||||
import { Separator } from "@/components/ui/separator";
|
import { Separator } from "@/components/ui/separator";
|
||||||
import { KhojLogoType } from "../components/logo/khojLogo";
|
import { KhojLogoType } from "../components/logo/khojLogo";
|
||||||
import { DialogTitle } from "@radix-ui/react-dialog";
|
import { DialogTitle } from "@radix-ui/react-dialog";
|
||||||
|
import Link from "next/link";
|
||||||
|
|
||||||
const agentsFetcher = () =>
|
const agentsFetcher = () =>
|
||||||
window
|
window
|
||||||
@@ -343,6 +344,14 @@ export default function Agents() {
|
|||||||
/>
|
/>
|
||||||
<span className="font-bold">How it works</span> Use any of these
|
<span className="font-bold">How it works</span> Use any of these
|
||||||
specialized personas to tune your conversation to your needs.
|
specialized personas to tune your conversation to your needs.
|
||||||
|
{
|
||||||
|
!isSubscribed && (
|
||||||
|
<span>
|
||||||
|
{" "}
|
||||||
|
<Link href="/settings" className="font-bold">Upgrade your plan</Link> to leverage custom models. You will fallback to the default model when chatting.
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
<div className="pt-6 md:pt-8">
|
<div className="pt-6 md:pt-8">
|
||||||
|
|||||||
@@ -453,7 +453,7 @@ export function AgentCard(props: AgentCardProps) {
|
|||||||
/>
|
/>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
) : (
|
) : (
|
||||||
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] rounded-lg">
|
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] md:max-w-[50vw] rounded-lg">
|
||||||
<DialogHeader>
|
<DialogHeader>
|
||||||
<div className="flex items-center">
|
<div className="flex items-center">
|
||||||
{getIconFromIconName(props.data.icon, props.data.color)}
|
{getIconFromIconName(props.data.icon, props.data.color)}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"use client"
|
"use client"
|
||||||
|
|
||||||
import { ArrowsDownUp, CaretCircleDown, CircleNotch, Sparkle } from "@phosphor-icons/react";
|
import { ArrowsDownUp, CaretCircleDown, CheckCircle, Circle, CircleNotch, PersonSimpleTaiChi, Sparkle } from "@phosphor-icons/react";
|
||||||
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
||||||
@@ -14,13 +14,20 @@ import { mutate } from "swr";
|
|||||||
import { Sheet, SheetContent } from "@/components/ui/sheet";
|
import { Sheet, SheetContent } from "@/components/ui/sheet";
|
||||||
import { AgentData } from "../agentCard/agentCard";
|
import { AgentData } from "../agentCard/agentCard";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
|
import { getAvailableIcons, getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
|
||||||
import { Label } from "@/components/ui/label";
|
import { Label } from "@/components/ui/label";
|
||||||
import { Checkbox } from "@/components/ui/checkbox";
|
import { Checkbox } from "@/components/ui/checkbox";
|
||||||
import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip";
|
import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip";
|
||||||
import { TooltipContent } from "@radix-ui/react-tooltip";
|
import { TooltipContent } from "@radix-ui/react-tooltip";
|
||||||
import { useAuthenticatedData } from "@/app/common/auth";
|
import { useAuthenticatedData } from "@/app/common/auth";
|
||||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||||
|
import { Dialog, DialogClose, DialogContent, DialogFooter, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog";
|
||||||
|
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||||
|
import { convertColorToTextClass, tailwindColors } from "@/app/common/colorUtils";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import Link from "next/link";
|
||||||
|
import { motion } from "framer-motion";
|
||||||
|
|
||||||
|
|
||||||
interface ChatSideBarProps {
|
interface ChatSideBarProps {
|
||||||
conversationId: string;
|
conversationId: string;
|
||||||
@@ -54,11 +61,245 @@ export function ChatSidebar({ ...props }: ChatSideBarProps) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface IAgentCreationProps {
|
||||||
|
customPrompt: string;
|
||||||
|
selectedModel: string;
|
||||||
|
inputTools: string[];
|
||||||
|
outputModes: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AgentError {
|
||||||
|
detail: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function AgentCreationForm(props: IAgentCreationProps) {
|
||||||
|
const iconOptions = getAvailableIcons();
|
||||||
|
const colorOptions = tailwindColors;
|
||||||
|
|
||||||
|
const [isCreating, setIsCreating] = useState<boolean>(false);
|
||||||
|
const [customAgentName, setCustomAgentName] = useState<string | undefined>();
|
||||||
|
const [customAgentIcon, setCustomAgentIcon] = useState<string | undefined>();
|
||||||
|
const [customAgentColor, setCustomAgentColor] = useState<string | undefined>();
|
||||||
|
|
||||||
|
const [doneCreating, setDoneCreating] = useState<boolean>(false);
|
||||||
|
const [createdSlug, setCreatedSlug] = useState<string | undefined>();
|
||||||
|
const [isValid, setIsValid] = useState<boolean>(false);
|
||||||
|
const [error, setError] = useState<string | undefined>();
|
||||||
|
|
||||||
|
function createAgent() {
|
||||||
|
if (isCreating) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsCreating(true);
|
||||||
|
|
||||||
|
const data = {
|
||||||
|
name: customAgentName,
|
||||||
|
icon: customAgentIcon,
|
||||||
|
color: customAgentColor,
|
||||||
|
persona: props.customPrompt,
|
||||||
|
chat_model: props.selectedModel,
|
||||||
|
input_tools: props.inputTools,
|
||||||
|
output_modes: props.outputModes,
|
||||||
|
privacy_level: "private",
|
||||||
|
};
|
||||||
|
|
||||||
|
const createAgentUrl = `/api/agents`;
|
||||||
|
|
||||||
|
fetch(createAgentUrl, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
body: JSON.stringify(data)
|
||||||
|
})
|
||||||
|
.then((res) => res.json())
|
||||||
|
.then((data: AgentData | AgentError) => {
|
||||||
|
console.log("Success:", data);
|
||||||
|
if ('detail' in data) {
|
||||||
|
setError(`Error creating agent: ${data.detail}`);
|
||||||
|
setIsCreating(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setDoneCreating(true);
|
||||||
|
setCreatedSlug(data.slug);
|
||||||
|
setIsCreating(false);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error("Error:", error);
|
||||||
|
setError(`Error creating agent: ${error}`);
|
||||||
|
setIsCreating(false);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (customAgentName && customAgentIcon && customAgentColor) {
|
||||||
|
setIsValid(true);
|
||||||
|
} else {
|
||||||
|
setIsValid(false);
|
||||||
|
}
|
||||||
|
}, [customAgentName, customAgentIcon, customAgentColor]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
|
||||||
|
<Dialog>
|
||||||
|
<DialogTrigger asChild>
|
||||||
|
<Button
|
||||||
|
className="p-1"
|
||||||
|
variant="ghost"
|
||||||
|
>
|
||||||
|
Create Agent
|
||||||
|
</Button>
|
||||||
|
</DialogTrigger>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
{
|
||||||
|
doneCreating && createdSlug ? (
|
||||||
|
<DialogTitle>
|
||||||
|
Created {customAgentName}
|
||||||
|
</DialogTitle>
|
||||||
|
) : (
|
||||||
|
<DialogTitle>
|
||||||
|
Create a New Agent
|
||||||
|
</DialogTitle>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<DialogClose />
|
||||||
|
</DialogHeader>
|
||||||
|
<div className="py-4">
|
||||||
|
{
|
||||||
|
doneCreating && createdSlug ? (
|
||||||
|
<div className="flex flex-col items-center justify-center gap-4 py-8">
|
||||||
|
<motion.div
|
||||||
|
initial={{ scale: 0 }}
|
||||||
|
animate={{ scale: 1 }}
|
||||||
|
transition={{
|
||||||
|
type: "spring",
|
||||||
|
stiffness: 260,
|
||||||
|
damping: 20
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<CheckCircle
|
||||||
|
className="w-16 h-16 text-green-500"
|
||||||
|
weight="fill"
|
||||||
|
/>
|
||||||
|
</motion.div>
|
||||||
|
<motion.p
|
||||||
|
initial={{ opacity: 0, y: 10 }}
|
||||||
|
animate={{ opacity: 1, y: 0 }}
|
||||||
|
transition={{ delay: 0.2 }}
|
||||||
|
className="text-center text-lg font-medium text-accent-foreground"
|
||||||
|
>
|
||||||
|
Created successfully!
|
||||||
|
</motion.p>
|
||||||
|
<motion.div
|
||||||
|
initial={{ opacity: 0, y: 10 }}
|
||||||
|
animate={{ opacity: 1, y: 0 }}
|
||||||
|
transition={{ delay: 0.4 }}
|
||||||
|
>
|
||||||
|
<Link href={`/agents?agent=${createdSlug}`}>
|
||||||
|
<Button variant="secondary" className="mt-2">
|
||||||
|
Manage Agent
|
||||||
|
</Button>
|
||||||
|
</Link>
|
||||||
|
</motion.div>
|
||||||
|
</div>
|
||||||
|
) :
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<div>
|
||||||
|
<Label htmlFor="agent_name">Name</Label>
|
||||||
|
<Input
|
||||||
|
id="agent_name"
|
||||||
|
className="w-full p-2 border mt-4 border-slate-500 rounded-lg"
|
||||||
|
disabled={isCreating}
|
||||||
|
value={customAgentName}
|
||||||
|
onChange={(e) => setCustomAgentName(e.target.value)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-4">
|
||||||
|
<div className="flex-1">
|
||||||
|
<Select onValueChange={setCustomAgentColor} defaultValue={customAgentColor}>
|
||||||
|
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
|
||||||
|
<SelectValue placeholder="Color" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent className="items-center space-y-1 inline-flex flex-col">
|
||||||
|
{colorOptions.map((colorOption) => (
|
||||||
|
<SelectItem key={colorOption} value={colorOption}>
|
||||||
|
<div className="flex items-center space-x-2">
|
||||||
|
<Circle
|
||||||
|
className={`w-6 h-6 mr-2 ${convertColorToTextClass(colorOption)}`}
|
||||||
|
weight="fill"
|
||||||
|
/>
|
||||||
|
{colorOption}
|
||||||
|
</div>
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<Select onValueChange={setCustomAgentIcon} defaultValue={customAgentIcon}>
|
||||||
|
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
|
||||||
|
<SelectValue placeholder="Icon" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent className="items-center space-y-1 inline-flex flex-col">
|
||||||
|
{iconOptions.map((iconOption) => (
|
||||||
|
<SelectItem key={iconOption} value={iconOption}>
|
||||||
|
<div className="flex items-center space-x-2">
|
||||||
|
{getIconFromIconName(
|
||||||
|
iconOption,
|
||||||
|
customAgentColor ?? "gray",
|
||||||
|
"w-6",
|
||||||
|
"h-6",
|
||||||
|
)}
|
||||||
|
{iconOption}
|
||||||
|
</div>
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
<DialogFooter>
|
||||||
|
{
|
||||||
|
error && (
|
||||||
|
<div className="text-red-500 text-sm">
|
||||||
|
{error}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
!doneCreating && (
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
onClick={() => createAgent()}
|
||||||
|
disabled={isCreating || !isValid}
|
||||||
|
>
|
||||||
|
{
|
||||||
|
isCreating ?
|
||||||
|
<CircleNotch className="animate-spin" />
|
||||||
|
:
|
||||||
|
<PersonSimpleTaiChi />
|
||||||
|
}
|
||||||
|
Create
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<DialogClose />
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog >
|
||||||
|
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
|
function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
|
||||||
const [isEditable, setIsEditable] = useState<boolean>(false);
|
const [isEditable, setIsEditable] = useState<boolean>(false);
|
||||||
const { data: agentConfigurationOptions, error: agentConfigurationOptionsError } =
|
const { data: agentConfigurationOptions, error: agentConfigurationOptionsError } =
|
||||||
useSWR<AgentConfigurationOptions>("/api/agents/options", fetcher);
|
useSWR<AgentConfigurationOptions>("/api/agents/options", fetcher);
|
||||||
|
|
||||||
const { data: agentData, isLoading: agentDataLoading, error: agentDataError } = useSWR<AgentData>(`/api/agents/conversation?conversation_id=${props.conversationId}`, fetcher);
|
const { data: agentData, isLoading: agentDataLoading, error: agentDataError } = useSWR<AgentData>(`/api/agents/conversation?conversation_id=${props.conversationId}`, fetcher);
|
||||||
const {
|
const {
|
||||||
@@ -211,9 +452,20 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
|
|||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="flex items-center relative text-sm">
|
<div className="flex items-center relative text-sm justify-between">
|
||||||
{getIconFromIconName("lightbulb", "orange")}
|
<p>
|
||||||
Chat Options
|
Chat Options
|
||||||
|
</p>
|
||||||
|
{
|
||||||
|
isEditable && customPrompt && !isDefaultAgent && selectedModel && (
|
||||||
|
<AgentCreationForm
|
||||||
|
customPrompt={customPrompt}
|
||||||
|
selectedModel={selectedModel}
|
||||||
|
inputTools={inputTools ?? []}
|
||||||
|
outputModes={outputModes ?? []}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1356,8 +1356,10 @@ class ConversationAdapters:
|
|||||||
return random.sample(all_questions, max_results)
|
return random.sample(all_questions, max_results)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
|
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
|
||||||
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
agent: Agent = (
|
||||||
|
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
|
||||||
|
)
|
||||||
if agent and agent.chat_model:
|
if agent and agent.chat_model:
|
||||||
chat_model = conversation.agent.chat_model
|
chat_model = conversation.agent.chat_model
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ async def get_agent_by_conversation(
|
|||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||||
|
is_subscribed = has_required_scope(request, ["premium"])
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
|
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
@@ -132,7 +133,7 @@ async def get_agent_by_conversation(
|
|||||||
"color": agent.style_color,
|
"color": agent.style_color,
|
||||||
"icon": agent.style_icon,
|
"icon": agent.style_icon,
|
||||||
"privacy_level": agent.privacy_level,
|
"privacy_level": agent.privacy_level,
|
||||||
"chat_model": agent.chat_model.name,
|
"chat_model": agent.chat_model.name if is_subscribed else None,
|
||||||
"has_files": has_files,
|
"has_files": has_files,
|
||||||
"input_tools": agent.input_tools,
|
"input_tools": agent.input_tools,
|
||||||
"output_modes": agent.output_modes,
|
"output_modes": agent.output_modes,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from urllib.parse import unquote
|
|||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import has_required_scope, requires
|
||||||
|
|
||||||
from khoj.app.settings import ALLOWED_HOSTS
|
from khoj.app.settings import ALLOWED_HOSTS
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
@@ -637,6 +637,7 @@ async def chat(
|
|||||||
chat_metadata: dict = {}
|
chat_metadata: dict = {}
|
||||||
connection_alive = True
|
connection_alive = True
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
is_subscribed = has_required_scope(request, ["premium"])
|
||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
train_of_thought = []
|
train_of_thought = []
|
||||||
@@ -1251,6 +1252,7 @@ async def chat(
|
|||||||
generated_mermaidjs_diagram,
|
generated_mermaidjs_diagram,
|
||||||
program_execution_context,
|
program_execution_context,
|
||||||
generated_asset_results,
|
generated_asset_results,
|
||||||
|
is_subscribed,
|
||||||
tracer,
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1390,6 +1390,7 @@ def generate_chat_response(
|
|||||||
generated_mermaidjs_diagram: str = None,
|
generated_mermaidjs_diagram: str = None,
|
||||||
program_execution_context: List[str] = [],
|
program_execution_context: List[str] = [],
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
|
is_subscribed: bool = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
@@ -1426,7 +1427,7 @@ def generate_chat_response(
|
|||||||
online_results = {}
|
online_results = {}
|
||||||
code_results = {}
|
code_results = {}
|
||||||
|
|
||||||
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
|
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
|
||||||
vision_available = chat_model.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user