From d0d30ace06bfc5e2972472805bc2dde1a015f1c0 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 12 Feb 2025 18:24:41 -0800 Subject: [PATCH] Add feature to create a custom agent directly from the side panel with currently configured settings - Also, when in not subscribed state, fallback to the default model when chatting with an agent - With conversion, create a brand new agent from inside the chat view that can be managed separately --- src/interface/web/app/agents/page.tsx | 9 + .../app/components/agentCard/agentCard.tsx | 2 +- .../components/chatSidebar/chatSidebar.tsx | 264 +++++++++++++++++- src/khoj/database/adapters/__init__.py | 6 +- src/khoj/routers/api_agents.py | 3 +- src/khoj/routers/api_chat.py | 4 +- src/khoj/routers/helpers.py | 3 +- 7 files changed, 279 insertions(+), 12 deletions(-) diff --git a/src/interface/web/app/agents/page.tsx b/src/interface/web/app/agents/page.tsx index 532af822..35d23660 100644 --- a/src/interface/web/app/agents/page.tsx +++ b/src/interface/web/app/agents/page.tsx @@ -35,6 +35,7 @@ import { AppSidebar } from "../components/appSidebar/appSidebar"; import { Separator } from "@/components/ui/separator"; import { KhojLogoType } from "../components/logo/khojLogo"; import { DialogTitle } from "@radix-ui/react-dialog"; +import Link from "next/link"; const agentsFetcher = () => window @@ -343,6 +344,14 @@ export default function Agents() { /> How it works Use any of these specialized personas to tune your conversation to your needs. + { + !isSubscribed && ( + + {" "} + Upgrade your plan to leverage custom models. You will fallback to the default model when chatting. + + ) + }
diff --git a/src/interface/web/app/components/agentCard/agentCard.tsx b/src/interface/web/app/components/agentCard/agentCard.tsx index ccb601e5..dfbade2a 100644 --- a/src/interface/web/app/components/agentCard/agentCard.tsx +++ b/src/interface/web/app/components/agentCard/agentCard.tsx @@ -453,7 +453,7 @@ export function AgentCard(props: AgentCardProps) { /> ) : ( - +
{getIconFromIconName(props.data.icon, props.data.color)} diff --git a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx index 48e7d857..0775abba 100644 --- a/src/interface/web/app/components/chatSidebar/chatSidebar.tsx +++ b/src/interface/web/app/components/chatSidebar/chatSidebar.tsx @@ -1,6 +1,6 @@ "use client" -import { ArrowsDownUp, CaretCircleDown, CircleNotch, Sparkle } from "@phosphor-icons/react"; +import { ArrowsDownUp, CaretCircleDown, CheckCircle, Circle, CircleNotch, PersonSimpleTaiChi, Sparkle } from "@phosphor-icons/react"; import { Button } from "@/components/ui/button"; @@ -14,13 +14,20 @@ import { mutate } from "swr"; import { Sheet, SheetContent } from "@/components/ui/sheet"; import { AgentData } from "../agentCard/agentCard"; import { useEffect, useState } from "react"; -import { getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils"; +import { getAvailableIcons, getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils"; import { Label } from "@/components/ui/label"; import { Checkbox } from "@/components/ui/checkbox"; import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip"; import { TooltipContent } from "@radix-ui/react-tooltip"; import { useAuthenticatedData } from "@/app/common/auth"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Dialog, DialogClose, DialogContent, DialogFooter, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { convertColorToTextClass, tailwindColors } from "@/app/common/colorUtils"; +import { Input } from "@/components/ui/input"; +import Link from "next/link"; +import { motion } from "framer-motion"; + interface ChatSideBarProps { conversationId: string; @@ -54,11 +61,245 @@ export function ChatSidebar({ ...props }: ChatSideBarProps) { ); } +interface IAgentCreationProps { + customPrompt: string; + selectedModel: string; + inputTools: string[]; + outputModes: string[]; +} + +interface AgentError { + detail: string; +} + +function AgentCreationForm(props: IAgentCreationProps) { + const iconOptions = getAvailableIcons(); + const colorOptions = tailwindColors; + + const [isCreating, setIsCreating] = useState(false); + const [customAgentName, setCustomAgentName] = useState(); + const [customAgentIcon, setCustomAgentIcon] = useState(); + const [customAgentColor, setCustomAgentColor] = useState(); + + const [doneCreating, setDoneCreating] = useState(false); + const [createdSlug, setCreatedSlug] = useState(); + const [isValid, setIsValid] = useState(false); + const [error, setError] = useState(); + + function createAgent() { + if (isCreating) { + return; + } + + setIsCreating(true); + + const data = { + name: customAgentName, + icon: customAgentIcon, + color: customAgentColor, + persona: props.customPrompt, + chat_model: props.selectedModel, + input_tools: props.inputTools, + output_modes: props.outputModes, + privacy_level: "private", + }; + + const createAgentUrl = `/api/agents`; + + fetch(createAgentUrl, { + method: "POST", + headers: { + "Content-Type": "application/json" + }, + body: JSON.stringify(data) + }) + .then((res) => res.json()) + .then((data: AgentData | AgentError) => { + console.log("Success:", data); + if ('detail' in data) { + setError(`Error creating agent: ${data.detail}`); + setIsCreating(false); + return; + } + setDoneCreating(true); + setCreatedSlug(data.slug); + setIsCreating(false); + }) + .catch((error) => { + console.error("Error:", error); + setError(`Error creating agent: ${error}`); + setIsCreating(false); + }); + } + + useEffect(() => { + if (customAgentName && customAgentIcon && customAgentColor) { + setIsValid(true); + } else { + setIsValid(false); + } + }, [customAgentName, customAgentIcon, customAgentColor]); + + return ( + + + + + + + + { + doneCreating && createdSlug ? ( + + Created {customAgentName} + + ) : ( + + Create a New Agent + + ) + } + + +
+ { + doneCreating && createdSlug ? ( +
+ + + + + Created successfully! + + + + + + +
+ ) : +
+
+ + setCustomAgentName(e.target.value)} + /> +
+
+
+ +
+
+ +
+
+
+ } +
+ + { + error && ( +
+ {error} +
+ ) + } + { + !doneCreating && ( + + ) + } + +
+
+
+ + ) +} function ChatSidebarInternal({ ...props }: ChatSideBarProps) { const [isEditable, setIsEditable] = useState(false); const { data: agentConfigurationOptions, error: agentConfigurationOptionsError } = - useSWR("/api/agents/options", fetcher); + useSWR("/api/agents/options", fetcher); const { data: agentData, isLoading: agentDataLoading, error: agentDataError } = useSWR(`/api/agents/conversation?conversation_id=${props.conversationId}`, fetcher); const { @@ -211,9 +452,20 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
) : ( -
- {getIconFromIconName("lightbulb", "orange")} - Chat Options +
+

+ Chat Options +

+ { + isEditable && customPrompt && !isDefaultAgent && selectedModel && ( + + ) + }
) } diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 01f74de3..ccbbcd99 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1356,8 +1356,10 @@ class ConversationAdapters: return random.sample(all_questions, max_results) @staticmethod - def get_valid_chat_model(user: KhojUser, conversation: Conversation): - agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None + def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): + agent: Agent = ( + conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None + ) if agent and agent.chat_model: chat_model = conversation.agent.chat_model else: diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index 05fe27fc..4da756ad 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -110,6 +110,7 @@ async def get_agent_by_conversation( conversation_id: str, ) -> Response: user: KhojUser = request.user.object if request.user.is_authenticated else None + is_subscribed = has_required_scope(request, ["premium"]) conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id) if not conversation: @@ -132,7 +133,7 @@ async def get_agent_by_conversation( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.name, + "chat_model": agent.chat_model.name if is_subscribed else None, "has_files": has_files, "input_tools": agent.input_tools, "output_modes": agent.output_modes, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4d346d40..83053bab 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -12,7 +12,7 @@ from urllib.parse import unquote from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import Response, StreamingResponse -from starlette.authentication import requires +from starlette.authentication import has_required_scope, requires from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -637,6 +637,7 @@ async def chat( chat_metadata: dict = {} connection_alive = True user: KhojUser = request.user.object + is_subscribed = has_required_scope(request, ["premium"]) event_delimiter = "␃🔚␗" q = unquote(q) train_of_thought = [] @@ -1251,6 +1252,7 @@ async def chat( generated_mermaidjs_diagram, program_execution_context, generated_asset_results, + is_subscribed, tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e6857d45..fb2d49c6 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1390,6 +1390,7 @@ def generate_chat_response( generated_mermaidjs_diagram: str = None, program_execution_context: List[str] = [], generated_asset_results: Dict[str, Dict] = {}, + is_subscribed: bool = False, tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables @@ -1426,7 +1427,7 @@ def generate_chat_response( online_results = {} code_results = {} - chat_model = ConversationAdapters.get_valid_chat_model(user, conversation) + chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed) vision_available = chat_model.vision_enabled if not vision_available and query_images: vision_enabled_config = ConversationAdapters.get_vision_enabled_config()