Merge branch 'master' into features/advanced-reasoning

This commit is contained in:
Debanjum Singh Solanky
2024-10-12 21:01:22 -07:00
20 changed files with 289 additions and 416 deletions

View File

@@ -1,7 +1,7 @@
{ {
"id": "khoj", "id": "khoj",
"name": "Khoj", "name": "Khoj",
"version": "1.24.1", "version": "1.25.0",
"minAppVersion": "0.15.0", "minAppVersion": "0.15.0",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Khoj Inc.", "author": "Khoj Inc.",

View File

@@ -1,6 +1,6 @@
{ {
"name": "Khoj", "name": "Khoj",
"version": "1.24.1", "version": "1.25.0",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Khoj Inc. <team@khoj.dev>", "author": "Khoj Inc. <team@khoj.dev>",
"license": "GPL-3.0-or-later", "license": "GPL-3.0-or-later",

View File

@@ -6,7 +6,7 @@
;; Saba Imran <saba@khoj.dev> ;; Saba Imran <saba@khoj.dev>
;; Description: Your Second Brain ;; Description: Your Second Brain
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image ;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
;; Version: 1.24.1 ;; Version: 1.25.0
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1")) ;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs ;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs

View File

@@ -1,7 +1,7 @@
{ {
"id": "khoj", "id": "khoj",
"name": "Khoj", "name": "Khoj",
"version": "1.24.1", "version": "1.25.0",
"minAppVersion": "0.15.0", "minAppVersion": "0.15.0",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Khoj Inc.", "author": "Khoj Inc.",

View File

@@ -1,6 +1,6 @@
{ {
"name": "Khoj", "name": "Khoj",
"version": "1.24.1", "version": "1.25.0",
"description": "Your Second Brain", "description": "Your Second Brain",
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>", "author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
"license": "GPL-3.0-or-later", "license": "GPL-3.0-or-later",

View File

@@ -76,5 +76,6 @@
"1.23.2": "0.15.0", "1.23.2": "0.15.0",
"1.23.3": "0.15.0", "1.23.3": "0.15.0",
"1.24.0": "0.15.0", "1.24.0": "0.15.0",
"1.24.1": "0.15.0" "1.24.1": "0.15.0",
"1.25.0": "0.15.0"
} }

View File

@@ -32,7 +32,6 @@ import {
Globe, Globe,
LockOpen, LockOpen,
FloppyDisk, FloppyDisk,
DotsThreeCircleVertical,
DotsThreeVertical, DotsThreeVertical,
Pencil, Pencil,
Trash, Trash,
@@ -46,16 +45,6 @@ import {
DialogHeader, DialogHeader,
DialogTrigger, DialogTrigger,
} from "@/components/ui/dialog"; } from "@/components/ui/dialog";
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerFooter,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "@/components/ui/drawer";
import LoginPrompt from "../components/loginPrompt/loginPrompt"; import LoginPrompt from "../components/loginPrompt/loginPrompt";
import { InlineLoading } from "../components/loading/loading"; import { InlineLoading } from "../components/loading/loading";
import SidePanel from "../components/sidePanel/chatHistorySidePanel"; import SidePanel from "../components/sidePanel/chatHistorySidePanel";
@@ -340,281 +329,149 @@ function AgentCard(props: AgentCardProps) {
)} )}
<CardHeader> <CardHeader>
<CardTitle> <CardTitle>
{!props.isMobileWidth ? ( <Dialog
<Dialog open={showModal}
open={showModal} onOpenChange={() => {
onOpenChange={() => { setShowModal(!showModal);
setShowModal(!showModal); window.history.pushState({}, `Khoj AI - Agents`, `/agents`);
window.history.pushState({}, `Khoj AI - Agents`, `/agents`); }}
}} >
> <DialogTrigger>
<DialogTrigger> <div className="flex items-center relative top-2">
<div className="flex items-center relative top-2"> {getIconFromIconName(props.data.icon, props.data.color)}
{getIconFromIconName(props.data.icon, props.data.color)} {props.data.name}
{props.data.name} </div>
</div> </DialogTrigger>
</DialogTrigger> <div className="flex float-right">
<div className="flex float-right"> {props.editCard && (
{props.editCard && ( <div className="float-right">
<div className="float-right"> <Popover>
<Popover> <PopoverTrigger>
<PopoverTrigger> <Button
<Button className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
>
<DotsThreeVertical
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
</PopoverTrigger>
<PopoverContent
className="w-fit grid p-1"
side={"bottom"}
align={"end"}
> >
<DotsThreeVertical
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
</PopoverTrigger>
<PopoverContent
className="w-fit grid p-1"
side={"bottom"}
align={"end"}
>
<Button
className="items-center justify-start"
variant={"ghost"}
onClick={() => setShowModal(true)}
>
<Pencil className="w-4 h-4 mr-2" />
Edit
</Button>
{props.editCard &&
props.data.privacy_level !== "private" && (
<ShareLink
buttonTitle="Share"
title="Share Agent"
description="Share a link to this agent with others. They'll be able to chat with it, and ask questions to all of its knowledge base."
buttonVariant={"ghost" as const}
includeIcon={true}
url={`${window.location.origin}/agents?agent=${props.data.slug}`}
/>
)}
{props.data.creator === userData?.username && (
<Button <Button
className="items-center justify-start" className="items-center justify-start"
variant={"ghost"} variant={"destructive"}
onClick={() => setShowModal(true)} onClick={() => {
fetch(`/api/agents/${props.data.slug}`, {
method: "DELETE",
}).then(() => {
props.setAgentChangeTriggered(true);
});
}}
> >
<Pencil className="w-4 h-4 mr-2" /> <Trash className="w-4 h-4 mr-2" />
Edit Delete
</Button> </Button>
{props.editCard && )}
props.data.privacy_level !== "private" && ( </PopoverContent>
<ShareLink </Popover>
buttonTitle="Share"
title="Share Agent"
description="Share a link to this agent with others. They'll be able to chat with it, and ask questions to all of its knowledge base."
buttonVariant={"ghost" as const}
includeIcon={true}
url={`${window.location.origin}/agents?agent=${props.data.slug}`}
/>
)}
{props.data.creator === userData?.username && (
<Button
className="items-center justify-start"
variant={"destructive"}
onClick={() => {
fetch(
`/api/agents/${props.data.slug}`,
{
method: "DELETE",
},
).then(() => {
props.setAgentChangeTriggered(true);
});
}}
>
<Trash className="w-4 h-4 mr-2" />
Delete
</Button>
)}
</PopoverContent>
</Popover>
</div>
)}
<div className="float-right">
{props.userProfile ? (
<Button
className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
onClick={() => openChat(props.data.slug, userData)}
>
<PaperPlaneTilt
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
) : (
<Button
className={`bg-[hsl(var(--background))] w-14 h-14 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
onClick={() => setShowLoginPrompt(true)}
>
<PaperPlaneTilt
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
)}
</div> </div>
</div>
{props.editCard ? (
<DialogContent
className={"lg:max-w-screen-lg overflow-y-scroll max-h-screen"}
>
<DialogTitle>
Edit <b>{props.data.name}</b>
</DialogTitle>
<AgentModificationForm
form={form}
onSubmit={onSubmit}
create={false}
errors={errors}
filesOptions={props.filesOptions}
modelOptions={props.modelOptions}
slug={props.data.slug}
inputToolOptions={props.inputToolOptions}
isSubscribed={props.isSubscribed}
outputModeOptions={props.outputModeOptions}
/>
</DialogContent>
) : (
<DialogContent className="whitespace-pre-line max-h-[80vh]">
<DialogHeader>
<div className="flex items-center">
{getIconFromIconName(props.data.icon, props.data.color)}
<p className="font-bold text-lg">{props.data.name}</p>
</div>
</DialogHeader>
<div className="max-h-[60vh] overflow-y-scroll text-neutral-500 dark:text-white">
{props.data.persona}
</div>
<div className="flex flex-wrap items-center gap-1">
{makeBadgeFooter()}
</div>
<DialogFooter>
<Button
className={`pt-6 pb-6 ${stylingString} bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white border-2 border-stone-100 shadow-sm rounded-xl hover:bg-stone-100 dark:hover:bg-neutral-900 dark:border-neutral-700`}
onClick={() => {
openChat(props.data.slug, userData);
setShowModal(false);
}}
>
<PaperPlaneTilt
className={`w-6 h-6 m-2 ${convertColorToTextClass(props.data.color)}`}
/>
Start Chatting
</Button>
</DialogFooter>
</DialogContent>
)} )}
</Dialog> <div className="float-right">
) : ( {props.userProfile ? (
<Drawer <Button
open={showModal} className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
onOpenChange={(open) => { onClick={() => openChat(props.data.slug, userData)}
setShowModal(open); >
window.history.pushState({}, `Khoj AI - Agents`, `/agents`); <PaperPlaneTilt
}} className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
> />
<DrawerTrigger> </Button>
<div className="flex items-center"> ) : (
{getIconFromIconName(props.data.icon, props.data.color)} <Button
{props.data.name} className={`bg-[hsl(var(--background))] w-14 h-14 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
</div> onClick={() => setShowLoginPrompt(true)}
</DrawerTrigger> >
<div className="flex float-right"> <PaperPlaneTilt
{props.editCard && ( className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
<div className="float-right"> />
<Popover> </Button>
<PopoverTrigger>
<Button
className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
>
<DotsThreeVertical
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
</PopoverTrigger>
<PopoverContent
className="w-fit grid p-1"
side={"bottom"}
align={"end"}
>
<Button
className="items-center justify-start"
variant={"ghost"}
onClick={() => setShowModal(true)}
>
<Pencil className="w-4 h-4 mr-2" />
Edit
</Button>
{props.editCard &&
props.data.privacy_level !== "private" && (
<ShareLink
buttonTitle="Share"
title="Share Agent"
description="Share a link to this agent with others. They'll be able to chat with it, and ask questions to all of its knowledge base."
buttonVariant={"ghost" as const}
includeIcon={true}
url={`${window.location.origin}/agents?agent=${props.data.slug}`}
/>
)}
{props.data.creator === userData?.username && (
<Button
className="items-center justify-start"
variant={"destructive"}
onClick={() => {
fetch(
`/api/agents/${props.data.slug}`,
{
method: "DELETE",
},
).then(() => {
props.setAgentChangeTriggered(true);
});
}}
>
<Trash className="w-4 h-4 mr-2" />
Delete
</Button>
)}
</PopoverContent>
</Popover>
</div>
)} )}
<div className="float-right">
{props.userProfile ? (
<Button
className={`bg-[hsl(var(--background))] w-10 h-10 p-0 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
onClick={() => openChat(props.data.slug, userData)}
>
<PaperPlaneTilt
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
) : (
<Button
className={`bg-[hsl(var(--background))] w-14 h-14 rounded-xl border dark:border-neutral-700 shadow-sm hover:bg-stone-100 dark:hover:bg-neutral-900`}
onClick={() => setShowLoginPrompt(true)}
>
<PaperPlaneTilt
className={`w-6 h-6 ${convertColorToTextClass(props.data.color)}`}
/>
</Button>
)}
</div>
</div> </div>
{props.editCard ? ( </div>
<DrawerContent className="whitespace-pre-line p-2"> {props.editCard ? (
<AgentModificationForm <DialogContent
form={form} className={"lg:max-w-screen-lg overflow-y-scroll max-h-screen"}
onSubmit={onSubmit} >
create={false} <DialogTitle>
errors={errors} Edit <b>{props.data.name}</b>
filesOptions={props.filesOptions} </DialogTitle>
modelOptions={props.modelOptions} <AgentModificationForm
slug={props.data.slug} form={form}
inputToolOptions={props.inputToolOptions} onSubmit={onSubmit}
outputModeOptions={props.outputModeOptions} create={false}
isSubscribed={props.isSubscribed} errors={errors}
/> filesOptions={props.filesOptions}
</DrawerContent> modelOptions={props.modelOptions}
) : ( slug={props.data.slug}
<DrawerContent className="whitespace-pre-line p-2"> inputToolOptions={props.inputToolOptions}
<DrawerHeader> isSubscribed={props.isSubscribed}
<DrawerTitle>{props.data.name}</DrawerTitle> outputModeOptions={props.outputModeOptions}
<DrawerDescription>Persona</DrawerDescription> />
</DrawerHeader> </DialogContent>
) : (
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] rounded-lg">
<DialogHeader>
<div className="flex items-center">
{getIconFromIconName(props.data.icon, props.data.color)}
<p className="font-bold text-lg">{props.data.name}</p>
</div>
</DialogHeader>
<div className="max-h-[60vh] overflow-y-scroll text-neutral-500 dark:text-white">
{props.data.persona} {props.data.persona}
<div className="flex flex-wrap items-center gap-1"> </div>
{makeBadgeFooter()} <div className="flex flex-wrap items-center gap-1">
</div> {makeBadgeFooter()}
<DrawerFooter> </div>
<DrawerClose>Done</DrawerClose> <DialogFooter>
</DrawerFooter> <Button
</DrawerContent> className={`pt-6 pb-6 ${stylingString} bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white border-2 border-stone-100 shadow-sm rounded-xl hover:bg-stone-100 dark:hover:bg-neutral-900 dark:border-neutral-700`}
)} onClick={() => {
</Drawer> openChat(props.data.slug, userData);
)} setShowModal(false);
}}
>
<PaperPlaneTilt
className={`w-6 h-6 m-2 ${convertColorToTextClass(props.data.color)}`}
/>
Start Chatting
</Button>
</DialogFooter>
</DialogContent>
)}
</Dialog>
</CardTitle> </CardTitle>
</CardHeader> </CardHeader>
<CardContent> <CardContent>
@@ -930,7 +787,7 @@ function AgentModificationForm(props: AgentModificationFormProps) {
/> />
<div className="grid"> <div className="grid">
<FormLabel className="mb-2">Look & Feel</FormLabel> <FormLabel className="mb-2">Look & Feel</FormLabel>
<div className="flex gap-1 justify-left"> <div className="flex gap-1 justify-left flex-col md:flex-row">
<FormField <FormField
control={props.form.control} control={props.form.control}
name="color" name="color"
@@ -1378,44 +1235,6 @@ function CreateAgentCard(props: CreateAgentCardProps) {
}); });
}; };
if (props.isMobileWidth) {
return (
<Drawer open={showModal} onOpenChange={setShowModal}>
<DrawerTrigger>
<div className="flex items-center">
<Plus />
Create Agent
</div>
</DrawerTrigger>
<DrawerContent className="p-2">
<DrawerHeader>
<DrawerTitle>Create Agent</DrawerTitle>
</DrawerHeader>
{!props.userProfile && showLoginPrompt && (
<LoginPrompt
loginRedirectMessage="Sign in to start chatting with a specialized agent"
onOpenChange={setShowLoginPrompt}
/>
)}
<AgentModificationForm
form={form}
onSubmit={onSubmit}
create={true}
errors={errors}
filesOptions={props.filesOptions}
modelOptions={props.modelOptions}
inputToolOptions={props.inputToolOptions}
outputModeOptions={props.outputModeOptions}
isSubscribed={props.isSubscribed}
/>
<DrawerFooter>
<DrawerClose>Dismiss</DrawerClose>
</DrawerFooter>
</DrawerContent>
</Drawer>
);
}
return ( return (
<Dialog open={showModal} onOpenChange={setShowModal}> <Dialog open={showModal} onOpenChange={setShowModal}>
<DialogTrigger> <DialogTrigger>

View File

@@ -1,6 +1,6 @@
{ {
"name": "khoj-ai", "name": "khoj-ai",
"version": "1.24.1", "version": "1.25.0",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "next dev", "dev": "next dev",

View File

@@ -42,6 +42,7 @@ from khoj.database.adapters import (
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.api_content import configure_content, configure_search from khoj.routers.api_content import configure_content, configure_search
from khoj.routers.helpers import update_telemetry_state
from khoj.routers.twilio import is_twilio_enabled from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
@@ -165,7 +166,15 @@ class UserAuthenticationBackend(AuthenticationBackend):
create_if_not_exists = request.query_params.get("create_if_not_exists") create_if_not_exists = request.query_params.get("create_if_not_exists")
if create_if_not_exists: if create_if_not_exists:
user = await aget_or_create_user_by_phone_number(phone_number) user, is_new = await aget_or_create_user_by_phone_number(phone_number)
if user and is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
else: else:
user = await aget_user_by_phone_number(phone_number) user = await aget_user_by_phone_number(phone_number)
@@ -244,7 +253,7 @@ def configure_server(
state.SearchType = configure_search_types() state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type) state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent() setup_default_agent(user)
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled" message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
logger.info(message) logger.info(message)
@@ -256,8 +265,8 @@ def configure_server(
raise e raise e
def setup_default_agent(): def setup_default_agent(user: KhojUser):
AgentAdapters.create_default_agent() AgentAdapters.create_default_agent(user)
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):

View File

@@ -113,13 +113,15 @@ async def get_or_create_user(token: dict) -> KhojUser:
return user return user
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser: async def aget_or_create_user_by_phone_number(phone_number: str) -> tuple[KhojUser, bool]:
is_new = False
if is_none_or_empty(phone_number): if is_none_or_empty(phone_number):
return None return None, is_new
user = await aget_user_by_phone_number(phone_number) user = await aget_user_by_phone_number(phone_number)
if not user: if not user:
user = await acreate_user_by_phone_number(phone_number) user = await acreate_user_by_phone_number(phone_number)
return user is_new = True
return user, is_new
async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser: async def aset_user_phone_number(user: KhojUser, phone_number: str) -> KhojUser:
@@ -165,8 +167,10 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
return user return user
async def aget_or_create_user_by_email(email: str) -> KhojUser: async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email}) user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
defaults={"username": email, "email": email}
)
await user.asave() await user.asave()
if user: if user:
@@ -177,7 +181,7 @@ async def aget_or_create_user_by_email(email: str) -> KhojUser:
if not user_subscription: if not user_subscription:
await Subscription.objects.acreate(user=user, type="trial") await Subscription.objects.acreate(user=user, type="trial")
return user return user, is_new
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser: async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
@@ -248,9 +252,9 @@ def get_user_subscription(email: str) -> Optional[Subscription]:
async def set_user_subscription( async def set_user_subscription(
email: str, is_recurring=None, renewal_date=None, type="standard" email: str, is_recurring=None, renewal_date=None, type="standard"
) -> Optional[Subscription]: ) -> tuple[Optional[Subscription], bool]:
# Get or create the user object and their subscription # Get or create the user object and their subscription
user = await aget_or_create_user_by_email(email) user, is_new = await aget_or_create_user_by_email(email)
user_subscription = await Subscription.objects.filter(user=user).afirst() user_subscription = await Subscription.objects.filter(user=user).afirst()
# Update the user subscription state # Update the user subscription state
@@ -262,7 +266,7 @@ async def set_user_subscription(
elif renewal_date is not None: elif renewal_date is not None:
user_subscription.renewal_date = renewal_date user_subscription.renewal_date = renewal_date
await user_subscription.asave() await user_subscription.asave()
return user_subscription return user_subscription, is_new
def subscription_to_state(subscription: Subscription) -> str: def subscription_to_state(subscription: Subscription) -> str:
@@ -643,8 +647,8 @@ class AgentAdapters:
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
@staticmethod @staticmethod
def create_default_agent(): def create_default_agent(user: KhojUser):
default_conversation_config = ConversationAdapters.get_default_conversation_config() default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
if default_conversation_config is None: if default_conversation_config is None:
logger.info("No default conversation config found, skipping default agent creation") logger.info("No default conversation config found, skipping default agent creation")
return None return None
@@ -968,29 +972,51 @@ class ConversationAdapters:
return VoiceModelOption.objects.first() return VoiceModelOption.objects.first()
@staticmethod @staticmethod
def get_default_conversation_config(): def get_default_conversation_config(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first() server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.chat_default is None: if server_chat_settings is not None and server_chat_settings.chat_default is not None:
return ChatModelOptions.objects.filter().first() return server_chat_settings.chat_default
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None
if user_chat_settings is not None and user_chat_settings.setting is not None:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return ChatModelOptions.objects.filter().first()
@staticmethod @staticmethod
async def aget_default_conversation_config(): async def aget_default_conversation_config(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings: ServerChatSettings = ( server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter() await ServerChatSettings.objects.filter()
.prefetch_related("chat_default", "chat_default__openai_config") .prefetch_related("chat_default", "chat_default__openai_config")
.afirst() .afirst()
) )
if server_chat_settings is None or server_chat_settings.chat_default is None: if server_chat_settings is not None and server_chat_settings.chat_default is not None:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() return server_chat_settings.chat_default
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = (
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst())
if user
else None
)
if user_chat_settings is not None and user_chat_settings.setting is not None:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod @staticmethod
def get_advanced_conversation_config(): def get_advanced_conversation_config():
server_chat_settings = ServerChatSettings.objects.first() server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.chat_advanced is None: if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return ConversationAdapters.get_default_conversation_config() return server_chat_settings.chat_advanced
return server_chat_settings.chat_advanced return ConversationAdapters.get_default_conversation_config()
@staticmethod @staticmethod
async def aget_advanced_conversation_config(): async def aget_advanced_conversation_config():
@@ -999,9 +1025,9 @@ class ConversationAdapters:
.prefetch_related("chat_advanced", "chat_advanced__openai_config") .prefetch_related("chat_advanced", "chat_advanced__openai_config")
.afirst() .afirst()
) )
if server_chat_settings is None or server_chat_settings.chat_advanced is None: if server_chat_settings is not None or server_chat_settings.chat_advanced is not None:
return await ConversationAdapters.aget_default_conversation_config() return server_chat_settings.chat_advanced
return server_chat_settings.chat_advanced return await ConversationAdapters.aget_default_conversation_config()
@staticmethod @staticmethod
def create_conversation_from_public_conversation( def create_conversation_from_public_conversation(

View File

@@ -25,7 +25,6 @@ async def text_to_image(
location_data: LocationData, location_data: LocationData,
references: List[Dict[str, Any]], references: List[Dict[str, Any]],
online_results: Dict[str, Any], online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None, uploaded_image_url: Optional[str] = None,
agent: Agent = None, agent: Agent = None,
@@ -66,8 +65,8 @@ async def text_to_image(
note_references=references, note_references=references,
online_results=online_results, online_results=online_results,
model_type=text_to_image_config.model_type, model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
user=user,
agent=agent, agent=agent,
) )

View File

@@ -104,7 +104,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [ tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent) read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
for link, subquery, content in webpages for link, subquery, content in webpages
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -160,7 +160,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls] tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict) response: Dict[str, Dict] = defaultdict(dict)
@@ -171,14 +171,14 @@ async def read_webpages(
async def read_webpage_and_extract_content( async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
) -> Tuple[str, Union[None, str], str]: ) -> Tuple[str, Union[None, str], str]:
try: try:
if is_none_or_empty(content): if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger): with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url) content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger): with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent) extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
return subquery, extracted_info, url return subquery, extracted_info, url
except Exception as e: except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}") logger.error(f"Failed to read web page at '{url}' with {e}")

View File

@@ -394,7 +394,7 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
vision_enabled = conversation_config.vision_enabled vision_enabled = conversation_config.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:

View File

@@ -198,7 +198,7 @@ def chat_history(
n: Optional[int] = None, n: Optional[int] = None,
): ):
user = request.user.object user = request.user.object
validate_conversation_config() validate_conversation_config(user)
# Load Conversation History # Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user( conversation = ConversationAdapters.get_conversation_by_user(
@@ -309,7 +309,7 @@ def get_shared_chat(
update_telemetry_state( update_telemetry_state(
request=request, request=request,
telemetry_type="api", telemetry_type="api",
api="chat_history", api="get_shared_chat_history",
**common.__dict__, **common.__dict__,
) )
@@ -742,12 +742,12 @@ async def chat(
q, q,
meta_log, meta_log,
is_automated_task, is_automated_task,
subscribed=subscribed, user=user,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent, agent=agent,
) )
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
@@ -1001,7 +1001,6 @@ async def chat(
location_data=location, location_data=location,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent, agent=agent,

View File

@@ -40,7 +40,7 @@ def get_user_chat_model(
chat_model = ConversationAdapters.get_conversation_config(user) chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None: if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config() chat_model = ConversationAdapters.get_default_conversation_config(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model})) return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))

View File

@@ -80,11 +80,19 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
request.session.pop("user", None) request.session.pop("user", None)
email = form.email email = form.email
user = await aget_or_create_user_by_email(email) user, is_new = await aget_or_create_user_by_email(email)
unique_id = user.email_verification_code unique_id = user.email_verification_code
if user: if user:
await send_magic_link_email(email, unique_id, request.base_url) await send_magic_link_email(email, unique_id, request.base_url)
if is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
return Response(status_code=200) return Response(status_code=200)

View File

@@ -124,20 +124,20 @@ def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip()) return is_none_or_empty(query.strip())
def validate_conversation_config(): def validate_conversation_config(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config() default_config = ConversationAdapters.get_default_conversation_config(user)
if default_config is None: if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.openai_config: if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser): async def is_ready_to_chat(user: KhojUser):
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or ( user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
await ConversationAdapters.aget_default_conversation_config() if user_conversation_config == None:
) user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model chat_model = user_conversation_config.chat_model
@@ -239,19 +239,19 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)
async def acreate_title_from_query(query: str) -> str: async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
""" """
Create a title from the given query Create a title from the given query
""" """
title_generation_prompt = prompts.subject_generation.format(query=query) title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger): with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt) response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip() return response.strip()
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]: async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]:
""" """
Check if the system prompt is safe to use Check if the system prompt is safe to use
""" """
@@ -260,7 +260,7 @@ async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
reason = "" reason = ""
with timer("Chat actor: Check if safe prompt", logger): with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(safe_prompt_check) response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
response = response.strip() response = response.strip()
try: try:
@@ -281,7 +281,7 @@ async def aget_relevant_information_sources(
query: str, query: str,
conversation_history: dict, conversation_history: dict,
is_task: bool, is_task: bool,
subscribed: bool, user: KhojUser,
uploaded_image_url: str = None, uploaded_image_url: str = None,
agent: Agent = None, agent: Agent = None,
): ):
@@ -319,7 +319,7 @@ async def aget_relevant_information_sources(
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
relevant_tools_prompt, relevant_tools_prompt,
response_type="json_object", response_type="json_object",
subscribed=subscribed, user=user,
) )
try: try:
@@ -355,7 +355,12 @@ async def aget_relevant_information_sources(
async def aget_relevant_output_modes( async def aget_relevant_output_modes(
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None query: str,
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
agent: Agent = None,
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@@ -391,7 +396,7 @@ async def aget_relevant_output_modes(
) )
with timer("Chat actor: Infer output mode for chat response", logger): with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
try: try:
response = response.strip() response = response.strip()
@@ -446,7 +451,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object" online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
@@ -493,7 +498,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object" online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@@ -511,7 +516,9 @@ async def generate_online_subqueries(
return [q] return [q]
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]: async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
) -> Tuple[str, ...]:
""" """
Schedule the date, time to run the query. Assume the server timezone is UTC. Schedule the date, time to run the query. Assume the server timezone is UTC.
""" """
@@ -523,7 +530,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
) )
raw_response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object" crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@@ -537,7 +544,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}") raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]: async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
""" """
@@ -555,14 +562,11 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Ag
personality_context=personality_context, personality_context=personality_context,
) )
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger): with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_information, prompts.system_prompt_extract_relevant_information,
chat_model_option=chat_model, user=user,
subscribed=subscribed,
) )
return response.strip() return response.strip()
@@ -571,8 +575,8 @@ async def extract_relevant_summary(
q: str, q: str,
corpus: str, corpus: str,
conversation_history: dict, conversation_history: dict,
subscribed: bool = False,
uploaded_image_url: str = None, uploaded_image_url: str = None,
user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> Union[str, None]: ) -> Union[str, None]:
""" """
@@ -595,14 +599,11 @@ async def extract_relevant_summary(
personality_context=personality_context, personality_context=personality_context,
) )
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger): with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_summary, prompts.system_prompt_extract_relevant_summary,
chat_model_option=chat_model, user=user,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
) )
return response.strip() return response.strip()
@@ -667,8 +668,8 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
model_type: Optional[str] = None, model_type: Optional[str] = None,
subscribed: bool = False,
uploaded_image_url: Optional[str] = None, uploaded_image_url: Optional[str] = None,
user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> str: ) -> str:
""" """
@@ -718,12 +719,8 @@ async def generate_better_image_prompt(
personality_context=personality_context, personality_context=personality_context,
) )
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
image_prompt, chat_model_option=chat_model, subscribed=subscribed, uploaded_image_url=uploaded_image_url
)
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1] response = response[1:-1]
@@ -735,8 +732,9 @@ def send_message_to_model_wrapper_sync(
message: str, message: str,
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
user: KhojUser = None,
): ):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config() conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
if conversation_config is None: if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
@@ -1124,7 +1122,7 @@ class CommonQueryParamsClass:
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()] CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool: def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
""" """
Decide whether to notify the user of the AI response. Decide whether to notify the user of the AI response.
Default to notifying the user for now. Default to notifying the user for now.
@@ -1141,7 +1139,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str) ->
with timer("Chat actor: Decide to notify user of automation response", logger): with timer("Chat actor: Decide to notify user of automation response", logger):
try: try:
# TODO Replace with async call so we don't have to maintain a sync version # TODO Replace with async call so we don't have to maintain a sync version
response = send_message_to_model_wrapper_sync(to_notify_or_not) response = send_message_to_model_wrapper_sync(to_notify_or_not, user)
should_notify_result = "no" not in response.lower() should_notify_result = "no" not in response.lower()
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.') logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
return should_notify_result return should_notify_result
@@ -1233,7 +1231,9 @@ def scheduled_chat(
ai_response = raw_response.text ai_response = raw_response.text
# Notify user if the AI response is satisfactory # Notify user if the AI response is satisfactory
if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response): if should_notify(
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
):
if is_resend_enabled(): if is_resend_enabled():
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image) send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
else: else:
@@ -1243,7 +1243,7 @@ def scheduled_chat(
async def create_automation( async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
): ):
crontime, query_to_run, subject = await schedule_query(q, meta_log) crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject return job, crontime, query_to_run, subject
@@ -1429,9 +1429,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user) current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else "" notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ( selected_chat_model_config = ConversationAdapters.get_conversation_config(
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config() user
) ) or ConversationAdapters.get_default_conversation_config(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list() chat_model_options = list()
for chat_model in chat_models: for chat_model in chat_models:

View File

@@ -7,6 +7,7 @@ from fastapi import APIRouter, Request
from starlette.authentication import requires from starlette.authentication import requires
from khoj.database import adapters from khoj.database import adapters
from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state from khoj.utils import state
# Stripe integration for Khoj Cloud Subscription # Stripe integration for Khoj Cloud Subscription
@@ -48,6 +49,8 @@ async def subscribe(request: Request):
customer_id = subscription["customer"] customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id) customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"] customer_email = customer["email"]
user = None
is_new = False
# Handle valid stripe webhook events # Handle valid stripe webhook events
success = True success = True
@@ -55,7 +58,9 @@ async def subscribe(request: Request):
# Mark the user as subscribed and update the next renewal date on payment # Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0] subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date) user, is_new = await adapters.set_user_subscription(
customer_email, is_recurring=True, renewal_date=renewal_date
)
success = user is not None success = user is not None
elif event_type in {"customer.subscription.updated"}: elif event_type in {"customer.subscription.updated"}:
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email) user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
@@ -63,15 +68,24 @@ async def subscribe(request: Request):
if user_subscription and user_subscription.renewal_date: if user_subscription and user_subscription.renewal_date:
# Mark user as unsubscribed or resubscribed # Mark user as unsubscribed or resubscribed
is_recurring = not subscription["cancel_at_period_end"] is_recurring = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring) user, is_new = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
success = updated_user is not None success = user is not None
elif event_type in {"customer.subscription.deleted"}: elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state # Reset the user to trial state
user = await adapters.set_user_subscription( user, is_new = await adapters.set_user_subscription(
customer_email, is_recurring=False, renewal_date=False, type="trial" customer_email, is_recurring=False, renewal_date=False, type="trial"
) )
success = user is not None success = user is not None
if user and is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(user.user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
logger.info(f'Stripe subscription {event["type"]} for {customer_email}') logger.info(f'Stripe subscription {event["type"]} for {customer_email}')
return {"success": success} return {"success": success}

View File

@@ -129,9 +129,6 @@ def initialization(interactive: bool = True):
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists(): if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
default_chat_model_name = user_chat_model_name default_chat_model_name = user_chat_model_name
# Create a server chat settings object with the default chat model
default_chat_model = ChatModelOptions.objects.filter(chat_model=default_chat_model_name).first()
ServerChatSettings.objects.create(chat_default=default_chat_model)
logger.info("🗣️ Chat model configuration complete") logger.info("🗣️ Chat model configuration complete")
# Set up offline speech to text model # Set up offline speech to text model

View File

@@ -76,5 +76,6 @@
"1.23.2": "0.15.0", "1.23.2": "0.15.0",
"1.23.3": "0.15.0", "1.23.3": "0.15.0",
"1.24.0": "0.15.0", "1.24.0": "0.15.0",
"1.24.1": "0.15.0" "1.24.1": "0.15.0",
"1.25.0": "0.15.0"
} }