mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Merge branch 'master' into features/advanced-reasoning
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"id": "khoj",
|
"id": "khoj",
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.24.1",
|
"version": "1.25.0",
|
||||||
"minAppVersion": "0.15.0",
|
"minAppVersion": "0.15.0",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Khoj Inc.",
|
"author": "Khoj Inc.",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.24.1",
|
"version": "1.25.0",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Khoj Inc. <team@khoj.dev>",
|
"author": "Khoj Inc. <team@khoj.dev>",
|
||||||
"license": "GPL-3.0-or-later",
|
"license": "GPL-3.0-or-later",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
;; Saba Imran <saba@khoj.dev>
|
;; Saba Imran <saba@khoj.dev>
|
||||||
;; Description: Your Second Brain
|
;; Description: Your Second Brain
|
||||||
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
|
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
|
||||||
;; Version: 1.24.1
|
;; Version: 1.25.0
|
||||||
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
||||||
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"id": "khoj",
|
"id": "khoj",
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.24.1",
|
"version": "1.25.0",
|
||||||
"minAppVersion": "0.15.0",
|
"minAppVersion": "0.15.0",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Khoj Inc.",
|
"author": "Khoj Inc.",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.24.1",
|
"version": "1.25.0",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
||||||
"license": "GPL-3.0-or-later",
|
"license": "GPL-3.0-or-later",
|
||||||
|
|||||||
@@ -76,5 +76,6 @@
|
|||||||
"1.23.2": "0.15.0",
|
"1.23.2": "0.15.0",
|
||||||
"1.23.3": "0.15.0",
|
"1.23.3": "0.15.0",
|
||||||
"1.24.0": "0.15.0",
|
"1.24.0": "0.15.0",
|
||||||
"1.24.1": "0.15.0"
|
"1.24.1": "0.15.0",
|
||||||
|
"1.25.0": "0.15.0"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ import {
|
|||||||
Globe,
|
Globe,
|
||||||
LockOpen,
|
LockOpen,
|
||||||
FloppyDisk,
|
FloppyDisk,
|
||||||
DotsThreeCircleVertical,
|
|
||||||
DotsThreeVertical,
|
DotsThreeVertical,
|
||||||
Pencil,
|
Pencil,
|
||||||
Trash,
|
Trash,
|
||||||
@@ -46,16 +45,6 @@ import {
|
|||||||
DialogHeader,
|
DialogHeader,
|
||||||
DialogTrigger,
|
DialogTrigger,
|
||||||
} from "@/components/ui/dialog";
|
} from "@/components/ui/dialog";
|
||||||
import {
|
|
||||||
Drawer,
|
|
||||||
DrawerClose,
|
|
||||||
DrawerContent,
|
|
||||||
DrawerDescription,
|
|
||||||
DrawerFooter,
|
|
||||||
DrawerHeader,
|
|
||||||
DrawerTitle,
|
|
||||||
DrawerTrigger,
|
|
||||||
} from "@/components/ui/drawer";
|
|
||||||
import LoginPrompt from "../components/loginPrompt/loginPrompt";
|
import LoginPrompt from "../components/loginPrompt/loginPrompt";
|
||||||
import { InlineLoading } from "../components/loading/loading";
|
import { InlineLoading } from "../components/loading/loading";
|
||||||
import SidePanel from "../components/sidePanel/chatHistorySidePanel";
|
import SidePanel from "../components/sidePanel/chatHistorySidePanel";
|
||||||
@@ -340,7 +329,6 @@ function AgentCard(props: AgentCardProps) {
|
|||||||
)}
|
)}
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>
|
<CardTitle>
|
||||||
{!props.isMobileWidth ? (
|
|
||||||
<Dialog
|
<Dialog
|
||||||
open={showModal}
|
open={showModal}
|
||||||
onOpenChange={() => {
|
onOpenChange={() => {
|
||||||
@@ -396,12 +384,9 @@ function AgentCard(props: AgentCardProps) {
|
|||||||
className="items-center justify-start"
|
className="items-center justify-start"
|
||||||
variant={"destructive"}
|
variant={"destructive"}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
fetch(
|
fetch(`/api/agents/${props.data.slug}`, {
|
||||||
`/api/agents/${props.data.slug}`,
|
|
||||||
{
|
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
},
|
}).then(() => {
|
||||||
).then(() => {
|
|
||||||
props.setAgentChangeTriggered(true);
|
props.setAgentChangeTriggered(true);
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
@@ -457,7 +442,7 @@ function AgentCard(props: AgentCardProps) {
|
|||||||
/>
|
/>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
) : (
|
) : (
|
||||||
<DialogContent className="whitespace-pre-line max-h-[80vh]">
|
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] rounded-lg">
|
||||||
<DialogHeader>
|
<DialogHeader>
|
||||||
<div className="flex items-center">
|
<div className="flex items-center">
|
||||||
{getIconFromIconName(props.data.icon, props.data.color)}
|
{getIconFromIconName(props.data.icon, props.data.color)}
|
||||||
@@ -487,134 +472,6 @@ function AgentCard(props: AgentCardProps) {
|
|||||||
</DialogContent>
|
</DialogContent>
|
||||||
)}
|
)}
|
||||||
</Dialog>
|
</Dialog>
|
||||||
) : (
|
|
||||||
<Drawer
|
|
||||||
open={showModal}
|
|
||||||
onOpenChange={(open) => {
|
|
||||||
setShowModal(open);
|
|
||||||
window.history.pushState({}, `Khoj AI - Agents`, `/agents`);
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<DrawerTrigger>
|
|
||||||
<div className="flex items-center">
|
|
||||||
{getIconFromIconName(props.data.icon, props.data.color)}
|
|
||||||
{props.data.name}
|
|
||||||
</div>
|
|
||||||
</DrawerTrigger>
|
|
||||||
<div className="flex float-right">
|
|
||||||
{props.editCard && (
|
|
||||||
<div className="float-right">
|
|
||||||
<Popover>
|
|
||||||
<PopoverTrigger>
|
|
||||||
<Button
|
|
||||||
className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
|
|
||||||
>
|
|
||||||
<DotsThreeVertical
|
|
||||||
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
|
|
||||||
/>
|
|
||||||
</Button>
|
|
||||||
</PopoverTrigger>
|
|
||||||
<PopoverContent
|
|
||||||
className="w-fit grid p-1"
|
|
||||||
side={"bottom"}
|
|
||||||
align={"end"}
|
|
||||||
>
|
|
||||||
<Button
|
|
||||||
className="items-center justify-start"
|
|
||||||
variant={"ghost"}
|
|
||||||
onClick={() => setShowModal(true)}
|
|
||||||
>
|
|
||||||
<Pencil className="w-4 h-4 mr-2" />
|
|
||||||
Edit
|
|
||||||
</Button>
|
|
||||||
{props.editCard &&
|
|
||||||
props.data.privacy_level !== "private" && (
|
|
||||||
<ShareLink
|
|
||||||
buttonTitle="Share"
|
|
||||||
title="Share Agent"
|
|
||||||
description="Share a link to this agent with others. They'll be able to chat with it, and ask questions to all of its knowledge base."
|
|
||||||
buttonVariant={"ghost" as const}
|
|
||||||
includeIcon={true}
|
|
||||||
url={`${window.location.origin}/agents?agent=${props.data.slug}`}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{props.data.creator === userData?.username && (
|
|
||||||
<Button
|
|
||||||
className="items-center justify-start"
|
|
||||||
variant={"destructive"}
|
|
||||||
onClick={() => {
|
|
||||||
fetch(
|
|
||||||
`/api/agents/${props.data.slug}`,
|
|
||||||
{
|
|
||||||
method: "DELETE",
|
|
||||||
},
|
|
||||||
).then(() => {
|
|
||||||
props.setAgentChangeTriggered(true);
|
|
||||||
});
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Trash className="w-4 h-4 mr-2" />
|
|
||||||
Delete
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</PopoverContent>
|
|
||||||
</Popover>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<div className="float-right">
|
|
||||||
{props.userProfile ? (
|
|
||||||
<Button
|
|
||||||
className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
|
|
||||||
onClick={() => openChat(props.data.slug, userData)}
|
|
||||||
>
|
|
||||||
<PaperPlaneTilt
|
|
||||||
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
|
|
||||||
/>
|
|
||||||
</Button>
|
|
||||||
) : (
|
|
||||||
<Button
|
|
||||||
className={`bg-[hsl(var(--background))] w-14 h-14 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
|
|
||||||
onClick={() => setShowLoginPrompt(true)}
|
|
||||||
>
|
|
||||||
<PaperPlaneTilt
|
|
||||||
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
|
|
||||||
/>
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{props.editCard ? (
|
|
||||||
<DrawerContent className="whitespace-pre-line p-2">
|
|
||||||
<AgentModificationForm
|
|
||||||
form={form}
|
|
||||||
onSubmit={onSubmit}
|
|
||||||
create={false}
|
|
||||||
errors={errors}
|
|
||||||
filesOptions={props.filesOptions}
|
|
||||||
modelOptions={props.modelOptions}
|
|
||||||
slug={props.data.slug}
|
|
||||||
inputToolOptions={props.inputToolOptions}
|
|
||||||
outputModeOptions={props.outputModeOptions}
|
|
||||||
isSubscribed={props.isSubscribed}
|
|
||||||
/>
|
|
||||||
</DrawerContent>
|
|
||||||
) : (
|
|
||||||
<DrawerContent className="whitespace-pre-line p-2">
|
|
||||||
<DrawerHeader>
|
|
||||||
<DrawerTitle>{props.data.name}</DrawerTitle>
|
|
||||||
<DrawerDescription>Persona</DrawerDescription>
|
|
||||||
</DrawerHeader>
|
|
||||||
{props.data.persona}
|
|
||||||
<div className="flex flex-wrap items-center gap-1">
|
|
||||||
{makeBadgeFooter()}
|
|
||||||
</div>
|
|
||||||
<DrawerFooter>
|
|
||||||
<DrawerClose>Done</DrawerClose>
|
|
||||||
</DrawerFooter>
|
|
||||||
</DrawerContent>
|
|
||||||
)}
|
|
||||||
</Drawer>
|
|
||||||
)}
|
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
@@ -930,7 +787,7 @@ function AgentModificationForm(props: AgentModificationFormProps) {
|
|||||||
/>
|
/>
|
||||||
<div className="grid">
|
<div className="grid">
|
||||||
<FormLabel className="mb-2">Look & Feel</FormLabel>
|
<FormLabel className="mb-2">Look & Feel</FormLabel>
|
||||||
<div className="flex gap-1 justify-left">
|
<div className="flex gap-1 justify-left flex-col md:flex-row">
|
||||||
<FormField
|
<FormField
|
||||||
control={props.form.control}
|
control={props.form.control}
|
||||||
name="color"
|
name="color"
|
||||||
@@ -1378,44 +1235,6 @@ function CreateAgentCard(props: CreateAgentCardProps) {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
if (props.isMobileWidth) {
|
|
||||||
return (
|
|
||||||
<Drawer open={showModal} onOpenChange={setShowModal}>
|
|
||||||
<DrawerTrigger>
|
|
||||||
<div className="flex items-center">
|
|
||||||
<Plus />
|
|
||||||
Create Agent
|
|
||||||
</div>
|
|
||||||
</DrawerTrigger>
|
|
||||||
<DrawerContent className="p-2">
|
|
||||||
<DrawerHeader>
|
|
||||||
<DrawerTitle>Create Agent</DrawerTitle>
|
|
||||||
</DrawerHeader>
|
|
||||||
{!props.userProfile && showLoginPrompt && (
|
|
||||||
<LoginPrompt
|
|
||||||
loginRedirectMessage="Sign in to start chatting with a specialized agent"
|
|
||||||
onOpenChange={setShowLoginPrompt}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
<AgentModificationForm
|
|
||||||
form={form}
|
|
||||||
onSubmit={onSubmit}
|
|
||||||
create={true}
|
|
||||||
errors={errors}
|
|
||||||
filesOptions={props.filesOptions}
|
|
||||||
modelOptions={props.modelOptions}
|
|
||||||
inputToolOptions={props.inputToolOptions}
|
|
||||||
outputModeOptions={props.outputModeOptions}
|
|
||||||
isSubscribed={props.isSubscribed}
|
|
||||||
/>
|
|
||||||
<DrawerFooter>
|
|
||||||
<DrawerClose>Dismiss</DrawerClose>
|
|
||||||
</DrawerFooter>
|
|
||||||
</DrawerContent>
|
|
||||||
</Drawer>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Dialog open={showModal} onOpenChange={setShowModal}>
|
<Dialog open={showModal} onOpenChange={setShowModal}>
|
||||||
<DialogTrigger>
|
<DialogTrigger>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "khoj-ai",
|
"name": "khoj-ai",
|
||||||
"version": "1.24.1",
|
"version": "1.25.0",
|
||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "next dev",
|
"dev": "next dev",
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from khoj.database.adapters import (
|
|||||||
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.routers.api_content import configure_content, configure_search
|
from khoj.routers.api_content import configure_content, configure_search
|
||||||
|
from khoj.routers.helpers import update_telemetry_state
|
||||||
from khoj.routers.twilio import is_twilio_enabled
|
from khoj.routers.twilio import is_twilio_enabled
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import SearchType
|
from khoj.utils.config import SearchType
|
||||||
@@ -165,7 +166,15 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
|
|
||||||
create_if_not_exists = request.query_params.get("create_if_not_exists")
|
create_if_not_exists = request.query_params.get("create_if_not_exists")
|
||||||
if create_if_not_exists:
|
if create_if_not_exists:
|
||||||
user = await aget_or_create_user_by_phone_number(phone_number)
|
user, is_new = await aget_or_create_user_by_phone_number(phone_number)
|
||||||
|
if user and is_new:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="create_user",
|
||||||
|
metadata={"user_id": str(user.uuid)},
|
||||||
|
)
|
||||||
|
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||||
else:
|
else:
|
||||||
user = await aget_user_by_phone_number(phone_number)
|
user = await aget_user_by_phone_number(phone_number)
|
||||||
|
|
||||||
@@ -244,7 +253,7 @@ def configure_server(
|
|||||||
|
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
setup_default_agent()
|
setup_default_agent(user)
|
||||||
|
|
||||||
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
@@ -256,8 +265,8 @@ def configure_server(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def setup_default_agent():
|
def setup_default_agent(user: KhojUser):
|
||||||
AgentAdapters.create_default_agent()
|
AgentAdapters.create_default_agent(user)
|
||||||
|
|
||||||
|
|
||||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
|
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
|
||||||
|
|||||||
@@ -113,13 +113,15 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
|
async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUser, bool]:
|
||||||
|
is_new = False
|
||||||
if is_none_or_empty(phone_number):
|
if is_none_or_empty(phone_number):
|
||||||
return None
|
return None, is_new
|
||||||
user = await aget_user_by_phone_number(phone_number)
|
user = await aget_user_by_phone_number(phone_number)
|
||||||
if not user:
|
if not user:
|
||||||
user = await acreate_user_by_phone_number(phone_number)
|
user = await acreate_user_by_phone_number(phone_number)
|
||||||
return user
|
is_new = True
|
||||||
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
|
||||||
@@ -165,8 +167,10 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def aget_or_create_user_by_email(email: str) -> KhojUser:
|
async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
|
||||||
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
|
user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
|
||||||
|
defaults={"username": email, "email": email}
|
||||||
|
)
|
||||||
await user.asave()
|
await user.asave()
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
@@ -177,7 +181,7 @@ async def aget_or_create_user_by_email(email: str) -> KhojUser:
|
|||||||
if not user_subscription:
|
if not user_subscription:
|
||||||
await Subscription.objects.acreate(user=user, type="trial")
|
await Subscription.objects.acreate(user=user, type="trial")
|
||||||
|
|
||||||
return user
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
|
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
|
||||||
@@ -248,9 +252,9 @@ def get_user_subscription(email: str) -> Optional[Subscription]:
|
|||||||
|
|
||||||
async def set_user_subscription(
|
async def set_user_subscription(
|
||||||
email: str, is_recurring=None, renewal_date=None, type="standard"
|
email: str, is_recurring=None, renewal_date=None, type="standard"
|
||||||
) -> Optional[Subscription]:
|
) -> tuple[Optional[Subscription], bool]:
|
||||||
# Get or create the user object and their subscription
|
# Get or create the user object and their subscription
|
||||||
user = await aget_or_create_user_by_email(email)
|
user, is_new = await aget_or_create_user_by_email(email)
|
||||||
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
# Update the user subscription state
|
# Update the user subscription state
|
||||||
@@ -262,7 +266,7 @@ async def set_user_subscription(
|
|||||||
elif renewal_date is not None:
|
elif renewal_date is not None:
|
||||||
user_subscription.renewal_date = renewal_date
|
user_subscription.renewal_date = renewal_date
|
||||||
await user_subscription.asave()
|
await user_subscription.asave()
|
||||||
return user_subscription
|
return user_subscription, is_new
|
||||||
|
|
||||||
|
|
||||||
def subscription_to_state(subscription: Subscription) -> str:
|
def subscription_to_state(subscription: Subscription) -> str:
|
||||||
@@ -643,8 +647,8 @@ class AgentAdapters:
|
|||||||
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_default_agent():
|
def create_default_agent(user: KhojUser):
|
||||||
default_conversation_config = ConversationAdapters.get_default_conversation_config()
|
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
|
||||||
if default_conversation_config is None:
|
if default_conversation_config is None:
|
||||||
logger.info("No default conversation config found, skipping default agent creation")
|
logger.info("No default conversation config found, skipping default agent creation")
|
||||||
return None
|
return None
|
||||||
@@ -968,29 +972,51 @@ class ConversationAdapters:
|
|||||||
return VoiceModelOption.objects.first()
|
return VoiceModelOption.objects.first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_conversation_config():
|
def get_default_conversation_config(user: KhojUser = None):
|
||||||
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
|
# Get the server chat settings
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||||
return ChatModelOptions.objects.filter().first()
|
|
||||||
return server_chat_settings.chat_default
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
|
# Get the user's chat settings, if the server chat settings are not set
|
||||||
|
user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None
|
||||||
|
if user_chat_settings is not None and user_chat_settings.setting is not None:
|
||||||
|
return user_chat_settings.setting
|
||||||
|
|
||||||
|
# Get the first chat model if even the user chat settings are not set
|
||||||
|
return ChatModelOptions.objects.filter().first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_default_conversation_config():
|
async def aget_default_conversation_config(user: KhojUser = None):
|
||||||
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
|
# Get the server chat settings
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
await ServerChatSettings.objects.filter()
|
await ServerChatSettings.objects.filter()
|
||||||
.prefetch_related("chat_default", "chat_default__openai_config")
|
.prefetch_related("chat_default", "chat_default__openai_config")
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
|
||||||
return server_chat_settings.chat_default
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
|
# Get the user's chat settings, if the server chat settings are not set
|
||||||
|
user_chat_settings = (
|
||||||
|
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst())
|
||||||
|
if user
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_chat_settings is not None and user_chat_settings.setting is not None:
|
||||||
|
return user_chat_settings.setting
|
||||||
|
|
||||||
|
# Get the first chat model if even the user chat settings are not set
|
||||||
|
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_advanced_conversation_config():
|
def get_advanced_conversation_config():
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
||||||
return ConversationAdapters.get_default_conversation_config()
|
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
|
return ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_advanced_conversation_config():
|
async def aget_advanced_conversation_config():
|
||||||
@@ -999,9 +1025,9 @@ class ConversationAdapters:
|
|||||||
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
|
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
if server_chat_settings is not None or server_chat_settings.chat_advanced is not None:
|
||||||
return await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
|
return await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_conversation_from_public_conversation(
|
def create_conversation_from_public_conversation(
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ async def text_to_image(
|
|||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
references: List[Dict[str, Any]],
|
references: List[Dict[str, Any]],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
subscribed: bool = False,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
uploaded_image_url: Optional[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
@@ -66,8 +65,8 @@ async def text_to_image(
|
|||||||
note_references=references,
|
note_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_type=text_to_image_config.model_type,
|
model_type=text_to_image_config.model_type,
|
||||||
subscribed=subscribed,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ async def search_online(
|
|||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [
|
tasks = [
|
||||||
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
|
read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
|
||||||
for link, subquery, content in webpages
|
for link, subquery, content in webpages
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
@@ -160,7 +160,7 @@ async def read_webpages(
|
|||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
|
tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
response: Dict[str, Dict] = defaultdict(dict)
|
response: Dict[str, Dict] = defaultdict(dict)
|
||||||
@@ -171,14 +171,14 @@ async def read_webpages(
|
|||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
|
subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||||
) -> Tuple[str, Union[None, str], str]:
|
) -> Tuple[str, Union[None, str], str]:
|
||||||
try:
|
try:
|
||||||
if is_none_or_empty(content):
|
if is_none_or_empty(content):
|
||||||
with timer(f"Reading web page at '{url}' took", logger):
|
with timer(f"Reading web page at '{url}' took", logger):
|
||||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
||||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||||
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
|
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
|
||||||
return subquery, extracted_info, url
|
return subquery, extracted_info, url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ async def extract_references_and_questions(
|
|||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
vision_enabled = conversation_config.vision_enabled
|
vision_enabled = conversation_config.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ def chat_history(
|
|||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
validate_conversation_config()
|
validate_conversation_config(user)
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
conversation = ConversationAdapters.get_conversation_by_user(
|
conversation = ConversationAdapters.get_conversation_by_user(
|
||||||
@@ -309,7 +309,7 @@ def get_shared_chat(
|
|||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="chat_history",
|
api="get_shared_chat_history",
|
||||||
**common.__dict__,
|
**common.__dict__,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -742,12 +742,12 @@ async def chat(
|
|||||||
q,
|
q,
|
||||||
meta_log,
|
meta_log,
|
||||||
is_automated_task,
|
is_automated_task,
|
||||||
subscribed=subscribed,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
@@ -1001,7 +1001,6 @@ async def chat(
|
|||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
subscribed=subscribed,
|
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def get_user_chat_model(
|
|||||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
chat_model = ConversationAdapters.get_conversation_config(user)
|
||||||
|
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_model = ConversationAdapters.get_default_conversation_config()
|
chat_model = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
||||||
|
|
||||||
|
|||||||
@@ -80,11 +80,19 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
|
|||||||
request.session.pop("user", None)
|
request.session.pop("user", None)
|
||||||
|
|
||||||
email = form.email
|
email = form.email
|
||||||
user = await aget_or_create_user_by_email(email)
|
user, is_new = await aget_or_create_user_by_email(email)
|
||||||
unique_id = user.email_verification_code
|
unique_id = user.email_verification_code
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
await send_magic_link_email(email, unique_id, request.base_url)
|
await send_magic_link_email(email, unique_id, request.base_url)
|
||||||
|
if is_new:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="create_user",
|
||||||
|
metadata={"user_id": str(user.uuid)},
|
||||||
|
)
|
||||||
|
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||||
|
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|||||||
@@ -124,20 +124,20 @@ def is_query_empty(query: str) -> bool:
|
|||||||
return is_none_or_empty(query.strip())
|
return is_none_or_empty(query.strip())
|
||||||
|
|
||||||
|
|
||||||
def validate_conversation_config():
|
def validate_conversation_config(user: KhojUser):
|
||||||
default_config = ConversationAdapters.get_default_conversation_config()
|
default_config = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
if default_config is None:
|
if default_config is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||||
|
|
||||||
if default_config.model_type == "openai" and not default_config.openai_config:
|
if default_config.model_type == "openai" and not default_config.openai_config:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||||
|
|
||||||
|
|
||||||
async def is_ready_to_chat(user: KhojUser):
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
|
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||||
await ConversationAdapters.aget_default_conversation_config()
|
if user_conversation_config == None:
|
||||||
)
|
user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
@@ -239,19 +239,19 @@ async def agenerate_chat_response(*args):
|
|||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||||
|
|
||||||
|
|
||||||
async def acreate_title_from_query(query: str) -> str:
|
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||||
"""
|
"""
|
||||||
Create a title from the given query
|
Create a title from the given query
|
||||||
"""
|
"""
|
||||||
title_generation_prompt = prompts.subject_generation.format(query=query)
|
title_generation_prompt = prompts.subject_generation.format(query=query)
|
||||||
|
|
||||||
with timer("Chat actor: Generate title from query", logger):
|
with timer("Chat actor: Generate title from query", logger):
|
||||||
response = await send_message_to_model_wrapper(title_generation_prompt)
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||||
|
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
|
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Check if the system prompt is safe to use
|
Check if the system prompt is safe to use
|
||||||
"""
|
"""
|
||||||
@@ -260,7 +260,7 @@ async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
|
|||||||
reason = ""
|
reason = ""
|
||||||
|
|
||||||
with timer("Chat actor: Check if safe prompt", logger):
|
with timer("Chat actor: Check if safe prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(safe_prompt_check)
|
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
|
||||||
|
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
try:
|
try:
|
||||||
@@ -281,7 +281,7 @@ async def aget_relevant_information_sources(
|
|||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
is_task: bool,
|
is_task: bool,
|
||||||
subscribed: bool,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
@@ -319,7 +319,7 @@ async def aget_relevant_information_sources(
|
|||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
subscribed=subscribed,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -355,7 +355,12 @@ async def aget_relevant_information_sources(
|
|||||||
|
|
||||||
|
|
||||||
async def aget_relevant_output_modes(
|
async def aget_relevant_output_modes(
|
||||||
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
|
query: str,
|
||||||
|
conversation_history: dict,
|
||||||
|
is_task: bool = False,
|
||||||
|
user: KhojUser = None,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
@@ -391,7 +396,7 @@ async def aget_relevant_output_modes(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
|
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
@@ -446,7 +451,7 @@ async def infer_webpage_urls(
|
|||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
|
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
@@ -493,7 +498,7 @@ async def generate_online_subqueries(
|
|||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
|
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
@@ -511,7 +516,9 @@ async def generate_online_subqueries(
|
|||||||
return [q]
|
return [q]
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
|
async def schedule_query(
|
||||||
|
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
"""
|
"""
|
||||||
@@ -523,7 +530,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
|
|||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
|
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
@@ -537,7 +544,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
|
|||||||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
|
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
"""
|
"""
|
||||||
@@ -555,14 +562,11 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Ag
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
chat_model_option=chat_model,
|
user=user,
|
||||||
subscribed=subscribed,
|
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
@@ -571,8 +575,8 @@ async def extract_relevant_summary(
|
|||||||
q: str,
|
q: str,
|
||||||
corpus: str,
|
corpus: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
subscribed: bool = False,
|
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
@@ -595,14 +599,11 @@ async def extract_relevant_summary(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
chat_model_option=chat_model,
|
user=user,
|
||||||
subscribed=subscribed,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
@@ -667,8 +668,8 @@ async def generate_better_image_prompt(
|
|||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
model_type: Optional[str] = None,
|
model_type: Optional[str] = None,
|
||||||
subscribed: bool = False,
|
|
||||||
uploaded_image_url: Optional[str] = None,
|
uploaded_image_url: Optional[str] = None,
|
||||||
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -718,12 +719,8 @@ async def generate_better_image_prompt(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
||||||
image_prompt, chat_model_option=chat_model, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
|
||||||
)
|
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
@@ -735,8 +732,9 @@ def send_message_to_model_wrapper_sync(
|
|||||||
message: str,
|
message: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
user: KhojUser = None,
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config()
|
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
@@ -1124,7 +1122,7 @@ class CommonQueryParamsClass:
|
|||||||
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||||
|
|
||||||
|
|
||||||
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool:
|
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
|
||||||
"""
|
"""
|
||||||
Decide whether to notify the user of the AI response.
|
Decide whether to notify the user of the AI response.
|
||||||
Default to notifying the user for now.
|
Default to notifying the user for now.
|
||||||
@@ -1141,7 +1139,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str) ->
|
|||||||
with timer("Chat actor: Decide to notify user of automation response", logger):
|
with timer("Chat actor: Decide to notify user of automation response", logger):
|
||||||
try:
|
try:
|
||||||
# TODO Replace with async call so we don't have to maintain a sync version
|
# TODO Replace with async call so we don't have to maintain a sync version
|
||||||
response = send_message_to_model_wrapper_sync(to_notify_or_not)
|
response = send_message_to_model_wrapper_sync(to_notify_or_not, user)
|
||||||
should_notify_result = "no" not in response.lower()
|
should_notify_result = "no" not in response.lower()
|
||||||
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
|
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
|
||||||
return should_notify_result
|
return should_notify_result
|
||||||
@@ -1233,7 +1231,9 @@ def scheduled_chat(
|
|||||||
ai_response = raw_response.text
|
ai_response = raw_response.text
|
||||||
|
|
||||||
# Notify user if the AI response is satisfactory
|
# Notify user if the AI response is satisfactory
|
||||||
if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response):
|
if should_notify(
|
||||||
|
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
|
||||||
|
):
|
||||||
if is_resend_enabled():
|
if is_resend_enabled():
|
||||||
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
|
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
|
||||||
else:
|
else:
|
||||||
@@ -1243,7 +1243,7 @@ def scheduled_chat(
|
|||||||
async def create_automation(
|
async def create_automation(
|
||||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
||||||
):
|
):
|
||||||
crontime, query_to_run, subject = await schedule_query(q, meta_log)
|
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
||||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||||
return job, crontime, query_to_run, subject
|
return job, crontime, query_to_run, subject
|
||||||
|
|
||||||
@@ -1429,9 +1429,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
|||||||
current_notion_config = get_user_notion_config(user)
|
current_notion_config = get_user_notion_config(user)
|
||||||
notion_token = current_notion_config.token if current_notion_config else ""
|
notion_token = current_notion_config.token if current_notion_config else ""
|
||||||
|
|
||||||
selected_chat_model_config = (
|
selected_chat_model_config = ConversationAdapters.get_conversation_config(
|
||||||
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
|
user
|
||||||
)
|
) or ConversationAdapters.get_default_conversation_config(user)
|
||||||
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
||||||
chat_model_options = list()
|
chat_model_options = list()
|
||||||
for chat_model in chat_models:
|
for chat_model in chat_models:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from fastapi import APIRouter, Request
|
|||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
|
from khoj.routers.helpers import update_telemetry_state
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
# Stripe integration for Khoj Cloud Subscription
|
# Stripe integration for Khoj Cloud Subscription
|
||||||
@@ -48,6 +49,8 @@ async def subscribe(request: Request):
|
|||||||
customer_id = subscription["customer"]
|
customer_id = subscription["customer"]
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
customer = stripe.Customer.retrieve(customer_id)
|
||||||
customer_email = customer["email"]
|
customer_email = customer["email"]
|
||||||
|
user = None
|
||||||
|
is_new = False
|
||||||
|
|
||||||
# Handle valid stripe webhook events
|
# Handle valid stripe webhook events
|
||||||
success = True
|
success = True
|
||||||
@@ -55,7 +58,9 @@ async def subscribe(request: Request):
|
|||||||
# Mark the user as subscribed and update the next renewal date on payment
|
# Mark the user as subscribed and update the next renewal date on payment
|
||||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||||
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
|
user, is_new = await adapters.set_user_subscription(
|
||||||
|
customer_email, is_recurring=True, renewal_date=renewal_date
|
||||||
|
)
|
||||||
success = user is not None
|
success = user is not None
|
||||||
elif event_type in {"customer.subscription.updated"}:
|
elif event_type in {"customer.subscription.updated"}:
|
||||||
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
|
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
|
||||||
@@ -63,15 +68,24 @@ async def subscribe(request: Request):
|
|||||||
if user_subscription and user_subscription.renewal_date:
|
if user_subscription and user_subscription.renewal_date:
|
||||||
# Mark user as unsubscribed or resubscribed
|
# Mark user as unsubscribed or resubscribed
|
||||||
is_recurring = not subscription["cancel_at_period_end"]
|
is_recurring = not subscription["cancel_at_period_end"]
|
||||||
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
|
user, is_new = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
|
||||||
success = updated_user is not None
|
success = user is not None
|
||||||
elif event_type in {"customer.subscription.deleted"}:
|
elif event_type in {"customer.subscription.deleted"}:
|
||||||
# Reset the user to trial state
|
# Reset the user to trial state
|
||||||
user = await adapters.set_user_subscription(
|
user, is_new = await adapters.set_user_subscription(
|
||||||
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
||||||
)
|
)
|
||||||
success = user is not None
|
success = user is not None
|
||||||
|
|
||||||
|
if user and is_new:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="create_user",
|
||||||
|
metadata={"user_id": str(user.user.uuid)},
|
||||||
|
)
|
||||||
|
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
|
||||||
|
|
||||||
logger.info(f'Stripe subscription {event["type"]} for {customer_email}')
|
logger.info(f'Stripe subscription {event["type"]} for {customer_email}')
|
||||||
return {"success": success}
|
return {"success": success}
|
||||||
|
|
||||||
|
|||||||
@@ -129,9 +129,6 @@ def initialization(interactive: bool = True):
|
|||||||
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
|
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
|
||||||
default_chat_model_name = user_chat_model_name
|
default_chat_model_name = user_chat_model_name
|
||||||
|
|
||||||
# Create a server chat settings object with the default chat model
|
|
||||||
default_chat_model = ChatModelOptions.objects.filter(chat_model=default_chat_model_name).first()
|
|
||||||
ServerChatSettings.objects.create(chat_default=default_chat_model)
|
|
||||||
logger.info("🗣️ Chat model configuration complete")
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
|
||||||
# Set up offline speech to text model
|
# Set up offline speech to text model
|
||||||
|
|||||||
@@ -76,5 +76,6 @@
|
|||||||
"1.23.2": "0.15.0",
|
"1.23.2": "0.15.0",
|
||||||
"1.23.3": "0.15.0",
|
"1.23.3": "0.15.0",
|
||||||
"1.24.0": "0.15.0",
|
"1.24.0": "0.15.0",
|
||||||
"1.24.1": "0.15.0"
|
"1.24.1": "0.15.0",
|
||||||
|
"1.25.0": "0.15.0"
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user