diff --git a/src/interface/web/app/common/auth.ts b/src/interface/web/app/common/auth.ts
index defbf536..6344fbfe 100644
--- a/src/interface/web/app/common/auth.ts
+++ b/src/interface/web/app/common/auth.ts
@@ -63,6 +63,8 @@ export interface UserConfig {
enabled_content_source: SyncedContent;
has_documents: boolean;
notion_token: string | null;
+ enable_memory: boolean;
+ server_memory_mode: "disabled" | "enabled_default_off" | "enabled_default_on";
// user model settings
search_model_options: ModelOptions[];
selected_search_model_config: number;
diff --git a/src/interface/web/app/components/userMemory/userMemory.tsx b/src/interface/web/app/components/userMemory/userMemory.tsx
new file mode 100644
index 00000000..db360e7b
--- /dev/null
+++ b/src/interface/web/app/components/userMemory/userMemory.tsx
@@ -0,0 +1,90 @@
+import { useState } from "react";
+import { Input } from "@/components/ui/input";
+import { Button } from "@/components/ui/button";
+import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react";
+import { useToast } from "@/components/ui/use-toast";
+
+export interface UserMemorySchema {
+ id: number;
+ raw: string;
+ created_at: string;
+}
+
+interface UserMemoryProps {
+ memory: UserMemorySchema;
+ onDelete: (id: number) => void;
+ onUpdate: (id: number, raw: string) => void;
+}
+
+export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) {
+ const [isEditing, setIsEditing] = useState(false);
+ const [content, setContent] = useState(memory.raw);
+ const { toast } = useToast();
+
+ const handleUpdate = () => {
+ onUpdate(memory.id, content);
+ setIsEditing(false);
+ toast({
+ title: "Memory Updated",
+ description: "Your memory has been successfully updated.",
+ });
+ };
+
+ const handleDelete = () => {
+ onDelete(memory.id);
+ toast({
+ title: "Memory Deleted",
+ description: "Your memory has been successfully deleted.",
+ });
+ };
+
+ return (
+
+ {isEditing ? (
+ <>
+
setContent(e.target.value)}
+ className="flex-1"
+ />
+
+
+ >
+ ) : (
+ <>
+
+
+
+ >
+ )}
+
+ );
+}
diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx
index ad8b88f5..75ea7ee2 100644
--- a/src/interface/web/app/settings/page.tsx
+++ b/src/interface/web/app/settings/page.tsx
@@ -15,6 +15,8 @@ import { Button } from "@/components/ui/button";
import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp";
import { Input } from "@/components/ui/input";
import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card";
+import { Switch } from "@/components/ui/switch";
+
import {
DropdownMenu,
DropdownMenuContent,
@@ -33,6 +35,15 @@ import {
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
+
+import {
+ Dialog,
+ DialogContent,
+ DialogHeader,
+ DialogTitle,
+ DialogTrigger
+} from "@/components/ui/dialog";
+
import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table";
import {
@@ -74,6 +85,7 @@ import Loading from "../components/loading/loading";
import IntlTelInput from "intl-tel-input/react";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar";
+import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory";
import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo";
import { Progress } from "@/components/ui/progress";
@@ -323,6 +335,9 @@ export default function SettingsView() {
const [numberValidationState, setNumberValidationState] = useState(
PhoneNumberValidationState.Verified,
);
+ const [memories, setMemories] = useState([]);
+ const [enableMemory, setEnableMemory] = useState(true);
+ const [serverMemoryMode, setServerMemoryMode] = useState("enabled_default_on");
const [isExporting, setIsExporting] = useState(false);
const [exportProgress, setExportProgress] = useState(0);
const [exportedConversations, setExportedConversations] = useState(0);
@@ -347,6 +362,8 @@ export default function SettingsView() {
);
setName(initialUserConfig?.given_name);
setNotionToken(initialUserConfig?.notion_token ?? null);
+ setEnableMemory(initialUserConfig?.enable_memory ?? true);
+ setServerMemoryMode(initialUserConfig?.server_memory_mode ?? "enabled_default_on");
}, [initialUserConfig]);
const sendOTP = async () => {
@@ -621,6 +638,88 @@ export default function SettingsView() {
}
};
+ const fetchMemories = async () => {
+ try {
+ console.log("Fetching memories...");
+ const response = await fetch('/api/memories/');
+ if (!response.ok) throw new Error('Failed to fetch memories');
+ const data = await response.json();
+ setMemories(data);
+ } catch (error) {
+ console.error('Error fetching memories:', error);
+ toast({
+ title: "Error",
+ description: "Failed to fetch memories. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+ const handleDeleteMemory = async (id: number) => {
+ try {
+ const response = await fetch(`/api/memories/${id}`, {
+ method: 'DELETE'
+ });
+ if (!response.ok) throw new Error('Failed to delete memory');
+ setMemories(memories.filter(memory => memory.id !== id));
+ } catch (error) {
+ console.error('Error deleting memory:', error);
+ toast({
+ title: "Error",
+ description: "Failed to delete memory. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+ const handleUpdateMemory = async (id: number, raw: string) => {
+ try {
+ const response = await fetch(`/api/memories/${id}`, {
+ method: 'PUT',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({ raw, memory_id: id }),
+ });
+ if (!response.ok) throw new Error('Failed to update memory');
+ const updatedMemory: UserMemorySchema = await response.json();
+ setMemories(memories.map(memory =>
+ memory.id === id ? updatedMemory : memory
+ ));
+ } catch (error) {
+ console.error('Error updating memory:', error);
+ toast({
+ title: "Error",
+ description: "Failed to update memory. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+ const handleToggleMemory = async (enabled: boolean) => {
+ try {
+ const response = await fetch(`/api/user/memory?enable_memory=${enabled}`, {
+ method: 'PATCH',
+ });
+ if (!response.ok) throw new Error('Failed to update memory setting');
+ setEnableMemory(enabled);
+ toast({
+ title: enabled ? "Memory enabled" : "Memory disabled",
+ description: enabled
+ ? "Khoj will learn and remember from your conversations."
+ : "Khoj will no longer learn or remember from your conversations.",
+ });
+ } catch (error) {
+ console.error('Error toggling memory:', error);
+ toast({
+ title: "Error",
+ description: "Failed to update memory setting. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+
const syncContent = async (type: string) => {
try {
const response = await fetch(`/api/content?t=${type}`, {
@@ -1212,7 +1311,64 @@ export default function SettingsView() {
-
+
+
+
+ Memories
+
+
+
+ View and manage your long-term memories
+
+
+
+ handleToggleMemory(checked)}
+ disabled={serverMemoryMode === "disabled"}
+ />
+
+ {serverMemoryMode === "disabled" && (
+
+ Memory has been disabled by the server administrator.
+
+ )}
+
+
+
+
+
diff --git a/src/interface/web/bun.lock b/src/interface/web/bun.lock
index 36f1d3cf..9287dd9f 100644
--- a/src/interface/web/bun.lock
+++ b/src/interface/web/bun.lock
@@ -24,6 +24,7 @@
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
+ "@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toast": "^1.2.15",
"@radix-ui/react-toggle": "^1.1.10",
@@ -274,7 +275,7 @@
"@radix-ui/react-slot": ["@radix-ui/react-slot@1.2.3", "", { "dependencies": { "@radix-ui/react-compose-refs": "1.1.2" }, "peerDependencies": { "@types/react": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react"] }, "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A=="],
- "@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="],
+ "@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.6", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-bByzr1+ep1zk4VubeEVViV592vu2lHE2BZY5OnzehZqOOgogN80+mNtCqPkhn2gklJqOpxWgPoYTSnhBCqpOXQ=="],
"@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.13", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.5", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.11", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A=="],
@@ -1478,8 +1479,6 @@
"@radix-ui/react-slider/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="],
- "@radix-ui/react-switch/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="],
-
"@radix-ui/react-toggle-group/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="],
"@radix-ui/react-toggle-group/@radix-ui/react-roving-focus": ["@radix-ui/react-roving-focus@1.1.10", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q=="],
@@ -1576,6 +1575,8 @@
"radix-ui/@radix-ui/react-select": ["@radix-ui/react-select@2.2.5", "", { "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-popper": "1.2.7", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA=="],
+ "radix-ui/@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="],
+
"radix-ui/@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.12", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-GTVAlRVrQrSw3cEARM0nAx73ixrWDPNZAruETn3oHCNP6SbZ/hNxdxp+u7VkIEv3/sFoLq1PfcHrl7Pnp0CDpw=="],
"radix-ui/@radix-ui/react-toast": ["@radix-ui/react-toast@1.2.14", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-nAP5FBxBJGQ/YfUB+r+O6USFVkWq3gAInkxyEnmvEV5jtSbfDhfa4hwX8CraCnbjMLsE7XSf/K75l9xXY7joWg=="],
diff --git a/src/interface/web/components/ui/switch.tsx b/src/interface/web/components/ui/switch.tsx
new file mode 100644
index 00000000..812af0df
--- /dev/null
+++ b/src/interface/web/components/ui/switch.tsx
@@ -0,0 +1,29 @@
+"use client"
+
+import * as React from "react"
+import * as SwitchPrimitives from "@radix-ui/react-switch"
+
+import { cn } from "@/lib/utils"
+
+const Switch = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+
+
+))
+Switch.displayName = SwitchPrimitives.Root.displayName
+
+export { Switch }
diff --git a/src/interface/web/package.json b/src/interface/web/package.json
index 3eceac6f..ccbbd4bf 100644
--- a/src/interface/web/package.json
+++ b/src/interface/web/package.json
@@ -38,6 +38,7 @@
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
+ "@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toast": "^1.2.15",
"@radix-ui/react-toggle": "^1.1.10",
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 40d1eeb5..50b2dd31 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -323,6 +323,7 @@ def configure_routes(app):
from khoj.routers.api_automation import api_automation
from khoj.routers.api_chat import api_chat
from khoj.routers.api_content import api_content
+ from khoj.routers.api_memories import api_memories
from khoj.routers.api_model import api_model
from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client
@@ -332,6 +333,7 @@ def configure_routes(app):
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_automation, prefix="/api/automation")
app.include_router(api_model, prefix="/api/model")
+ app.include_router(api_memories, prefix="/api/memories")
app.include_router(api_content, prefix="/api/content")
app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client)
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 3ab08ef2..abcc226f 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -62,6 +62,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
+ UserMemory,
UserRequests,
UserTextToImageModelConfig,
UserVoiceModelConfig,
@@ -566,6 +567,16 @@ def get_default_search_model() -> SearchModelConfig:
return SearchModelConfig.objects.first()
+async def aget_default_search_model() -> SearchModelConfig:
+ default_search_model = await SearchModelConfig.objects.filter(name="default").afirst()
+
+ if default_search_model:
+ return default_search_model
+ elif await SearchModelConfig.objects.count() == 0:
+ await SearchModelConfig.objects.acreate()
+ return await SearchModelConfig.objects.afirst()
+
+
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
@@ -1556,12 +1567,17 @@ class ConversationAdapters:
):
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
- conversation = await Conversation.objects.filter(
- user=user, client=client_application, id=conversation_id
- ).afirst()
+ conversation = (
+ await Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
+ .prefetch_related("agent", "agent__chat_model")
+ .afirst()
+ )
else:
conversation = (
- await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
+ await Conversation.objects.filter(user=user, client=client_application)
+ .prefetch_related("agent", "agent__chat_model")
+ .order_by("-updated_at")
+ .afirst()
)
existing_messages = conversation.messages if conversation else []
@@ -1597,6 +1613,85 @@ class ConversationAdapters:
return None
return config.setting
+ @staticmethod
+ async def ais_memory_enabled(user: KhojUser) -> bool:
+ """
+ Check if memory is enabled for the user based on server config and user preference.
+
+ Logic:
+ - If server memory_mode is DISABLED: return False (overrides user preference)
+ - If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False
+ - If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True
+ - If no server config exists: default to True
+ """
+ # Get server-level memory configuration
+ server_settings = await ServerChatSettings.objects.afirst()
+ if server_settings:
+ memory_mode = server_settings.memory_mode
+ # Disabled mode overrides all user preferences
+ if memory_mode == ServerChatSettings.MemoryMode.DISABLED:
+ return False
+
+ # Get user preference
+ user_config = await UserConversationConfig.objects.filter(user=user).afirst()
+
+ if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF:
+ # User must explicitly opt-in; check if user has set a preference
+ if user_config is None:
+ return False # Default off for new users
+ return user_config.enable_memory
+
+ # ENABLED_DEFAULT_ON: use user preference if set, else True
+ if user_config is None:
+ return True # Default on for new users
+ return user_config.enable_memory
+
+ # No server config - default behavior (enabled, default on)
+ user_config = await UserConversationConfig.objects.filter(user=user).afirst()
+ if user_config is None:
+ return True
+ return user_config.enable_memory
+
+ @staticmethod
+ def is_memory_enabled(user: KhojUser) -> bool:
+ """
+ Sync version of ais_memory_enabled.
+ Check if memory is enabled for the user based on server config and user preference.
+
+ Logic:
+ - If server memory_mode is DISABLED: return False (overrides user preference)
+ - If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False
+ - If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True
+ - If no server config exists: default to True
+ """
+ # Get server-level memory configuration
+ server_settings = ServerChatSettings.objects.first()
+ if server_settings:
+ memory_mode = server_settings.memory_mode
+ # Disabled mode overrides all user preferences
+ if memory_mode == ServerChatSettings.MemoryMode.DISABLED:
+ return False
+
+ # Get user preference
+ user_config = UserConversationConfig.objects.filter(user=user).first()
+
+ if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF:
+ # User must explicitly opt-in; check if user has set a preference
+ if user_config is None:
+ return False # Default off for new users
+ return user_config.enable_memory
+
+ # ENABLED_DEFAULT_ON: use user preference if set, else True
+ if user_config is None:
+ return True # Default on for new users
+ return user_config.enable_memory
+
+ # No server config - default behavior (enabled, default on)
+ user_config = UserConversationConfig.objects.filter(user=user).first()
+ if user_config is None:
+ return True
+ return user_config.enable_memory
+
@staticmethod
async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().prefetch_related("ai_model_api").afirst()
@@ -2186,3 +2281,94 @@ class McpServerAdapters:
except Exception as e:
logger.error(f"Error retrieving MCP servers: {e}", exc_info=True)
return servers
+
+
+class UserMemoryAdapters:
+ @staticmethod
+ @require_valid_user
+ async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]:
+ """
+ Pulls memories from the database for a given user. Medium term memory.
+ """
+ time_frame = datetime.now(timezone.utc) - timedelta(days=window)
+ default_agent = await AgentAdapters.aget_default_agent()
+ if agent and agent != default_agent:
+ memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by(
+ "-created_at"
+ )[:limit]
+ else:
+ memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit]
+ return await sync_to_async(list)(memories)
+
+ @staticmethod
+ @require_valid_user
+ async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory:
+ """
+ Saves a memory to the database for a given user.
+ """
+ embeddings_model = state.embeddings_model
+ model = await aget_default_search_model()
+ embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory)
+ default_agent = await AgentAdapters.aget_default_agent()
+ if agent and agent != default_agent:
+ memory_instance = await UserMemory.objects.acreate(
+ user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent
+ )
+ else:
+ memory_instance = await UserMemory.objects.acreate(
+ user=user, embeddings=embeddings, raw=memory, search_model=model
+ )
+
+ return memory_instance
+
+ @staticmethod
+ @require_valid_user
+ async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]:
+ """
+ Searches for memories in the database for a given user. Long term memory.
+ """
+ embeddings_model = state.embeddings_model
+ model = await aget_default_search_model()
+ max_distance = model.bi_encoder_confidence_threshold or math.inf
+ embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query)
+ default_agent = await AgentAdapters.aget_default_agent()
+
+ if agent and agent != default_agent:
+ relevant_memories = UserMemory.objects.filter(user=user, agent=agent)
+ else:
+ relevant_memories = UserMemory.objects.filter(user=user)
+
+ relevant_memories = (
+ relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query))
+ .order_by("distance")
+ .filter(distance__lte=max_distance)
+ )
+
+ return await sync_to_async(list)(relevant_memories[:limit])
+
+ @staticmethod
+ @require_valid_user
+ async def delete_memory(user: KhojUser, memory_id: str) -> bool:
+ """
+ Deletes a memory from the database for a given user.
+ """
+ try:
+ memory = await UserMemory.objects.aget(user=user, id=memory_id)
+ await memory.adelete()
+ return True
+ except UserMemory.DoesNotExist:
+ return False
+
+ @staticmethod
+ def to_dict(memories: List[UserMemory]) -> List[dict]:
+ """
+ Converts a list of Memory objects to a list of dictionaries.
+ """
+ return [
+ {
+ "id": f"{memory.id}",
+ "raw": memory.raw,
+ "updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"),
+ }
+ for memory in memories
+ ]
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index 872c57ca..345b03fd 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -34,6 +34,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
+ UserMemory,
UserRequests,
UserVoiceModelConfig,
VoiceModelOption,
@@ -182,6 +183,7 @@ admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin)
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
+admin.site.register(UserMemory, unfold_admin.ModelAdmin)
@admin.register(McpServer)
@@ -295,6 +297,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin):
"think_paid_fast",
"think_paid_deep",
"web_scraper",
+ "memory_mode",
)
ordering = ("priority",)
diff --git a/src/khoj/database/management/commands/manage_memories.py b/src/khoj/database/management/commands/manage_memories.py
new file mode 100644
index 00000000..62e14c97
--- /dev/null
+++ b/src/khoj/database/management/commands/manage_memories.py
@@ -0,0 +1,384 @@
+import asyncio
+from datetime import datetime, timedelta
+from typing import List, Optional
+
+from django.core.management.base import BaseCommand
+from django.db.models import Q
+from django.utils import timezone
+
+from khoj.configure import initialize_server
+from khoj.database.adapters import UserMemoryAdapters
+from khoj.database.models import (
+ Conversation,
+ DataStore,
+ KhojUser,
+)
+from khoj.routers.helpers import extract_facts_from_query
+
+
+class Command(BaseCommand):
+ help = "Manage user memories - generate from conversations or delete existing memories"
+
+ def add_arguments(self, parser):
+ parser.add_argument(
+ "--lookback-days",
+ type=int,
+ default=None,
+ help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories",
+ )
+ parser.add_argument(
+ "--users",
+ type=str,
+ help="Process specific users (comma-separated usernames or emails)",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=10,
+ help="Number of conversations to process in each batch (default: 10)",
+ )
+ parser.add_argument(
+ "--apply",
+ action="store_true",
+ help="Actually perform the operation. Without this flag, only shows what would be processed.",
+ )
+ parser.add_argument(
+ "--delete",
+ action="store_true",
+ help="Delete all memories for specified users instead of generating new ones",
+ )
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Resume from last checkpoint if process was interrupted",
+ )
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Force regenerate memories even if already processed",
+ )
+
+ def handle(self, *args, **options):
+ """Main entry point for the command"""
+ initialize_server()
+ asyncio.run(self.async_handle(*args, **options))
+
+ async def async_handle(self, *args, **options):
+ """Async handler for memory management"""
+ lookback_days = options["lookback_days"]
+ usernames = options["users"]
+ batch_size = options["batch_size"]
+ apply = options["apply"]
+ delete = options["delete"]
+ resume = options["resume"]
+ force = options["force"]
+
+ mode = "APPLY" if apply else "DRY RUN"
+
+ # Handle deletion mode
+ if delete:
+ # For deletion, only use cutoff_date if lookback_days is explicitly provided
+ cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None
+ await self.handle_delete_memories(usernames, cutoff_date, apply)
+ return
+
+ # Handle generation mode
+ # For generation, default to 7 days if not specified
+ if lookback_days is None:
+ lookback_days = 7
+ cutoff_date = timezone.now() - timedelta(days=lookback_days)
+ self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days")
+
+ # Get users to process
+ users = await self.get_users_to_process(usernames)
+ if not users:
+ self.stdout.write("No users found to process")
+ return
+
+ self.stdout.write(f"Found {len(users)} users to process")
+
+ # Initialize or retrieve checkpoint
+ checkpoint = await self.get_or_create_checkpoint(resume)
+
+ total_conversations = 0
+ total_memories = 0
+
+ for user in users:
+ # Check if user already processed in checkpoint
+ if not force and user.id in checkpoint.get("processed_users", []):
+ self.stdout.write(f"Skipping already processed user: {user.username}")
+ continue
+
+ self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})")
+
+ # Get conversations for this user
+ conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force)
+
+ if not conversations:
+ self.stdout.write(f" No conversations to process for {user.username}")
+ # Mark user as processed
+ if apply:
+ await self.update_checkpoint(checkpoint, user_id=user.id)
+ continue
+
+ self.stdout.write(f" Found {len(conversations)} conversations to process")
+
+ # Process conversations in batches
+ user_memories = 0
+ for i in range(0, len(conversations), batch_size):
+ batch = conversations[i : i + batch_size]
+ batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint)
+ user_memories += batch_memories
+ total_conversations += len(batch)
+
+ # Update progress
+ progress = min(i + batch_size, len(conversations))
+ self.stdout.write(
+ f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories"
+ )
+
+ total_memories += user_memories
+ self.stdout.write(
+ f" Completed user {user.username}: "
+ f"processed {len(conversations)} conversations, "
+ f"generated {user_memories} memories"
+ )
+
+ # Mark user as processed
+ if apply:
+ await self.update_checkpoint(checkpoint, user_id=user.id)
+
+ # Clear checkpoint on successful completion
+ if apply:
+ await self.clear_checkpoint()
+
+ action = "Generated" if apply else "Would generate"
+ self.stdout.write(
+ self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations")
+ )
+
+ async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]:
+ """Get list of users to comma separated usernames or emails to process"""
+ if users_str:
+ usernames = [u.strip() for u in users_str.split(",") if u.strip()]
+ # Process specific users
+ users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))]
+ return users
+ else:
+ # Process all users with conversations
+ return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()]
+
+ async def get_user_conversations(
+ self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool
+ ) -> List[Conversation]:
+ """Get conversations for a user that need processing"""
+ if cutoff_date is None:
+ query = Conversation.objects.filter(user=user).order_by("updated_at")
+ else:
+ query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at")
+
+ # Filter out already processed conversations if resuming
+ if not force and user.id in checkpoint.get("processed_conversations", {}):
+ processed_ids = checkpoint["processed_conversations"][user.id]
+ query = query.exclude(id__in=processed_ids)
+
+ return [conv async for conv in query]
+
+ async def process_conversation_batch(
+ self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict
+ ) -> int:
+ """Process a batch of conversations and generate memories"""
+ total_memories = 0
+
+ for conversation in conversations:
+ try:
+ # Get conversation messages using sync_to_async for property access
+ from asgiref.sync import sync_to_async
+
+ # Access conversation_log synchronously
+ @sync_to_async
+ def get_messages():
+ return conversation.messages
+
+ messages = await get_messages()
+ if not messages:
+ continue
+
+ # Get agent if conversation has one
+ @sync_to_async
+ def get_agent():
+ return conversation.agent
+
+ agent = await get_agent()
+
+ # Get existing memories for context
+ # Process each conversation turn
+ conversation_memories = 0
+ i = 0
+ while i + 1 < len(messages):
+ # Only process user-assistant pairs as a valid turn for memory extraction
+ if messages[i].by != "you" or messages[i + 1].by != "khoj":
+ i += 1
+ continue
+
+ # Get the conversation history up to this point
+ history = messages[: i + 2]
+
+ # Extract user query text for memory search
+ q = ""
+ if messages[i].message is None:
+ i += 1
+ continue
+ elif isinstance(messages[i].message, str):
+ q = messages[i].message
+ elif isinstance(messages[i].message, list):
+ q = "\n\n".join(
+ content.get("text", "")
+ for content in messages[i].message
+ if isinstance(content, dict) and content.get("text")
+ )
+
+ if not q or not q.strip():
+ i += 1
+ continue
+
+ # Get unique recent and long term relevant memories
+ recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent)
+ long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent)
+ relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values())
+
+ if apply:
+ # Ensure agent is fully loaded with its chat_model
+ if agent:
+
+ @sync_to_async
+ def load_agent_with_chat_model():
+ # Force load the chat_model relationship
+ _ = agent.chat_model
+ return agent
+
+ agent = await load_agent_with_chat_model()
+
+ # Update memories based on latest conversation turn
+ memory_updates = await extract_facts_from_query(
+ user=user,
+ conversation_history=history,
+ existing_facts=relevant_memories,
+ agent=agent,
+ tracer={},
+ )
+
+ # Save new memories
+ for memory in memory_updates.create:
+ await UserMemoryAdapters.save_memory(user, memory, agent=agent)
+ conversation_memories += 1
+ self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...")
+
+ # Delete outdated memories
+ for memory in memory_updates.delete:
+ await UserMemoryAdapters.delete_memory(user, memory)
+ self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...")
+ else:
+ # Dry run - estimate memories that would be created
+ conversation_memories += 1 # Rough estimate
+
+ # Move to next conversation turn pair
+ i += 2
+
+ total_memories += conversation_memories
+
+ # Update checkpoint after each conversation
+ if apply:
+ await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id))
+ except Exception as e:
+ import traceback
+
+ self.stderr.write(
+ f"Error processing conversation {conversation.id} for user {user.id}: {e}\n"
+ f"Traceback: {traceback.format_exc()}"
+ )
+ continue
+
+ return total_memories
+
+ async def get_or_create_checkpoint(self, resume: bool) -> dict:
+ """Get or create checkpoint for resumable processing"""
+ checkpoint_key = "memory_generation_checkpoint"
+
+ if resume:
+ # Try to retrieve existing checkpoint
+ checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst()
+ if checkpoint_store:
+ self.stdout.write("Resuming from checkpoint...")
+ return checkpoint_store.value
+
+ # Create new checkpoint
+ return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}}
+
+ async def update_checkpoint(
+ self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None
+ ):
+ """Update checkpoint with progress"""
+ if user_id and user_id not in checkpoint["processed_users"]:
+ checkpoint["processed_users"].append(user_id)
+
+ if user_id and conversation_id:
+ if user_id not in checkpoint["processed_conversations"]:
+ checkpoint["processed_conversations"][user_id] = []
+ if conversation_id not in checkpoint["processed_conversations"][user_id]:
+ checkpoint["processed_conversations"][user_id].append(conversation_id)
+
+ # Save checkpoint to database
+ await DataStore.objects.aupdate_or_create(
+ key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True}
+ )
+
+ async def clear_checkpoint(self):
+ """Clear checkpoint after successful completion"""
+ await DataStore.objects.filter(key="memory_generation_checkpoint").adelete()
+ self.stdout.write("Checkpoint cleared")
+
+ async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool):
+ """Handle deletion of user memories"""
+ from khoj.database.models import UserMemory
+
+ # Get users to process
+ users = await self.get_users_to_process(usernames)
+ if not users:
+ self.stdout.write("No users found to process")
+ return
+
+ mode = "APPLY" if apply else "DRY RUN"
+ if cutoff_date:
+ # Calculate days from cutoff date
+ days_back = (timezone.now() - cutoff_date).days
+ self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users")
+ else:
+ self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users")
+
+ total_deleted = 0
+ for user in users:
+ # Count memories for this user
+ if cutoff_date is None:
+ user_memories = UserMemory.objects.filter(user=user)
+ else:
+ user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date)
+
+ memories_count = await user_memories.acount()
+ if memories_count == 0:
+ self.stdout.write(f" User {user.username} has no memories to delete")
+ continue
+
+ self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories")
+
+ if apply:
+ # Delete memories for this user (with date filter if specified)
+ deleted_count, _ = await user_memories.adelete()
+ self.stdout.write(f" Deleted {deleted_count} memories")
+ total_deleted += deleted_count
+ else:
+ self.stdout.write(f" Would delete {memories_count} memories")
+ total_deleted += memories_count
+
+ action = "Deleted" if apply else "Would delete"
+ self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total"))
diff --git a/src/khoj/database/migrations/0099_usermemory.py b/src/khoj/database/migrations/0099_usermemory.py
new file mode 100644
index 00000000..b1d6939a
--- /dev/null
+++ b/src/khoj/database/migrations/0099_usermemory.py
@@ -0,0 +1,82 @@
+# Generated by Django 5.1.10 on 2025-08-29 22:57
+
+import django.db.models.deletion
+import pgvector.django
+from django.conf import settings
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0098_alter_texttoimagemodelconfig_model_type"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="UserMemory",
+ fields=[
+ (
+ "id",
+ models.BigAutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("embeddings", pgvector.django.VectorField()),
+ ("raw", models.TextField()),
+ (
+ "agent",
+ models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ to="database.agent",
+ ),
+ ),
+ (
+ "search_model",
+ models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ to="database.searchmodelconfig",
+ ),
+ ),
+ (
+ "user",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ migrations.AddField(
+ model_name="userconversationconfig",
+ name="enable_memory",
+ field=models.BooleanField(default=True),
+ ),
+ migrations.AddField(
+ model_name="serverchatsettings",
+ name="memory_mode",
+ field=models.CharField(
+ choices=[
+ ("disabled", "Disabled"),
+ ("enabled_default_off", "Enabled, default off"),
+ ("enabled_default_on", "Enabled, default on"),
+ ],
+ default="enabled_default_on",
+ help_text="Server-level memory feature configuration. Disabled overrides user preference.",
+ max_length=20,
+ ),
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 404b0f38..56516b2d 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -478,6 +478,13 @@ class ServerChatSettings(DbBaseModel):
THINK_PAID_FAST = "think_paid_fast"
THINK_PAID_DEEP = "think_paid_deep"
+ class MemoryMode(models.TextChoices):
+ """Enum for server-level memory feature configuration"""
+
+ DISABLED = "disabled", "Disabled"
+ ENABLED_DEFAULT_OFF = "enabled_default_off", "Enabled, default off"
+ ENABLED_DEFAULT_ON = "enabled_default_on", "Enabled, default on"
+
chat_default = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
@@ -506,6 +513,12 @@ class ServerChatSettings(DbBaseModel):
unique=True,
help_text="Priority of the server chat settings. Lower numbers run first.",
)
+ memory_mode = models.CharField(
+ max_length=20,
+ choices=MemoryMode.choices,
+ default=MemoryMode.ENABLED_DEFAULT_ON,
+ help_text="Server-level memory feature configuration. Disabled overrides user preference.",
+ )
def clean(self):
error = {}
@@ -629,6 +642,7 @@ class SpeechToTextModelOptions(DbBaseModel):
class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
+ enable_memory = models.BooleanField(default=True)
class UserVoiceModelConfig(DbBaseModel):
@@ -836,3 +850,15 @@ class McpServer(DbBaseModel):
def __str__(self):
return self.name
+
+
+class UserMemory(DbBaseModel):
+ """
+ Long term memory store derived from conversation between user and agent.
+ """
+
+ user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
+ agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
+ embeddings = VectorField(dimensions=None)
+ raw = models.TextField()
+ search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 611d41a8..0ed601ea 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -1302,3 +1302,70 @@ user_name = PromptTemplate.from_template(
User's Name: {name}
""".strip()
)
+
+extract_facts_from_query = PromptTemplate.from_template(
+ """
+You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user.
+This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember.
+
+You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action:
+1. Create new facts
+2. Delete existing facts
+
+You should delete existing facts that are no longer true.
+You can enhance new facts with information from existing facts.
+You cannot update existing facts directly, instead create new facts and delete related existing ones to update them.
+
+Your output should be a JSON object with two lists: create and delete.
+- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective.
+- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true.
+- Leave the create or delete list empty if you have nothing important to add or remove.
+
+# Example
+Existing Facts:
+[
+ {{
+ "id": "5283",
+ "raw": "I am not interested in sports",
+ "updated_at": "2023-10-01T12:00:00+00:00"
+ }},
+ {{
+ "id": "22",
+ "raw": "I am a software engineer",
+ "updated_at": "2023-10-31T14:00:00+00:00"
+ }},
+ {{
+ "id": "651",
+ "raw": "My mother works at the hospital",
+ "updated_at": "2023-10-02T17:00:00+00:00"
+ }}
+]
+
+Latest Chat Session:
+- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline.
+In between coding, I took my cat Whiskers out for a walk and played a game of football.
+My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat.
+- AI: That's great to hear!
+
+Response:
+{{
+ "create": [
+ "I am interested in AI and machine learning",
+ "I have a pet cat named Whiskers",
+ "I enjoy playing football",
+ "My mother works at the hospital and is a doctor"
+ ],
+ "delete": [
+ "5283",
+ "651"
+ ],
+}}
+
+# Input
+Existing Facts:
+{matched_facts}
+
+Latest Chat Session:
+{chat_history}
+""".strip()
+)
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index a2069243..dae182c2 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -28,6 +28,7 @@ from khoj.database.models import (
ClientApplication,
Intent,
KhojUser,
+ UserMemory,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.base_filter import BaseFilter
@@ -551,6 +552,7 @@ async def save_to_conversation_log(
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
+ relevant_memories: List[UserMemory] = [],
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
@@ -559,6 +561,8 @@ async def save_to_conversation_log(
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
+ from khoj.routers.helpers import ai_update_memories
+
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
@@ -605,6 +609,16 @@ async def save_to_conversation_log(
user_message=q,
)
+ if not automation_id:
+ # Don't update memories from automations, as this could get noisy.
+ await ai_update_memories(
+ user=user,
+ conversation_history=new_messages or [],
+ memories=relevant_memories,
+ agent=db_conversation.agent if db_conversation else None,
+ tracer=tracer,
+ )
+
if is_promptrace_enabled():
merge_message_into_conversation_trace(q, chat_response, tracer)
@@ -666,6 +680,7 @@ def generate_chatml_messages_with_context(
query_files: str = None,
query_images=None,
context_message="",
+ relevant_memories: List[UserMemory] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
chat_history: list[ChatMessageModel] = [],
@@ -806,6 +821,14 @@ def generate_chatml_messages_with_context(
),
)
+ if not is_none_or_empty(relevant_memories):
+ memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n\n"
+ for memory in relevant_memories:
+ friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S")
+ memory_context += f"- [{friendly_dt}]: {memory.raw}\n"
+ memory_context += ""
+ messages.append(ChatMessage(content=memory_context, role="user"))
+
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py
index 45a640c1..38c2cd84 100644
--- a/src/khoj/processor/image/generate.py
+++ b/src/khoj/processor/image/generate.py
@@ -25,6 +25,7 @@ from khoj.database.models import (
Intent,
KhojUser,
TextToImageModelConfig,
+ UserMemory,
)
from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url
@@ -47,6 +48,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
agent: Agent = None,
tracer: dict = {},
):
@@ -91,6 +93,7 @@ async def text_to_image(
model_type=text_to_image_config.model_type,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,
diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py
index 0aa12ca4..ccd780aa 100644
--- a/src/khoj/processor/operator/__init__.py
+++ b/src/khoj/processor/operator/__init__.py
@@ -5,7 +5,7 @@ import os
from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
-from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
+from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser, UserMemory
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
@@ -42,6 +42,7 @@ async def operate_environment(
query_images: Optional[List[str]] = None, # TODO: Handle query images
agent: Agent = None,
query_files: str = None, # TODO: Handle query files
+ relevant_memories: Optional[List[UserMemory]] = None, # TODO: Handle relevant memories
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: Optional[str] = ChatEvent.END_EVENT.value,
diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py
index beeb640c..96a587ae 100644
--- a/src/khoj/processor/tools/online_search.py
+++ b/src/khoj/processor/tools/online_search.py
@@ -14,6 +14,7 @@ from khoj.database.models import (
Agent,
ChatMessageModel,
KhojUser,
+ UserMemory,
WebScraper,
)
from khoj.routers.helpers import (
@@ -63,6 +64,7 @@ async def search_online(
max_webpages_to_read: int = 1,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
previous_subqueries: Set = set(),
fast_model: bool = True,
agent: Agent = None,
@@ -82,6 +84,7 @@ async def search_online(
user,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
max_queries=max_online_searches,
fast_model=fast_model,
agent=agent,
@@ -160,7 +163,15 @@ async def search_online(
async for event in send_status_func(f"**Browsing**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
- extract_from_webpage(link, data["queries"], data.get("content"), user=user, agent=agent, tracer=tracer)
+ extract_from_webpage(
+ link,
+ data["queries"],
+ data.get("content"),
+ relevant_memories=relevant_memories,
+ user=user,
+ agent=agent,
+ tracer=tracer,
+ )
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@@ -438,6 +449,7 @@ async def read_webpages(
fast_model: bool = True,
agent: Agent = None,
max_webpages_to_read: int = 1,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
@@ -460,6 +472,7 @@ async def read_webpages(
user,
send_status_func=send_status_func,
agent=agent,
+ relevant_memories=relevant_memories,
tracer=tracer,
):
yield result
@@ -471,6 +484,7 @@ async def read_webpages_content(
user: KhojUser,
send_status_func: Optional[Callable] = None,
agent: Agent = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
logger.info(f"Reading web pages at: {urls}")
@@ -478,7 +492,10 @@ async def read_webpages_content(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Browsing**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
- tasks = [extract_from_webpage(url, {query}, user=user, agent=agent, tracer=tracer) for url in urls]
+ tasks = [
+ extract_from_webpage(url, {query}, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer)
+ for url in urls
+ ]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@@ -533,6 +550,7 @@ async def extract_from_webpage(
url: str,
subqueries: set[str] = None,
content: str = None,
+ relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -546,7 +564,9 @@ async def extract_from_webpage(
extracted_info = None
if not is_none_or_empty(content):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
- extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent, tracer=tracer)
+ extracted_info = await extract_relevant_info(
+ subqueries, content, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer
+ )
return subqueries, url, extracted_info
diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py
index 2ec392ff..45c95115 100644
--- a/src/khoj/processor/tools/run_code.py
+++ b/src/khoj/processor/tools/run_code.py
@@ -21,7 +21,7 @@ from tenacity import (
)
from khoj.database.adapters import AgentAdapters, FileObjectAdapters
-from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser
+from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser, UserMemory
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
@@ -59,6 +59,7 @@ async def run_code(
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
tracer: dict = {},
@@ -79,6 +80,7 @@ async def run_code(
agent,
tracer,
query_files,
+ relevant_memories,
)
except Exception as e:
raise ValueError(f"Failed to generate code for {instructions} with error: {e}")
@@ -126,6 +128,7 @@ async def generate_python_code(
agent: Agent = None,
tracer: dict = {},
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
) -> GeneratedCode:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@@ -161,6 +164,7 @@ async def generate_python_code(
code_generation_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index f7345d94..31950b1f 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -13,7 +13,7 @@ from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo
-from khoj.database.models import KhojUser, SpeechToTextModelOptions
+from khoj.database.models import KhojUser, SpeechToTextModelOptions, UserConversationConfig
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import (
ApiUserRateLimiter,
@@ -215,6 +215,29 @@ def set_user_name(
return {"status": "ok"}
+@api.patch("/user/memory", status_code=200)
+@requires(["authenticated"])
+def set_user_memory_enabled(
+ request: Request,
+ enable_memory: bool,
+ client: Optional[str] = None,
+):
+ user = request.user.object
+
+ user_config, _ = UserConversationConfig.objects.get_or_create(user=user)
+ user_config.enable_memory = enable_memory
+ user_config.save()
+
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="set_user_memory_enabled",
+ client=client,
+ )
+
+ return {"status": "ok", "enable_memory": enable_memory}
+
+
@api.get("/health", response_class=Response)
@requires(["authenticated"], status_code=200)
def health_check(request: Request) -> Response:
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 04cb84de..2e326690 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -29,6 +29,7 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
PublicConversationAdapters,
+ UserMemoryAdapters,
aget_user_name,
)
from khoj.database.models import Agent, KhojUser
@@ -101,7 +102,6 @@ conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
)
-
api_chat = APIRouter()
@@ -963,6 +963,14 @@ async def event_generator(
location = LocationData(city=city, region=region, country=country, country_code=country_code)
chat_history = conversation.messages
+ # Get most recent memories and long term relevant memories if memory is enabled
+ relevant_memories = []
+ if await ConversationAdapters.ais_memory_enabled(user):
+ recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent)
+ long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent)
+ # Create a de-duped set of memories
+ relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values())
+
# If interrupted message in DB
if last_message := await conversation.pop_message(interrupted=True):
# Populate context from interrupted message
@@ -987,6 +995,7 @@ async def event_generator(
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
)
except ValueError as e:
@@ -1024,6 +1033,7 @@ async def event_generator(
previous_iterations=list(research_results),
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
user_name=user_name,
location=location,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1081,6 +1091,7 @@ async def event_generator(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1130,6 +1141,7 @@ async def event_generator(
max_online_searches=3,
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1157,6 +1169,7 @@ async def event_generator(
max_webpages_to_read=1,
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1198,6 +1211,7 @@ async def event_generator(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1223,6 +1237,7 @@ async def event_generator(
list(operator_results)[-1] if operator_results else None,
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent,
cancellation_event=cancellation_event,
@@ -1276,6 +1291,7 @@ async def event_generator(
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1317,6 +1333,7 @@ async def event_generator(
online_results=online_results,
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1375,6 +1392,7 @@ async def event_generator(
user_name,
uploaded_images,
attached_file_context,
+ relevant_memories,
program_execution_context,
generated_asset_results,
is_subscribed,
@@ -1434,6 +1452,7 @@ async def event_generator(
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
+ relevant_memories=relevant_memories,
generated_images=generated_images,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
diff --git a/src/khoj/routers/api_memories.py b/src/khoj/routers/api_memories.py
new file mode 100644
index 00000000..d7527827
--- /dev/null
+++ b/src/khoj/routers/api_memories.py
@@ -0,0 +1,114 @@
+import json
+import logging
+from typing import Optional
+
+from asgiref.sync import sync_to_async
+from fastapi import APIRouter, Request
+from fastapi.responses import Response
+from pydantic import BaseModel
+from starlette.authentication import requires
+
+from khoj.database.adapters import UserMemoryAdapters
+from khoj.database.models import UserMemory
+
+api_memories = APIRouter()
+logger = logging.getLogger(__name__)
+
+
+@api_memories.get("")
+@requires(["authenticated"])
+async def get_memories(
+ request: Request,
+ client: Optional[str] = None,
+):
+ """Get all memories for the authenticated user"""
+ user = request.user.object
+
+ memories = UserMemory.objects.filter(user=user)
+ all_memories = await sync_to_async(list)(memories)
+
+ # Convert memories to a list of dictionaries
+ formatted_memories = [
+ {
+ "id": memory.id,
+ "raw": memory.raw,
+ "created_at": memory.created_at.isoformat(),
+ }
+ for memory in all_memories
+ ]
+
+ return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200)
+
+
+@api_memories.delete("/{memory_id}")
+@requires(["authenticated"])
+async def delete_memory(
+ request: Request,
+ memory_id: int,
+ client: Optional[str] = None,
+):
+ """Delete a specific memory by ID"""
+ user = request.user.object
+
+ # Verify memory belongs to user before deleting
+ memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst()
+ if not memory:
+ return Response(
+ content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404
+ )
+
+ await memory.adelete()
+
+ return Response(status_code=204)
+
+
+class UpdateMemoryBody(BaseModel):
+ """Request model for updating a memory"""
+
+ raw: str
+
+
+@api_memories.put("/{memory_id}")
+@requires(["authenticated"])
+async def update_memory(
+ request: Request,
+ body: UpdateMemoryBody,
+ memory_id: int,
+ client: Optional[str] = None,
+):
+ """Update a specific memory's content"""
+ user = request.user.object
+
+ # Get the memory and verify it belongs to the user
+ memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst()
+ if not memory:
+ return Response(
+ content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404
+ )
+
+ new_content = body.raw
+ if not new_content:
+ return Response(
+ content=json.dumps({"error": "Missing required field 'raw'"}),
+ media_type="application/json",
+ status_code=400,
+ )
+
+ await memory.adelete()
+
+ # Create a new memory with the updated content
+ new_memory = await UserMemoryAdapters.save_memory(
+ user=user,
+ memory=new_content,
+ )
+
+ return Response(
+ content=json.dumps(
+ {
+ "id": new_memory.id,
+ "raw": new_memory.raw,
+ }
+ ),
+ media_type="application/json",
+ status_code=200,
+ )
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 1c9f9b33..29a0df0f 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -46,6 +46,7 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
+ UserMemoryAdapters,
aget_user_by_email,
create_khoj_token,
get_default_search_model,
@@ -53,6 +54,7 @@ from khoj.database.adapters import (
get_user_name,
get_user_notion_config,
get_user_subscription_state,
+ require_valid_user,
run_with_process_lock,
)
from khoj.database.models import (
@@ -66,8 +68,10 @@ from khoj.database.models import (
NotionConfig,
ProcessLock,
RateLimitRecord,
+ ServerChatSettings,
Subscription,
TextToImageModelConfig,
+ UserMemory,
UserRequests,
)
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
@@ -354,6 +358,7 @@ async def aget_data_sources_and_output_format(
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> Dict[str, Any]:
"""
@@ -415,6 +420,7 @@ async def aget_data_sources_and_output_format(
relevant_tools_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
response_type="json_object",
response_schema=PickTools,
fast_model=True,
@@ -470,6 +476,7 @@ async def infer_webpage_urls(
user: KhojUser,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
@@ -506,6 +513,7 @@ async def infer_webpage_urls(
online_queries_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
response_type="json_object",
response_schema=WebpageUrls,
fast_model=fast_model,
@@ -536,6 +544,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
max_queries: int = 3,
fast_model: bool = True,
agent: Agent = None,
@@ -573,6 +582,7 @@ async def generate_online_subqueries(
online_queries_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
response_type="json_object",
response_schema=OnlineQueries,
fast_model=fast_model,
@@ -659,7 +669,12 @@ async def aschedule_query(
async def extract_relevant_info(
- qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
+ qs: set[str],
+ corpus: str,
+ relevant_memories: List[UserMemory] = None,
+ user: KhojUser = None,
+ agent: Agent = None,
+ tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@@ -683,6 +698,7 @@ async def extract_relevant_info(
response = await send_message_to_model_wrapper(
extract_relevant_information,
system_message=prompts.system_prompt_extract_relevant_information,
+ relevant_memories=relevant_memories,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
@@ -802,6 +818,7 @@ async def generate_excalidraw_diagram(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -819,6 +836,7 @@ async def generate_excalidraw_diagram(
online_results=online_results,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,
@@ -854,6 +872,7 @@ async def generate_better_diagram_description(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -899,6 +918,7 @@ async def generate_better_diagram_description(
improve_diagram_description_prompt,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
@@ -956,6 +976,88 @@ async def generate_excalidraw_diagram_from_description(
return response
+class MemoryUpdates(BaseModel):
+ """Facts to add or remove from memory."""
+
+ create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.")
+ delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.")
+
+
+async def extract_facts_from_query(
+ user: KhojUser,
+ conversation_history: List[ChatMessageModel],
+ existing_facts: List[UserMemory] = None,
+ agent: Agent = None,
+ tracer: dict = {},
+) -> MemoryUpdates:
+ """
+ Extract facts from the given query
+ """
+ chat_history = construct_chat_history(conversation_history, n=2)
+
+ formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else []
+
+ extract_facts_prompt = prompts.extract_facts_from_query.format(
+ chat_history=chat_history,
+ matched_facts=formatted_memories,
+ )
+
+ with timer("Chat actor: Extract facts from query", logger):
+ response = await send_message_to_model_wrapper(
+ extract_facts_prompt,
+ response_schema=MemoryUpdates,
+ user=user,
+ fast_model=False,
+ agent_chat_model=agent.chat_model,
+ tracer=tracer,
+ )
+ response = response.text.strip()
+ # JSON parse the list of strings
+ try:
+ response = clean_json(response)
+ response = json.loads(response)
+ parsed_response = MemoryUpdates(**response)
+ if not isinstance(parsed_response, MemoryUpdates):
+ raise ValueError(f"Invalid response for extracting facts: {response}")
+ return parsed_response
+
+ except Exception:
+ logger.error(f"Invalid response for extracting facts: {response}")
+ return MemoryUpdates(create=[], delete=[])
+
+
+@require_valid_user
+async def ai_update_memories(
+ user: KhojUser,
+ conversation_history: List[ChatMessageModel],
+ memories: List[UserMemory],
+ agent: Agent,
+ tracer: dict = {},
+):
+ """
+ Updates the memories for a given user, based on their latest input query.
+ """
+ # Skip memory updates if memory is disabled for the user
+ if not await ConversationAdapters.ais_memory_enabled(user):
+ return
+
+ memory_update = await extract_facts_from_query(
+ user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer
+ )
+
+ if not memory_update:
+ return
+
+ # Save the memory updates to the database
+ for memory in memory_update.create:
+ logger.info(f"Creating memory: {memory}")
+ await UserMemoryAdapters.save_memory(user, memory, agent=agent)
+
+ for memory in memory_update.delete:
+ logger.info(f"Deleting memory: {memory}")
+ await UserMemoryAdapters.delete_memory(user, memory)
+
+
async def generate_mermaidjs_diagram(
q: str,
chat_history: List[ChatMessageModel],
@@ -964,6 +1066,7 @@ async def generate_mermaidjs_diagram(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -981,6 +1084,7 @@ async def generate_mermaidjs_diagram(
online_results=online_results,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,
@@ -1010,6 +1114,7 @@ async def generate_better_mermaidjs_diagram_description(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -1055,6 +1160,7 @@ async def generate_better_mermaidjs_diagram_description(
improve_diagram_description_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
@@ -1104,6 +1210,7 @@ async def generate_better_image_prompt(
model_type: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = "",
+ relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -1148,6 +1255,7 @@ async def generate_better_image_prompt(
q,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
chat_history=conversation_history,
system_message=enhance_image_system_message,
response_type="json_object",
@@ -1180,6 +1288,7 @@ async def search_documents(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
previous_inferred_queries: Set = set(),
fast_model: bool = True,
agent: Agent = None,
@@ -1230,6 +1339,7 @@ async def search_documents(
user=user,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
personality_context=personality_context,
location_data=location_data,
chat_history=chat_history,
@@ -1283,6 +1393,7 @@ async def extract_questions(
query_files: str = None,
query_images: Optional[List[str]] = None,
personality_context: str = "",
+ relevant_memories: List[UserMemory] = None,
location_data: LocationData = None,
chat_history: List[ChatMessageModel] = [],
max_queries: int = 5,
@@ -1338,6 +1449,7 @@ async def extract_questions(
query=prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
system_message=system_prompt,
response_type="json_object",
response_schema=DocumentQueries,
@@ -1518,6 +1630,7 @@ async def send_message_to_model_wrapper(
query_files: str = None,
query_images: List[str] = None,
context: str = "",
+ relevant_memories: List[UserMemory] = None,
chat_history: list[ChatMessageModel] = [],
system_message: str = "",
# Model Config
@@ -1568,6 +1681,7 @@ async def send_message_to_model_wrapper(
query_files=query_files,
query_images=query_images,
context_message=context,
+ relevant_memories=relevant_memories,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model.name,
@@ -1685,6 +1799,7 @@ def build_conversation_context(
operator_results: List[OperatorRun],
query_files: str = None,
query_images: Optional[List[str]] = None,
+ relevant_memories: List[UserMemory] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
chat_history: List[ChatMessageModel] = [],
@@ -1761,6 +1876,7 @@ def build_conversation_context(
query_files=query_files,
query_images=query_images,
context_message=context_message,
+ relevant_memories=relevant_memories,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
@@ -1789,6 +1905,7 @@ async def agenerate_chat_response(
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = [],
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
@@ -1831,6 +1948,7 @@ async def agenerate_chat_response(
operator_results=operator_results,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
@@ -2792,6 +2910,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
selected_chat_model_config = ConversationAdapters.get_chat_model(
user
) or ConversationAdapters.get_default_chat_model(user)
+ server_chat_settings = ServerChatSettings.objects.first()
+ server_memory_mode = (
+ server_chat_settings.memory_mode
+ if server_chat_settings
+ else ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON.value # type: ignore[attr-defined]
+ )
+ enable_memory = ConversationAdapters.is_memory_enabled(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:
@@ -2848,6 +2973,8 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"enabled_content_source": enabled_content_sources,
"has_documents": has_documents,
"notion_token": notion_token,
+ "enable_memory": enable_memory,
+ "server_memory_mode": server_memory_mode,
# user model settings
"chat_model_options": chat_model_options,
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py
index e103b330..3185e8c0 100644
--- a/src/khoj/routers/research.py
+++ b/src/khoj/routers/research.py
@@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional
import yaml
from khoj.database.adapters import AgentAdapters, EntryAdapters, McpServerAdapters
-from khoj.database.models import Agent, ChatMessageModel, KhojUser
+from khoj.database.models import Agent, ChatMessageModel, KhojUser, UserMemory
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
@@ -285,6 +285,7 @@ async def apick_next_tool(
max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
+ relevant_memories: List[UserMemory] = [],
max_document_searches: int = 7,
max_online_searches: int = 3,
max_webpages_to_read: int = 3,
@@ -416,6 +417,7 @@ async def apick_next_tool(
query="",
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
tools=tools,
@@ -479,6 +481,7 @@ async def research(
previous_iterations: List[ResearchIteration],
query_images: List[str],
query_files: str = None,
+ relevant_memories: List[UserMemory] = [],
user_name: str = None,
location: LocationData = None,
send_status_func: Optional[Callable] = None,
@@ -544,6 +547,7 @@ async def research(
MAX_ITERATIONS,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
max_webpages_to_read=max_webpages_to_read,
diff --git a/tests/helpers.py b/tests/helpers.py
index 985ba00a..d5740d53 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -4,9 +4,12 @@ import os
from datetime import datetime
import factory
+from asgiref.sync import sync_to_async
from django.utils.timezone import make_aware
+from khoj.database.adapters import AgentAdapters
from khoj.database.models import (
+ Agent,
AiModelApi,
ChatMessageModel,
ChatModel,
@@ -15,8 +18,10 @@ from khoj.database.models import (
KhojUser,
ProcessLock,
SearchModelConfig,
+ ServerChatSettings,
Subscription,
UserConversationConfig,
+ UserMemory,
)
from khoj.processor.conversation.utils import message_to_log
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
@@ -277,3 +282,45 @@ class ProcessLockFactory(factory.django.DjangoModelFactory):
model = ProcessLock
name = "test_lock"
+
+
+class ServerChatSettingsFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = ServerChatSettings
+
+ memory_mode = ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON
+
+
+# Async-safe wrappers for factories and ORM operations
+async def acreate_user():
+ return await sync_to_async(UserFactory)()
+
+
+async def acreate_subscription(user):
+ return await sync_to_async(SubscriptionFactory)(user=user)
+
+
+async def acreate_chat_model():
+ return await sync_to_async(ChatModelFactory)()
+
+
+async def acreate_default_agent():
+ return await sync_to_async(AgentAdapters.create_default_agent)()
+
+
+async def acreate_agent(name, chat_model, personality):
+ return await sync_to_async(Agent.objects.create)(
+ name=name,
+ chat_model=chat_model,
+ personality=personality,
+ )
+
+
+async def acreate_test_memory(user, agent=None, raw_text="test memory"):
+ """Create a memory directly in DB without embeddings for testing."""
+ return await sync_to_async(UserMemory.objects.create)(
+ user=user,
+ agent=agent,
+ raw=raw_text,
+ embeddings=[0.1] * 384, # Dummy embeddings
+ )
diff --git a/tests/test_memory_settings.py b/tests/test_memory_settings.py
new file mode 100644
index 00000000..28621670
--- /dev/null
+++ b/tests/test_memory_settings.py
@@ -0,0 +1,479 @@
+"""
+Tests for memory enable/disable settings and memory scoping by user+agent.
+
+These tests verify:
+1. The behavior of ConversationAdapters.is_memory_enabled() for different combinations of:
+ - ServerChatSettings.memory_mode (DISABLED, ENABLED_DEFAULT_OFF, ENABLED_DEFAULT_ON)
+ - UserConversationConfig.enable_memory (True, False, or not set)
+
+2. Memory scoping by user and agent:
+ - Memories are scoped to user + agent
+ - Default agent has access to ALL memories across all agents for a user
+ - Non-default agents only see their own memories
+"""
+
+import pytest
+from unittest.mock import MagicMock
+
+from khoj.database.adapters import ConversationAdapters, UserMemoryAdapters
+from khoj.database.models import ServerChatSettings, UserConversationConfig
+from khoj.routers.helpers import get_user_config
+from tests.helpers import (
+ acreate_user,
+ acreate_subscription,
+ acreate_chat_model,
+ acreate_default_agent,
+ acreate_agent,
+ acreate_test_memory,
+ ServerChatSettingsFactory,
+ SubscriptionFactory,
+ UserFactory,
+)
+
+
+# ----------------------------------------------------------------------------------------------------
+# Test is_memory_enabled with no server config (default behavior)
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
+def test_memory_enabled_no_server_config_no_user_config():
+ """When no server config and no user config exists, memory should be enabled (default on)."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is True
+
+
+@pytest.mark.django_db
+def test_memory_enabled_no_server_config_user_enabled():
+ """When no server config but user has explicitly enabled memory."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ user_config = UserConversationConfig.objects.create(user=user, enable_memory=True)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is True
+
+
+@pytest.mark.django_db
+def test_memory_enabled_no_server_config_user_disabled():
+ """When no server config but user has explicitly disabled memory."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ user_config = UserConversationConfig.objects.create(user=user, enable_memory=False)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+# ----------------------------------------------------------------------------------------------------
+# Test is_memory_enabled with server mode DISABLED
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
+def test_memory_disabled_server_disabled_no_user_config():
+ """When server disables memory, it should override everything - no user config."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+@pytest.mark.django_db
+def test_memory_disabled_server_disabled_user_enabled():
+ """When server disables memory, it should override user preference (enabled)."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED)
+ UserConversationConfig.objects.create(user=user, enable_memory=True)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+@pytest.mark.django_db
+def test_memory_disabled_server_disabled_user_disabled():
+ """When server disables memory, user disabled too - should be disabled."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED)
+ UserConversationConfig.objects.create(user=user, enable_memory=False)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+# ----------------------------------------------------------------------------------------------------
+# Test is_memory_enabled with server mode ENABLED_DEFAULT_OFF
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
+def test_memory_enabled_default_off_no_user_config():
+ """When server is enabled_default_off and no user config, memory should be off."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+@pytest.mark.django_db
+def test_memory_enabled_default_off_user_enabled():
+ """When server is enabled_default_off and user opts in, memory should be on."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF)
+ UserConversationConfig.objects.create(user=user, enable_memory=True)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is True
+
+
+@pytest.mark.django_db
+def test_memory_enabled_default_off_user_disabled():
+ """When server is enabled_default_off and user explicitly disabled, memory should be off."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF)
+ UserConversationConfig.objects.create(user=user, enable_memory=False)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+# ----------------------------------------------------------------------------------------------------
+# Test is_memory_enabled with server mode ENABLED_DEFAULT_ON
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
+def test_memory_enabled_default_on_no_user_config():
+ """When server is enabled_default_on and no user config, memory should be on."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is True
+
+
+@pytest.mark.django_db
+def test_memory_enabled_default_on_user_enabled():
+ """When server is enabled_default_on and user enabled, memory should be on."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON)
+ UserConversationConfig.objects.create(user=user, enable_memory=True)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is True
+
+
+@pytest.mark.django_db
+def test_memory_enabled_default_on_user_disabled():
+ """When server is enabled_default_on and user opts out, memory should be off."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON)
+ UserConversationConfig.objects.create(user=user, enable_memory=False)
+
+ result = ConversationAdapters.is_memory_enabled(user)
+
+ assert result is False
+
+
+# ----------------------------------------------------------------------------------------------------
+# Test get_user_config returns correct enable_memory and server_memory_mode
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
+def test_get_user_config_memory_no_server_config():
+ """get_user_config should return default values when no server config."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ request = MagicMock()
+ request.url = MagicMock()
+ request.url.path = "/api/config"
+ request.session = {}
+
+ config = get_user_config(user, request, is_detailed=True)
+
+ assert config["enable_memory"] is True
+ assert config["server_memory_mode"] == "enabled_default_on"
+
+
+@pytest.mark.django_db
+def test_get_user_config_memory_server_disabled():
+ """get_user_config should reflect server disabled mode."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED)
+ request = MagicMock()
+ request.url = MagicMock()
+ request.url.path = "/api/config"
+ request.session = {}
+
+ config = get_user_config(user, request, is_detailed=True)
+
+ assert config["enable_memory"] is False
+ assert config["server_memory_mode"] == "disabled"
+
+
+@pytest.mark.django_db
+def test_get_user_config_memory_server_enabled_default_off_user_opted_in():
+ """get_user_config should show user opted in when server is default off."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF)
+ UserConversationConfig.objects.create(user=user, enable_memory=True)
+ request = MagicMock()
+ request.url = MagicMock()
+ request.url.path = "/api/config"
+ request.session = {}
+
+ config = get_user_config(user, request, is_detailed=True)
+
+ assert config["enable_memory"] is True
+ assert config["server_memory_mode"] == "enabled_default_off"
+
+
+@pytest.mark.django_db
+def test_get_user_config_memory_server_enabled_default_on_user_opted_out():
+ """get_user_config should show user opted out when server is default on."""
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON)
+ UserConversationConfig.objects.create(user=user, enable_memory=False)
+ request = MagicMock()
+ request.url = MagicMock()
+ request.url.path = "/api/config"
+ request.session = {}
+
+ config = get_user_config(user, request, is_detailed=True)
+
+ assert config["enable_memory"] is False
+ assert config["server_memory_mode"] == "enabled_default_on"
+
+
+# ----------------------------------------------------------------------------------------------------
+# Test memory scoping by user and agent
+# ----------------------------------------------------------------------------------------------------
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_pull_memories_default_agent_sees_all_memories():
+ """Default agent should see ALL memories for the user, including those from other agents."""
+ # Setup
+ user = await acreate_user()
+ await acreate_subscription(user)
+ chat_model = await acreate_chat_model()
+
+ # Create default agent
+ default_agent = await acreate_default_agent()
+ assert default_agent is not None
+
+ # Create a custom agent
+ custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent")
+
+ # Create memories for different agents
+ await acreate_test_memory(user, agent=None, raw_text="memory without agent")
+ await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent")
+ await acreate_test_memory(user, agent=custom_agent, raw_text="memory for custom agent")
+
+ # Act: Pull memories with default agent
+ memories = await UserMemoryAdapters.pull_memories(user=user, agent=default_agent)
+
+ # Assert: Default agent sees ALL memories
+ memory_texts = [m.raw for m in memories]
+ assert "memory without agent" in memory_texts
+ assert "memory for default agent" in memory_texts
+ assert "memory for custom agent" in memory_texts
+ assert len(memories) == 3
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_pull_memories_custom_agent_sees_only_own_memories():
+ """Custom (non-default) agent should only see its own memories."""
+ # Setup
+ user = await acreate_user()
+ await acreate_subscription(user)
+ chat_model = await acreate_chat_model()
+
+ # Create default agent
+ default_agent = await acreate_default_agent()
+ assert default_agent is not None
+
+ # Create custom agents
+ custom_agent_1 = await acreate_agent("Custom Agent 1", chat_model, "First custom agent")
+ custom_agent_2 = await acreate_agent("Custom Agent 2", chat_model, "Second custom agent")
+
+ # Create memories for different agents
+ await acreate_test_memory(user, agent=None, raw_text="memory without agent")
+ await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent")
+ await acreate_test_memory(user, agent=custom_agent_1, raw_text="memory for custom agent 1")
+ await acreate_test_memory(user, agent=custom_agent_2, raw_text="memory for custom agent 2")
+
+ # Act: Pull memories with custom_agent_1
+ memories = await UserMemoryAdapters.pull_memories(user=user, agent=custom_agent_1)
+
+ # Assert: Custom agent 1 only sees its own memories
+ memory_texts = [m.raw for m in memories]
+ assert "memory for custom agent 1" in memory_texts
+ assert "memory without agent" not in memory_texts
+ assert "memory for default agent" not in memory_texts
+ assert "memory for custom agent 2" not in memory_texts
+ assert len(memories) == 1
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_pull_memories_no_agent_same_as_default_agent():
+ """Pulling memories with agent=None should behave same as default agent (see all)."""
+ # Setup
+ user = await acreate_user()
+ await acreate_subscription(user)
+ chat_model = await acreate_chat_model()
+
+ # Create default agent
+ default_agent = await acreate_default_agent()
+ assert default_agent is not None
+
+ # Create a custom agent
+ custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent")
+
+ # Create memories
+ await acreate_test_memory(user, agent=None, raw_text="memory without agent")
+ await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent")
+ await acreate_test_memory(user, agent=custom_agent, raw_text="memory for custom agent")
+
+ # Act: Pull memories with agent=None
+ memories = await UserMemoryAdapters.pull_memories(user=user, agent=None)
+
+ # Assert: Should see all memories (same as default agent behavior)
+ memory_texts = [m.raw for m in memories]
+ assert "memory without agent" in memory_texts
+ assert "memory for default agent" in memory_texts
+ assert "memory for custom agent" in memory_texts
+ assert len(memories) == 3
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_save_memory_with_custom_agent_scopes_to_agent():
+ """Memories saved with a custom agent should be scoped to that agent."""
+ # Setup
+ user = await acreate_user()
+ await acreate_subscription(user)
+ chat_model = await acreate_chat_model()
+
+ # Create default agent
+ default_agent = await acreate_default_agent()
+ assert default_agent is not None
+
+ # Create custom agent
+ custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent")
+
+ # Create memory with custom agent (directly in DB to avoid embeddings)
+ memory = await acreate_test_memory(user, agent=custom_agent, raw_text="custom agent memory")
+
+ # Assert: Memory is scoped to the custom agent
+ assert memory.agent == custom_agent
+ assert memory.user == user
+
+ # Verify custom agent can see it
+ custom_memories = await UserMemoryAdapters.pull_memories(user=user, agent=custom_agent)
+ assert len(custom_memories) == 1
+ assert custom_memories[0].raw == "custom agent memory"
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_save_memory_with_default_agent_has_no_agent_scope():
+ """Memories saved with default agent should have agent=None (global scope)."""
+ # Setup
+ user = await acreate_user()
+ await acreate_subscription(user)
+ await acreate_chat_model() # Required for default agent creation
+
+ # Create default agent
+ default_agent = await acreate_default_agent()
+ assert default_agent is not None
+
+ # Create memory with default agent (directly in DB)
+ # Based on save_memory logic: if agent == default_agent, agent is not set
+ memory = await acreate_test_memory(user, agent=None, raw_text="default agent memory")
+
+ # Assert: Memory has no agent (global scope)
+ assert memory.agent is None
+ assert memory.user == user
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_memories_isolated_between_users():
+ """Memories should be isolated between different users."""
+ # Setup
+ user1 = await acreate_user()
+ user2 = await acreate_user()
+ await acreate_subscription(user1)
+ await acreate_subscription(user2)
+
+ # Create default agent
+ await acreate_default_agent()
+
+ # Create memories for each user
+ await acreate_test_memory(user1, agent=None, raw_text="user1 memory")
+ await acreate_test_memory(user2, agent=None, raw_text="user2 memory")
+
+ # Act: Pull memories for each user
+ user1_memories = await UserMemoryAdapters.pull_memories(user=user1)
+ user2_memories = await UserMemoryAdapters.pull_memories(user=user2)
+
+ # Assert: Each user only sees their own memories
+ assert len(user1_memories) == 1
+ assert user1_memories[0].raw == "user1 memory"
+
+ assert len(user2_memories) == 1
+ assert user2_memories[0].raw == "user2 memory"
+
+
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_custom_agent_cannot_see_other_custom_agent_memories():
+ """One custom agent should not see another custom agent's memories."""
+ # Setup
+ user = await acreate_user()
+ await acreate_subscription(user)
+ chat_model = await acreate_chat_model()
+
+ # Create default agent
+ await acreate_default_agent()
+
+ # Create two custom agents
+ agent_accountant = await acreate_agent("Accountant", chat_model, "Financial advisor")
+ agent_chef = await acreate_agent("Chef", chat_model, "Cooking expert")
+
+ # Create memories for each agent
+ await acreate_test_memory(user, agent=agent_accountant, raw_text="user spent $500 on groceries")
+ await acreate_test_memory(user, agent=agent_chef, raw_text="user likes Italian food")
+
+ # Act & Assert: Accountant only sees financial memories
+ accountant_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent_accountant)
+ assert len(accountant_memories) == 1
+ assert accountant_memories[0].raw == "user spent $500 on groceries"
+
+ # Act & Assert: Chef only sees food memories
+ chef_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent_chef)
+ assert len(chef_memories) == 1
+ assert chef_memories[0].raw == "user likes Italian food"