From d6c2d1fa49cb585259480b993c521b3542aa5ae9 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Fri, 2 Jan 2026 19:37:05 -0800 Subject: [PATCH] Give Khoj Long Term Memories (#1168) # Motivation A major component of useful AI systems is adaptation to the user context. This is a major reason why we'd enabled syncing knowledge bases. The next steps in this direction is to dynamically update the evolving state of the user as conversations take place across time and topics. This allows for more personalized conversations and to maintain context across conversations. # Overview This change introduces medium and long term memories in Khoj. - The scope of a conversation can be thought of as short term memory. - Medium term memory extends to the past week. - Long term memory extends to anytime in the past, where a search query results in a match. # Details - Enable user to view and manage agent generated memories from their settings page - Fully integrate the memory object into all downstream usage, from image generation, notes extraction, online search, etc. - Scope memory per agent. The default agent has access to memories created by other agents as well. - Enable users and admins to enable/disable Khoj's memory system --------- Co-authored-by: Debanjum --- src/interface/web/app/common/auth.ts | 2 + .../app/components/userMemory/userMemory.tsx | 90 ++++ src/interface/web/app/settings/page.tsx | 158 +++++- src/interface/web/bun.lock | 7 +- src/interface/web/components/ui/switch.tsx | 29 ++ src/interface/web/package.json | 1 + src/khoj/configure.py | 2 + src/khoj/database/adapters/__init__.py | 194 ++++++- src/khoj/database/admin.py | 3 + .../management/commands/manage_memories.py | 384 ++++++++++++++ .../database/migrations/0099_usermemory.py | 82 +++ src/khoj/database/models/__init__.py | 26 + src/khoj/processor/conversation/prompts.py | 67 +++ src/khoj/processor/conversation/utils.py | 23 + src/khoj/processor/image/generate.py | 3 + src/khoj/processor/operator/__init__.py | 3 +- src/khoj/processor/tools/online_search.py | 26 +- src/khoj/processor/tools/run_code.py | 6 +- src/khoj/routers/api.py | 25 +- src/khoj/routers/api_chat.py | 21 +- src/khoj/routers/api_memories.py | 114 +++++ src/khoj/routers/helpers.py | 129 ++++- src/khoj/routers/research.py | 6 +- tests/helpers.py | 47 ++ tests/test_memory_settings.py | 479 ++++++++++++++++++ 25 files changed, 1910 insertions(+), 17 deletions(-) create mode 100644 src/interface/web/app/components/userMemory/userMemory.tsx create mode 100644 src/interface/web/components/ui/switch.tsx create mode 100644 src/khoj/database/management/commands/manage_memories.py create mode 100644 src/khoj/database/migrations/0099_usermemory.py create mode 100644 src/khoj/routers/api_memories.py create mode 100644 tests/test_memory_settings.py diff --git a/src/interface/web/app/common/auth.ts b/src/interface/web/app/common/auth.ts index defbf536..6344fbfe 100644 --- a/src/interface/web/app/common/auth.ts +++ b/src/interface/web/app/common/auth.ts @@ -63,6 +63,8 @@ export interface UserConfig { enabled_content_source: SyncedContent; has_documents: boolean; notion_token: string | null; + enable_memory: boolean; + server_memory_mode: "disabled" | "enabled_default_off" | "enabled_default_on"; // user model settings search_model_options: ModelOptions[]; selected_search_model_config: number; diff --git a/src/interface/web/app/components/userMemory/userMemory.tsx b/src/interface/web/app/components/userMemory/userMemory.tsx new file mode 100644 index 00000000..db360e7b --- /dev/null +++ b/src/interface/web/app/components/userMemory/userMemory.tsx @@ -0,0 +1,90 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Button } from "@/components/ui/button"; +import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react"; +import { useToast } from "@/components/ui/use-toast"; + +export interface UserMemorySchema { + id: number; + raw: string; + created_at: string; +} + +interface UserMemoryProps { + memory: UserMemorySchema; + onDelete: (id: number) => void; + onUpdate: (id: number, raw: string) => void; +} + +export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) { + const [isEditing, setIsEditing] = useState(false); + const [content, setContent] = useState(memory.raw); + const { toast } = useToast(); + + const handleUpdate = () => { + onUpdate(memory.id, content); + setIsEditing(false); + toast({ + title: "Memory Updated", + description: "Your memory has been successfully updated.", + }); + }; + + const handleDelete = () => { + onDelete(memory.id); + toast({ + title: "Memory Deleted", + description: "Your memory has been successfully deleted.", + }); + }; + + return ( +
+ {isEditing ? ( + <> + setContent(e.target.value)} + className="flex-1" + /> + + + + ) : ( + <> + + + + + )} +
+ ); +} diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index ad8b88f5..75ea7ee2 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -15,6 +15,8 @@ import { Button } from "@/components/ui/button"; import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp"; import { Input } from "@/components/ui/input"; import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card"; +import { Switch } from "@/components/ui/switch"; + import { DropdownMenu, DropdownMenuContent, @@ -33,6 +35,15 @@ import { AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; + +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog"; + import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table"; import { @@ -74,6 +85,7 @@ import Loading from "../components/loading/loading"; import IntlTelInput from "intl-tel-input/react"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; +import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory"; import { Separator } from "@/components/ui/separator"; import { KhojLogoType } from "../components/logo/khojLogo"; import { Progress } from "@/components/ui/progress"; @@ -323,6 +335,9 @@ export default function SettingsView() { const [numberValidationState, setNumberValidationState] = useState( PhoneNumberValidationState.Verified, ); + const [memories, setMemories] = useState([]); + const [enableMemory, setEnableMemory] = useState(true); + const [serverMemoryMode, setServerMemoryMode] = useState("enabled_default_on"); const [isExporting, setIsExporting] = useState(false); const [exportProgress, setExportProgress] = useState(0); const [exportedConversations, setExportedConversations] = useState(0); @@ -347,6 +362,8 @@ export default function SettingsView() { ); setName(initialUserConfig?.given_name); setNotionToken(initialUserConfig?.notion_token ?? null); + setEnableMemory(initialUserConfig?.enable_memory ?? true); + setServerMemoryMode(initialUserConfig?.server_memory_mode ?? "enabled_default_on"); }, [initialUserConfig]); const sendOTP = async () => { @@ -621,6 +638,88 @@ export default function SettingsView() { } }; + const fetchMemories = async () => { + try { + console.log("Fetching memories..."); + const response = await fetch('/api/memories/'); + if (!response.ok) throw new Error('Failed to fetch memories'); + const data = await response.json(); + setMemories(data); + } catch (error) { + console.error('Error fetching memories:', error); + toast({ + title: "Error", + description: "Failed to fetch memories. Please try again.", + variant: "destructive" + }); + } + }; + + const handleDeleteMemory = async (id: number) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'DELETE' + }); + if (!response.ok) throw new Error('Failed to delete memory'); + setMemories(memories.filter(memory => memory.id !== id)); + } catch (error) { + console.error('Error deleting memory:', error); + toast({ + title: "Error", + description: "Failed to delete memory. Please try again.", + variant: "destructive" + }); + } + }; + + const handleUpdateMemory = async (id: number, raw: string) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ raw, memory_id: id }), + }); + if (!response.ok) throw new Error('Failed to update memory'); + const updatedMemory: UserMemorySchema = await response.json(); + setMemories(memories.map(memory => + memory.id === id ? updatedMemory : memory + )); + } catch (error) { + console.error('Error updating memory:', error); + toast({ + title: "Error", + description: "Failed to update memory. Please try again.", + variant: "destructive" + }); + } + }; + + const handleToggleMemory = async (enabled: boolean) => { + try { + const response = await fetch(`/api/user/memory?enable_memory=${enabled}`, { + method: 'PATCH', + }); + if (!response.ok) throw new Error('Failed to update memory setting'); + setEnableMemory(enabled); + toast({ + title: enabled ? "Memory enabled" : "Memory disabled", + description: enabled + ? "Khoj will learn and remember from your conversations." + : "Khoj will no longer learn or remember from your conversations.", + }); + } catch (error) { + console.error('Error toggling memory:', error); + toast({ + title: "Error", + description: "Failed to update memory setting. Please try again.", + variant: "destructive" + }); + } + }; + + const syncContent = async (type: string) => { try { const response = await fetch(`/api/content?t=${type}`, { @@ -1212,7 +1311,64 @@ export default function SettingsView() { - + + + + Memories + + +

+ View and manage your long-term memories +

+
+ + handleToggleMemory(checked)} + disabled={serverMemoryMode === "disabled"} + /> +
+ {serverMemoryMode === "disabled" && ( +

+ Memory has been disabled by the server administrator. +

+ )} +
+ + open && fetchMemories()}> + + + + + + Your Memories + +
+ {memories.map((memory) => ( + + ))} + {memories.length === 0 && ( +

No memories found

+ )} +
+
+
+
+
diff --git a/src/interface/web/bun.lock b/src/interface/web/bun.lock index 36f1d3cf..9287dd9f 100644 --- a/src/interface/web/bun.lock +++ b/src/interface/web/bun.lock @@ -24,6 +24,7 @@ "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-slot": "^1.2.3", + "@radix-ui/react-switch": "^1.2.6", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-toast": "^1.2.15", "@radix-ui/react-toggle": "^1.1.10", @@ -274,7 +275,7 @@ "@radix-ui/react-slot": ["@radix-ui/react-slot@1.2.3", "", { "dependencies": { "@radix-ui/react-compose-refs": "1.1.2" }, "peerDependencies": { "@types/react": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react"] }, "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A=="], - "@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="], + "@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.6", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-bByzr1+ep1zk4VubeEVViV592vu2lHE2BZY5OnzehZqOOgogN80+mNtCqPkhn2gklJqOpxWgPoYTSnhBCqpOXQ=="], "@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.13", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.5", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.11", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A=="], @@ -1478,8 +1479,6 @@ "@radix-ui/react-slider/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="], - "@radix-ui/react-switch/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="], - "@radix-ui/react-toggle-group/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="], "@radix-ui/react-toggle-group/@radix-ui/react-roving-focus": ["@radix-ui/react-roving-focus@1.1.10", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q=="], @@ -1576,6 +1575,8 @@ "radix-ui/@radix-ui/react-select": ["@radix-ui/react-select@2.2.5", "", { "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-popper": "1.2.7", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA=="], + "radix-ui/@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="], + "radix-ui/@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.12", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-GTVAlRVrQrSw3cEARM0nAx73ixrWDPNZAruETn3oHCNP6SbZ/hNxdxp+u7VkIEv3/sFoLq1PfcHrl7Pnp0CDpw=="], "radix-ui/@radix-ui/react-toast": ["@radix-ui/react-toast@1.2.14", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-nAP5FBxBJGQ/YfUB+r+O6USFVkWq3gAInkxyEnmvEV5jtSbfDhfa4hwX8CraCnbjMLsE7XSf/K75l9xXY7joWg=="], diff --git a/src/interface/web/components/ui/switch.tsx b/src/interface/web/components/ui/switch.tsx new file mode 100644 index 00000000..812af0df --- /dev/null +++ b/src/interface/web/components/ui/switch.tsx @@ -0,0 +1,29 @@ +"use client" + +import * as React from "react" +import * as SwitchPrimitives from "@radix-ui/react-switch" + +import { cn } from "@/lib/utils" + +const Switch = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)) +Switch.displayName = SwitchPrimitives.Root.displayName + +export { Switch } diff --git a/src/interface/web/package.json b/src/interface/web/package.json index 3eceac6f..ccbbd4bf 100644 --- a/src/interface/web/package.json +++ b/src/interface/web/package.json @@ -38,6 +38,7 @@ "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-slot": "^1.2.3", + "@radix-ui/react-switch": "^1.2.6", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-toast": "^1.2.15", "@radix-ui/react-toggle": "^1.1.10", diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 40d1eeb5..50b2dd31 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -323,6 +323,7 @@ def configure_routes(app): from khoj.routers.api_automation import api_automation from khoj.routers.api_chat import api_chat from khoj.routers.api_content import api_content + from khoj.routers.api_memories import api_memories from khoj.routers.api_model import api_model from khoj.routers.notion import notion_router from khoj.routers.web_client import web_client @@ -332,6 +333,7 @@ def configure_routes(app): app.include_router(api_agents, prefix="/api/agents") app.include_router(api_automation, prefix="/api/automation") app.include_router(api_model, prefix="/api/model") + app.include_router(api_memories, prefix="/api/memories") app.include_router(api_content, prefix="/api/content") app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 3ab08ef2..abcc226f 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -62,6 +62,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserTextToImageModelConfig, UserVoiceModelConfig, @@ -566,6 +567,16 @@ def get_default_search_model() -> SearchModelConfig: return SearchModelConfig.objects.first() +async def aget_default_search_model() -> SearchModelConfig: + default_search_model = await SearchModelConfig.objects.filter(name="default").afirst() + + if default_search_model: + return default_search_model + elif await SearchModelConfig.objects.count() == 0: + await SearchModelConfig.objects.acreate() + return await SearchModelConfig.objects.afirst() + + def get_or_create_search_models(): search_models = SearchModelConfig.objects.all() if search_models.count() == 0: @@ -1556,12 +1567,17 @@ class ConversationAdapters: ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = await Conversation.objects.filter( - user=user, client=client_application, id=conversation_id - ).afirst() + conversation = ( + await Conversation.objects.filter(user=user, client=client_application, id=conversation_id) + .prefetch_related("agent", "agent__chat_model") + .afirst() + ) else: conversation = ( - await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() + await Conversation.objects.filter(user=user, client=client_application) + .prefetch_related("agent", "agent__chat_model") + .order_by("-updated_at") + .afirst() ) existing_messages = conversation.messages if conversation else [] @@ -1597,6 +1613,85 @@ class ConversationAdapters: return None return config.setting + @staticmethod + async def ais_memory_enabled(user: KhojUser) -> bool: + """ + Check if memory is enabled for the user based on server config and user preference. + + Logic: + - If server memory_mode is DISABLED: return False (overrides user preference) + - If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False + - If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True + - If no server config exists: default to True + """ + # Get server-level memory configuration + server_settings = await ServerChatSettings.objects.afirst() + if server_settings: + memory_mode = server_settings.memory_mode + # Disabled mode overrides all user preferences + if memory_mode == ServerChatSettings.MemoryMode.DISABLED: + return False + + # Get user preference + user_config = await UserConversationConfig.objects.filter(user=user).afirst() + + if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF: + # User must explicitly opt-in; check if user has set a preference + if user_config is None: + return False # Default off for new users + return user_config.enable_memory + + # ENABLED_DEFAULT_ON: use user preference if set, else True + if user_config is None: + return True # Default on for new users + return user_config.enable_memory + + # No server config - default behavior (enabled, default on) + user_config = await UserConversationConfig.objects.filter(user=user).afirst() + if user_config is None: + return True + return user_config.enable_memory + + @staticmethod + def is_memory_enabled(user: KhojUser) -> bool: + """ + Sync version of ais_memory_enabled. + Check if memory is enabled for the user based on server config and user preference. + + Logic: + - If server memory_mode is DISABLED: return False (overrides user preference) + - If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False + - If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True + - If no server config exists: default to True + """ + # Get server-level memory configuration + server_settings = ServerChatSettings.objects.first() + if server_settings: + memory_mode = server_settings.memory_mode + # Disabled mode overrides all user preferences + if memory_mode == ServerChatSettings.MemoryMode.DISABLED: + return False + + # Get user preference + user_config = UserConversationConfig.objects.filter(user=user).first() + + if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF: + # User must explicitly opt-in; check if user has set a preference + if user_config is None: + return False # Default off for new users + return user_config.enable_memory + + # ENABLED_DEFAULT_ON: use user preference if set, else True + if user_config is None: + return True # Default on for new users + return user_config.enable_memory + + # No server config - default behavior (enabled, default on) + user_config = UserConversationConfig.objects.filter(user=user).first() + if user_config is None: + return True + return user_config.enable_memory + @staticmethod async def get_speech_to_text_config(): return await SpeechToTextModelOptions.objects.filter().prefetch_related("ai_model_api").afirst() @@ -2186,3 +2281,94 @@ class McpServerAdapters: except Exception as e: logger.error(f"Error retrieving MCP servers: {e}", exc_info=True) return servers + + +class UserMemoryAdapters: + @staticmethod + @require_valid_user + async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]: + """ + Pulls memories from the database for a given user. Medium term memory. + """ + time_frame = datetime.now(timezone.utc) - timedelta(days=window) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by( + "-created_at" + )[:limit] + else: + memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit] + return await sync_to_async(list)(memories) + + @staticmethod + @require_valid_user + async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory: + """ + Saves a memory to the database for a given user. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent + ) + else: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model + ) + + return memory_instance + + @staticmethod + @require_valid_user + async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]: + """ + Searches for memories in the database for a given user. Long term memory. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + max_distance = model.bi_encoder_confidence_threshold or math.inf + embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query) + default_agent = await AgentAdapters.aget_default_agent() + + if agent and agent != default_agent: + relevant_memories = UserMemory.objects.filter(user=user, agent=agent) + else: + relevant_memories = UserMemory.objects.filter(user=user) + + relevant_memories = ( + relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query)) + .order_by("distance") + .filter(distance__lte=max_distance) + ) + + return await sync_to_async(list)(relevant_memories[:limit]) + + @staticmethod + @require_valid_user + async def delete_memory(user: KhojUser, memory_id: str) -> bool: + """ + Deletes a memory from the database for a given user. + """ + try: + memory = await UserMemory.objects.aget(user=user, id=memory_id) + await memory.adelete() + return True + except UserMemory.DoesNotExist: + return False + + @staticmethod + def to_dict(memories: List[UserMemory]) -> List[dict]: + """ + Converts a list of Memory objects to a list of dictionaries. + """ + return [ + { + "id": f"{memory.id}", + "raw": memory.raw, + "updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"), + } + for memory in memories + ] diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 872c57ca..345b03fd 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -34,6 +34,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserVoiceModelConfig, VoiceModelOption, @@ -182,6 +183,7 @@ admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin) admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin) admin.site.register(UserRequests, unfold_admin.ModelAdmin) admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin) +admin.site.register(UserMemory, unfold_admin.ModelAdmin) @admin.register(McpServer) @@ -295,6 +297,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): "think_paid_fast", "think_paid_deep", "web_scraper", + "memory_mode", ) ordering = ("priority",) diff --git a/src/khoj/database/management/commands/manage_memories.py b/src/khoj/database/management/commands/manage_memories.py new file mode 100644 index 00000000..62e14c97 --- /dev/null +++ b/src/khoj/database/management/commands/manage_memories.py @@ -0,0 +1,384 @@ +import asyncio +from datetime import datetime, timedelta +from typing import List, Optional + +from django.core.management.base import BaseCommand +from django.db.models import Q +from django.utils import timezone + +from khoj.configure import initialize_server +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import ( + Conversation, + DataStore, + KhojUser, +) +from khoj.routers.helpers import extract_facts_from_query + + +class Command(BaseCommand): + help = "Manage user memories - generate from conversations or delete existing memories" + + def add_arguments(self, parser): + parser.add_argument( + "--lookback-days", + type=int, + default=None, + help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories", + ) + parser.add_argument( + "--users", + type=str, + help="Process specific users (comma-separated usernames or emails)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=10, + help="Number of conversations to process in each batch (default: 10)", + ) + parser.add_argument( + "--apply", + action="store_true", + help="Actually perform the operation. Without this flag, only shows what would be processed.", + ) + parser.add_argument( + "--delete", + action="store_true", + help="Delete all memories for specified users instead of generating new ones", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from last checkpoint if process was interrupted", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force regenerate memories even if already processed", + ) + + def handle(self, *args, **options): + """Main entry point for the command""" + initialize_server() + asyncio.run(self.async_handle(*args, **options)) + + async def async_handle(self, *args, **options): + """Async handler for memory management""" + lookback_days = options["lookback_days"] + usernames = options["users"] + batch_size = options["batch_size"] + apply = options["apply"] + delete = options["delete"] + resume = options["resume"] + force = options["force"] + + mode = "APPLY" if apply else "DRY RUN" + + # Handle deletion mode + if delete: + # For deletion, only use cutoff_date if lookback_days is explicitly provided + cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None + await self.handle_delete_memories(usernames, cutoff_date, apply) + return + + # Handle generation mode + # For generation, default to 7 days if not specified + if lookback_days is None: + lookback_days = 7 + cutoff_date = timezone.now() - timedelta(days=lookback_days) + self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days") + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + self.stdout.write(f"Found {len(users)} users to process") + + # Initialize or retrieve checkpoint + checkpoint = await self.get_or_create_checkpoint(resume) + + total_conversations = 0 + total_memories = 0 + + for user in users: + # Check if user already processed in checkpoint + if not force and user.id in checkpoint.get("processed_users", []): + self.stdout.write(f"Skipping already processed user: {user.username}") + continue + + self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})") + + # Get conversations for this user + conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force) + + if not conversations: + self.stdout.write(f" No conversations to process for {user.username}") + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + continue + + self.stdout.write(f" Found {len(conversations)} conversations to process") + + # Process conversations in batches + user_memories = 0 + for i in range(0, len(conversations), batch_size): + batch = conversations[i : i + batch_size] + batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint) + user_memories += batch_memories + total_conversations += len(batch) + + # Update progress + progress = min(i + batch_size, len(conversations)) + self.stdout.write( + f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories" + ) + + total_memories += user_memories + self.stdout.write( + f" Completed user {user.username}: " + f"processed {len(conversations)} conversations, " + f"generated {user_memories} memories" + ) + + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + + # Clear checkpoint on successful completion + if apply: + await self.clear_checkpoint() + + action = "Generated" if apply else "Would generate" + self.stdout.write( + self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations") + ) + + async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]: + """Get list of users to comma separated usernames or emails to process""" + if users_str: + usernames = [u.strip() for u in users_str.split(",") if u.strip()] + # Process specific users + users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))] + return users + else: + # Process all users with conversations + return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()] + + async def get_user_conversations( + self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool + ) -> List[Conversation]: + """Get conversations for a user that need processing""" + if cutoff_date is None: + query = Conversation.objects.filter(user=user).order_by("updated_at") + else: + query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at") + + # Filter out already processed conversations if resuming + if not force and user.id in checkpoint.get("processed_conversations", {}): + processed_ids = checkpoint["processed_conversations"][user.id] + query = query.exclude(id__in=processed_ids) + + return [conv async for conv in query] + + async def process_conversation_batch( + self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict + ) -> int: + """Process a batch of conversations and generate memories""" + total_memories = 0 + + for conversation in conversations: + try: + # Get conversation messages using sync_to_async for property access + from asgiref.sync import sync_to_async + + # Access conversation_log synchronously + @sync_to_async + def get_messages(): + return conversation.messages + + messages = await get_messages() + if not messages: + continue + + # Get agent if conversation has one + @sync_to_async + def get_agent(): + return conversation.agent + + agent = await get_agent() + + # Get existing memories for context + # Process each conversation turn + conversation_memories = 0 + i = 0 + while i + 1 < len(messages): + # Only process user-assistant pairs as a valid turn for memory extraction + if messages[i].by != "you" or messages[i + 1].by != "khoj": + i += 1 + continue + + # Get the conversation history up to this point + history = messages[: i + 2] + + # Extract user query text for memory search + q = "" + if messages[i].message is None: + i += 1 + continue + elif isinstance(messages[i].message, str): + q = messages[i].message + elif isinstance(messages[i].message, list): + q = "\n\n".join( + content.get("text", "") + for content in messages[i].message + if isinstance(content, dict) and content.get("text") + ) + + if not q or not q.strip(): + i += 1 + continue + + # Get unique recent and long term relevant memories + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + + if apply: + # Ensure agent is fully loaded with its chat_model + if agent: + + @sync_to_async + def load_agent_with_chat_model(): + # Force load the chat_model relationship + _ = agent.chat_model + return agent + + agent = await load_agent_with_chat_model() + + # Update memories based on latest conversation turn + memory_updates = await extract_facts_from_query( + user=user, + conversation_history=history, + existing_facts=relevant_memories, + agent=agent, + tracer={}, + ) + + # Save new memories + for memory in memory_updates.create: + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + conversation_memories += 1 + self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...") + + # Delete outdated memories + for memory in memory_updates.delete: + await UserMemoryAdapters.delete_memory(user, memory) + self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...") + else: + # Dry run - estimate memories that would be created + conversation_memories += 1 # Rough estimate + + # Move to next conversation turn pair + i += 2 + + total_memories += conversation_memories + + # Update checkpoint after each conversation + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id)) + except Exception as e: + import traceback + + self.stderr.write( + f"Error processing conversation {conversation.id} for user {user.id}: {e}\n" + f"Traceback: {traceback.format_exc()}" + ) + continue + + return total_memories + + async def get_or_create_checkpoint(self, resume: bool) -> dict: + """Get or create checkpoint for resumable processing""" + checkpoint_key = "memory_generation_checkpoint" + + if resume: + # Try to retrieve existing checkpoint + checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst() + if checkpoint_store: + self.stdout.write("Resuming from checkpoint...") + return checkpoint_store.value + + # Create new checkpoint + return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}} + + async def update_checkpoint( + self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None + ): + """Update checkpoint with progress""" + if user_id and user_id not in checkpoint["processed_users"]: + checkpoint["processed_users"].append(user_id) + + if user_id and conversation_id: + if user_id not in checkpoint["processed_conversations"]: + checkpoint["processed_conversations"][user_id] = [] + if conversation_id not in checkpoint["processed_conversations"][user_id]: + checkpoint["processed_conversations"][user_id].append(conversation_id) + + # Save checkpoint to database + await DataStore.objects.aupdate_or_create( + key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True} + ) + + async def clear_checkpoint(self): + """Clear checkpoint after successful completion""" + await DataStore.objects.filter(key="memory_generation_checkpoint").adelete() + self.stdout.write("Checkpoint cleared") + + async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool): + """Handle deletion of user memories""" + from khoj.database.models import UserMemory + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + mode = "APPLY" if apply else "DRY RUN" + if cutoff_date: + # Calculate days from cutoff date + days_back = (timezone.now() - cutoff_date).days + self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users") + else: + self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users") + + total_deleted = 0 + for user in users: + # Count memories for this user + if cutoff_date is None: + user_memories = UserMemory.objects.filter(user=user) + else: + user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date) + + memories_count = await user_memories.acount() + if memories_count == 0: + self.stdout.write(f" User {user.username} has no memories to delete") + continue + + self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories") + + if apply: + # Delete memories for this user (with date filter if specified) + deleted_count, _ = await user_memories.adelete() + self.stdout.write(f" Deleted {deleted_count} memories") + total_deleted += deleted_count + else: + self.stdout.write(f" Would delete {memories_count} memories") + total_deleted += memories_count + + action = "Deleted" if apply else "Would delete" + self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total")) diff --git a/src/khoj/database/migrations/0099_usermemory.py b/src/khoj/database/migrations/0099_usermemory.py new file mode 100644 index 00000000..b1d6939a --- /dev/null +++ b/src/khoj/database/migrations/0099_usermemory.py @@ -0,0 +1,82 @@ +# Generated by Django 5.1.10 on 2025-08-29 22:57 + +import django.db.models.deletion +import pgvector.django +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0098_alter_texttoimagemodelconfig_model_type"), + ] + + operations = [ + migrations.CreateModel( + name="UserMemory", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("embeddings", pgvector.django.VectorField()), + ("raw", models.TextField()), + ( + "agent", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.agent", + ), + ), + ( + "search_model", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="database.searchmodelconfig", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="userconversationconfig", + name="enable_memory", + field=models.BooleanField(default=True), + ), + migrations.AddField( + model_name="serverchatsettings", + name="memory_mode", + field=models.CharField( + choices=[ + ("disabled", "Disabled"), + ("enabled_default_off", "Enabled, default off"), + ("enabled_default_on", "Enabled, default on"), + ], + default="enabled_default_on", + help_text="Server-level memory feature configuration. Disabled overrides user preference.", + max_length=20, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 404b0f38..56516b2d 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -478,6 +478,13 @@ class ServerChatSettings(DbBaseModel): THINK_PAID_FAST = "think_paid_fast" THINK_PAID_DEEP = "think_paid_deep" + class MemoryMode(models.TextChoices): + """Enum for server-level memory feature configuration""" + + DISABLED = "disabled", "Disabled" + ENABLED_DEFAULT_OFF = "enabled_default_off", "Enabled, default off" + ENABLED_DEFAULT_ON = "enabled_default_on", "Enabled, default on" + chat_default = models.ForeignKey( ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) @@ -506,6 +513,12 @@ class ServerChatSettings(DbBaseModel): unique=True, help_text="Priority of the server chat settings. Lower numbers run first.", ) + memory_mode = models.CharField( + max_length=20, + choices=MemoryMode.choices, + default=MemoryMode.ENABLED_DEFAULT_ON, + help_text="Server-level memory feature configuration. Disabled overrides user preference.", + ) def clean(self): error = {} @@ -629,6 +642,7 @@ class SpeechToTextModelOptions(DbBaseModel): class UserConversationConfig(DbBaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True) + enable_memory = models.BooleanField(default=True) class UserVoiceModelConfig(DbBaseModel): @@ -836,3 +850,15 @@ class McpServer(DbBaseModel): def __str__(self): return self.name + + +class UserMemory(DbBaseModel): + """ + Long term memory store derived from conversation between user and agent. + """ + + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True) + embeddings = VectorField(dimensions=None) + raw = models.TextField() + search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 611d41a8..0ed601ea 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1302,3 +1302,70 @@ user_name = PromptTemplate.from_template( User's Name: {name} """.strip() ) + +extract_facts_from_query = PromptTemplate.from_template( + """ +You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user. +This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember. + +You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action: +1. Create new facts +2. Delete existing facts + +You should delete existing facts that are no longer true. +You can enhance new facts with information from existing facts. +You cannot update existing facts directly, instead create new facts and delete related existing ones to update them. + +Your output should be a JSON object with two lists: create and delete. +- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective. +- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true. +- Leave the create or delete list empty if you have nothing important to add or remove. + +# Example +Existing Facts: +[ + {{ + "id": "5283", + "raw": "I am not interested in sports", + "updated_at": "2023-10-01T12:00:00+00:00" + }}, + {{ + "id": "22", + "raw": "I am a software engineer", + "updated_at": "2023-10-31T14:00:00+00:00" + }}, + {{ + "id": "651", + "raw": "My mother works at the hospital", + "updated_at": "2023-10-02T17:00:00+00:00" + }} +] + +Latest Chat Session: +- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline. +In between coding, I took my cat Whiskers out for a walk and played a game of football. +My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat. +- AI: That's great to hear! + +Response: +{{ + "create": [ + "I am interested in AI and machine learning", + "I have a pet cat named Whiskers", + "I enjoy playing football", + "My mother works at the hospital and is a doctor" + ], + "delete": [ + "5283", + "651" + ], +}} + +# Input +Existing Facts: +{matched_facts} + +Latest Chat Session: +{chat_history} +""".strip() +) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a2069243..dae182c2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -28,6 +28,7 @@ from khoj.database.models import ( ClientApplication, Intent, KhojUser, + UserMemory, ) from khoj.processor.conversation import prompts from khoj.search_filter.base_filter import BaseFilter @@ -551,6 +552,7 @@ async def save_to_conversation_log( client_application: ClientApplication = None, conversation_id: str = None, automation_id: str = None, + relevant_memories: List[UserMemory] = [], query_images: List[str] = None, raw_query_files: List[FileAttachment] = [], generated_images: List[str] = [], @@ -559,6 +561,8 @@ async def save_to_conversation_log( train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): + from khoj.routers.helpers import ai_update_memories + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") turn_id = tracer.get("mid") or str(uuid.uuid4()) @@ -605,6 +609,16 @@ async def save_to_conversation_log( user_message=q, ) + if not automation_id: + # Don't update memories from automations, as this could get noisy. + await ai_update_memories( + user=user, + conversation_history=new_messages or [], + memories=relevant_memories, + agent=db_conversation.agent if db_conversation else None, + tracer=tracer, + ) + if is_promptrace_enabled(): merge_message_into_conversation_trace(q, chat_response, tracer) @@ -666,6 +680,7 @@ def generate_chatml_messages_with_context( query_files: str = None, query_images=None, context_message="", + relevant_memories: List[UserMemory] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = [], chat_history: list[ChatMessageModel] = [], @@ -806,6 +821,14 @@ def generate_chatml_messages_with_context( ), ) + if not is_none_or_empty(relevant_memories): + memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n\n" + for memory in relevant_memories: + friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") + memory_context += f"- [{friendly_dt}]: {memory.raw}\n" + memory_context += "" + messages.append(ChatMessage(content=memory_context, role="user")) + if not is_none_or_empty(user_message): messages.append( ChatMessage( diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 45a640c1..38c2cd84 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -25,6 +25,7 @@ from khoj.database.models import ( Intent, KhojUser, TextToImageModelConfig, + UserMemory, ) from khoj.processor.conversation.google.utils import _is_retryable_error from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url @@ -47,6 +48,7 @@ async def text_to_image( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, agent: Agent = None, tracer: dict = {}, ): @@ -91,6 +93,7 @@ async def text_to_image( model_type=text_to_image_config.model_type, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer, diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 0aa12ca4..ccd780aa 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -5,7 +5,7 @@ import os from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters -from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser, UserMemory from khoj.processor.conversation.utils import ( AgentMessage, OperatorRun, @@ -42,6 +42,7 @@ async def operate_environment( query_images: Optional[List[str]] = None, # TODO: Handle query images agent: Agent = None, query_files: str = None, # TODO: Handle query files + relevant_memories: Optional[List[UserMemory]] = None, # TODO: Handle relevant memories cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, abort_message: Optional[str] = ChatEvent.END_EVENT.value, diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index beeb640c..96a587ae 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -14,6 +14,7 @@ from khoj.database.models import ( Agent, ChatMessageModel, KhojUser, + UserMemory, WebScraper, ) from khoj.routers.helpers import ( @@ -63,6 +64,7 @@ async def search_online( max_webpages_to_read: int = 1, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, previous_subqueries: Set = set(), fast_model: bool = True, agent: Agent = None, @@ -82,6 +84,7 @@ async def search_online( user, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, max_queries=max_online_searches, fast_model=fast_model, agent=agent, @@ -160,7 +163,15 @@ async def search_online( async for event in send_status_func(f"**Browsing**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - extract_from_webpage(link, data["queries"], data.get("content"), user=user, agent=agent, tracer=tracer) + extract_from_webpage( + link, + data["queries"], + data.get("content"), + relevant_memories=relevant_memories, + user=user, + agent=agent, + tracer=tracer, + ) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) @@ -438,6 +449,7 @@ async def read_webpages( fast_model: bool = True, agent: Agent = None, max_webpages_to_read: int = 1, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" @@ -460,6 +472,7 @@ async def read_webpages( user, send_status_func=send_status_func, agent=agent, + relevant_memories=relevant_memories, tracer=tracer, ): yield result @@ -471,6 +484,7 @@ async def read_webpages_content( user: KhojUser, send_status_func: Optional[Callable] = None, agent: Agent = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): logger.info(f"Reading web pages at: {urls}") @@ -478,7 +492,10 @@ async def read_webpages_content( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Browsing**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [extract_from_webpage(url, {query}, user=user, agent=agent, tracer=tracer) for url in urls] + tasks = [ + extract_from_webpage(url, {query}, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer) + for url in urls + ] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -533,6 +550,7 @@ async def extract_from_webpage( url: str, subqueries: set[str] = None, content: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -546,7 +564,9 @@ async def extract_from_webpage( extracted_info = None if not is_none_or_empty(content): with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent, tracer=tracer) + extracted_info = await extract_relevant_info( + subqueries, content, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer + ) return subqueries, url, extracted_info diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 2ec392ff..45c95115 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -21,7 +21,7 @@ from tenacity import ( ) from khoj.database.adapters import AgentAdapters, FileObjectAdapters -from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser +from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ChatEvent, @@ -59,6 +59,7 @@ async def run_code( send_status_func: Optional[Callable] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, agent: Agent = None, sandbox_url: str = SANDBOX_URL, tracer: dict = {}, @@ -79,6 +80,7 @@ async def run_code( agent, tracer, query_files, + relevant_memories, ) except Exception as e: raise ValueError(f"Failed to generate code for {instructions} with error: {e}") @@ -126,6 +128,7 @@ async def generate_python_code( agent: Agent = None, tracer: dict = {}, query_files: str = None, + relevant_memories: List[UserMemory] = None, ) -> GeneratedCode: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -161,6 +164,7 @@ async def generate_python_code( code_generation_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f7345d94..31950b1f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -13,7 +13,7 @@ from starlette.authentication import has_required_scope, requires from khoj.configure import initialize_content from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo -from khoj.database.models import KhojUser, SpeechToTextModelOptions +from khoj.database.models import KhojUser, SpeechToTextModelOptions, UserConversationConfig from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, @@ -215,6 +215,29 @@ def set_user_name( return {"status": "ok"} +@api.patch("/user/memory", status_code=200) +@requires(["authenticated"]) +def set_user_memory_enabled( + request: Request, + enable_memory: bool, + client: Optional[str] = None, +): + user = request.user.object + + user_config, _ = UserConversationConfig.objects.get_or_create(user=user) + user_config.enable_memory = enable_memory + user_config.save() + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_user_memory_enabled", + client=client, + ) + + return {"status": "ok", "enable_memory": enable_memory} + + @api.get("/health", response_class=Response) @requires(["authenticated"], status_code=200) def health_check(request: Request) -> Response: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 04cb84de..2e326690 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -29,6 +29,7 @@ from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, PublicConversationAdapters, + UserMemoryAdapters, aget_user_name, ) from khoj.database.models import Agent, KhojUser @@ -101,7 +102,6 @@ conversation_command_rate_limiter = ConversationCommandRateLimiter( trial_rate_limit=20, subscribed_rate_limit=75, slug="command" ) - api_chat = APIRouter() @@ -963,6 +963,14 @@ async def event_generator( location = LocationData(city=city, region=region, country=country, country_code=country_code) chat_history = conversation.messages + # Get most recent memories and long term relevant memories if memory is enabled + relevant_memories = [] + if await ConversationAdapters.ais_memory_enabled(user): + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + # Create a de-duped set of memories + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + # If interrupted message in DB if last_message := await conversation.pop_message(interrupted=True): # Populate context from interrupted message @@ -987,6 +995,7 @@ async def event_generator( query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ) except ValueError as e: @@ -1024,6 +1033,7 @@ async def event_generator( previous_iterations=list(research_results), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, user_name=user_name, location=location, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1081,6 +1091,7 @@ async def event_generator( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1130,6 +1141,7 @@ async def event_generator( max_online_searches=3, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1157,6 +1169,7 @@ async def event_generator( max_webpages_to_read=1, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1198,6 +1211,7 @@ async def event_generator( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1223,6 +1237,7 @@ async def event_generator( list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, send_status_func=partial(send_event, ChatEvent.STATUS), agent=agent, cancellation_event=cancellation_event, @@ -1276,6 +1291,7 @@ async def event_generator( send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1317,6 +1333,7 @@ async def event_generator( online_results=online_results, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1375,6 +1392,7 @@ async def event_generator( user_name, uploaded_images, attached_file_context, + relevant_memories, program_execution_context, generated_asset_results, is_subscribed, @@ -1434,6 +1452,7 @@ async def event_generator( query_images=uploaded_images, train_of_thought=train_of_thought, raw_query_files=raw_query_files, + relevant_memories=relevant_memories, generated_images=generated_images, generated_mermaidjs_diagram=generated_mermaidjs_diagram, tracer=tracer, diff --git a/src/khoj/routers/api_memories.py b/src/khoj/routers/api_memories.py new file mode 100644 index 00000000..d7527827 --- /dev/null +++ b/src/khoj/routers/api_memories.py @@ -0,0 +1,114 @@ +import json +import logging +from typing import Optional + +from asgiref.sync import sync_to_async +from fastapi import APIRouter, Request +from fastapi.responses import Response +from pydantic import BaseModel +from starlette.authentication import requires + +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import UserMemory + +api_memories = APIRouter() +logger = logging.getLogger(__name__) + + +@api_memories.get("") +@requires(["authenticated"]) +async def get_memories( + request: Request, + client: Optional[str] = None, +): + """Get all memories for the authenticated user""" + user = request.user.object + + memories = UserMemory.objects.filter(user=user) + all_memories = await sync_to_async(list)(memories) + + # Convert memories to a list of dictionaries + formatted_memories = [ + { + "id": memory.id, + "raw": memory.raw, + "created_at": memory.created_at.isoformat(), + } + for memory in all_memories + ] + + return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200) + + +@api_memories.delete("/{memory_id}") +@requires(["authenticated"]) +async def delete_memory( + request: Request, + memory_id: int, + client: Optional[str] = None, +): + """Delete a specific memory by ID""" + user = request.user.object + + # Verify memory belongs to user before deleting + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + await memory.adelete() + + return Response(status_code=204) + + +class UpdateMemoryBody(BaseModel): + """Request model for updating a memory""" + + raw: str + + +@api_memories.put("/{memory_id}") +@requires(["authenticated"]) +async def update_memory( + request: Request, + body: UpdateMemoryBody, + memory_id: int, + client: Optional[str] = None, +): + """Update a specific memory's content""" + user = request.user.object + + # Get the memory and verify it belongs to the user + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + new_content = body.raw + if not new_content: + return Response( + content=json.dumps({"error": "Missing required field 'raw'"}), + media_type="application/json", + status_code=400, + ) + + await memory.adelete() + + # Create a new memory with the updated content + new_memory = await UserMemoryAdapters.save_memory( + user=user, + memory=new_content, + ) + + return Response( + content=json.dumps( + { + "id": new_memory.id, + "raw": new_memory.raw, + } + ), + media_type="application/json", + status_code=200, + ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 1c9f9b33..29a0df0f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -46,6 +46,7 @@ from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, FileObjectAdapters, + UserMemoryAdapters, aget_user_by_email, create_khoj_token, get_default_search_model, @@ -53,6 +54,7 @@ from khoj.database.adapters import ( get_user_name, get_user_notion_config, get_user_subscription_state, + require_valid_user, run_with_process_lock, ) from khoj.database.models import ( @@ -66,8 +68,10 @@ from khoj.database.models import ( NotionConfig, ProcessLock, RateLimitRecord, + ServerChatSettings, Subscription, TextToImageModelConfig, + UserMemory, UserRequests, ) from khoj.processor.content.docx.docx_to_entries import DocxToEntries @@ -354,6 +358,7 @@ async def aget_data_sources_and_output_format( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Dict[str, Any]: """ @@ -415,6 +420,7 @@ async def aget_data_sources_and_output_format( relevant_tools_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=PickTools, fast_model=True, @@ -470,6 +476,7 @@ async def infer_webpage_urls( user: KhojUser, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, fast_model: bool = True, agent: Agent = None, tracer: dict = {}, @@ -506,6 +513,7 @@ async def infer_webpage_urls( online_queries_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=WebpageUrls, fast_model=fast_model, @@ -536,6 +544,7 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, max_queries: int = 3, fast_model: bool = True, agent: Agent = None, @@ -573,6 +582,7 @@ async def generate_online_subqueries( online_queries_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=OnlineQueries, fast_model=fast_model, @@ -659,7 +669,12 @@ async def aschedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} + qs: set[str], + corpus: str, + relevant_memories: List[UserMemory] = None, + user: KhojUser = None, + agent: Agent = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -683,6 +698,7 @@ async def extract_relevant_info( response = await send_message_to_model_wrapper( extract_relevant_information, system_message=prompts.system_prompt_extract_relevant_information, + relevant_memories=relevant_memories, fast_model=True, agent_chat_model=agent_chat_model, user=user, @@ -802,6 +818,7 @@ async def generate_excalidraw_diagram( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -819,6 +836,7 @@ async def generate_excalidraw_diagram( online_results=online_results, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer, @@ -854,6 +872,7 @@ async def generate_better_diagram_description( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -899,6 +918,7 @@ async def generate_better_diagram_description( improve_diagram_description_prompt, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, @@ -956,6 +976,88 @@ async def generate_excalidraw_diagram_from_description( return response +class MemoryUpdates(BaseModel): + """Facts to add or remove from memory.""" + + create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.") + delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.") + + +async def extract_facts_from_query( + user: KhojUser, + conversation_history: List[ChatMessageModel], + existing_facts: List[UserMemory] = None, + agent: Agent = None, + tracer: dict = {}, +) -> MemoryUpdates: + """ + Extract facts from the given query + """ + chat_history = construct_chat_history(conversation_history, n=2) + + formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else [] + + extract_facts_prompt = prompts.extract_facts_from_query.format( + chat_history=chat_history, + matched_facts=formatted_memories, + ) + + with timer("Chat actor: Extract facts from query", logger): + response = await send_message_to_model_wrapper( + extract_facts_prompt, + response_schema=MemoryUpdates, + user=user, + fast_model=False, + agent_chat_model=agent.chat_model, + tracer=tracer, + ) + response = response.text.strip() + # JSON parse the list of strings + try: + response = clean_json(response) + response = json.loads(response) + parsed_response = MemoryUpdates(**response) + if not isinstance(parsed_response, MemoryUpdates): + raise ValueError(f"Invalid response for extracting facts: {response}") + return parsed_response + + except Exception: + logger.error(f"Invalid response for extracting facts: {response}") + return MemoryUpdates(create=[], delete=[]) + + +@require_valid_user +async def ai_update_memories( + user: KhojUser, + conversation_history: List[ChatMessageModel], + memories: List[UserMemory], + agent: Agent, + tracer: dict = {}, +): + """ + Updates the memories for a given user, based on their latest input query. + """ + # Skip memory updates if memory is disabled for the user + if not await ConversationAdapters.ais_memory_enabled(user): + return + + memory_update = await extract_facts_from_query( + user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer + ) + + if not memory_update: + return + + # Save the memory updates to the database + for memory in memory_update.create: + logger.info(f"Creating memory: {memory}") + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + + for memory in memory_update.delete: + logger.info(f"Deleting memory: {memory}") + await UserMemoryAdapters.delete_memory(user, memory) + + async def generate_mermaidjs_diagram( q: str, chat_history: List[ChatMessageModel], @@ -964,6 +1066,7 @@ async def generate_mermaidjs_diagram( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -981,6 +1084,7 @@ async def generate_mermaidjs_diagram( online_results=online_results, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer, @@ -1010,6 +1114,7 @@ async def generate_better_mermaidjs_diagram_description( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -1055,6 +1160,7 @@ async def generate_better_mermaidjs_diagram_description( improve_diagram_description_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, @@ -1104,6 +1210,7 @@ async def generate_better_image_prompt( model_type: Optional[str] = None, query_images: Optional[List[str]] = None, query_files: str = "", + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -1148,6 +1255,7 @@ async def generate_better_image_prompt( q, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, chat_history=conversation_history, system_message=enhance_image_system_message, response_type="json_object", @@ -1180,6 +1288,7 @@ async def search_documents( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, previous_inferred_queries: Set = set(), fast_model: bool = True, agent: Agent = None, @@ -1230,6 +1339,7 @@ async def search_documents( user=user, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, personality_context=personality_context, location_data=location_data, chat_history=chat_history, @@ -1283,6 +1393,7 @@ async def extract_questions( query_files: str = None, query_images: Optional[List[str]] = None, personality_context: str = "", + relevant_memories: List[UserMemory] = None, location_data: LocationData = None, chat_history: List[ChatMessageModel] = [], max_queries: int = 5, @@ -1338,6 +1449,7 @@ async def extract_questions( query=prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, system_message=system_prompt, response_type="json_object", response_schema=DocumentQueries, @@ -1518,6 +1630,7 @@ async def send_message_to_model_wrapper( query_files: str = None, query_images: List[str] = None, context: str = "", + relevant_memories: List[UserMemory] = None, chat_history: list[ChatMessageModel] = [], system_message: str = "", # Model Config @@ -1568,6 +1681,7 @@ async def send_message_to_model_wrapper( query_files=query_files, query_images=query_images, context_message=context, + relevant_memories=relevant_memories, chat_history=chat_history, system_message=system_message, model_name=chat_model.name, @@ -1685,6 +1799,7 @@ def build_conversation_context( operator_results: List[OperatorRun], query_files: str = None, query_images: Optional[List[str]] = None, + relevant_memories: List[UserMemory] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, chat_history: List[ChatMessageModel] = [], @@ -1761,6 +1876,7 @@ def build_conversation_context( query_files=query_files, query_images=query_images, context_message=context_message, + relevant_memories=relevant_memories, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, chat_history=chat_history, @@ -1789,6 +1905,7 @@ async def agenerate_chat_response( user_name: Optional[str] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = [], program_execution_context: List[str] = [], generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, @@ -1831,6 +1948,7 @@ async def agenerate_chat_response( operator_results=operator_results, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, chat_history=chat_history, @@ -2792,6 +2910,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) selected_chat_model_config = ConversationAdapters.get_chat_model( user ) or ConversationAdapters.get_default_chat_model(user) + server_chat_settings = ServerChatSettings.objects.first() + server_memory_mode = ( + server_chat_settings.memory_mode + if server_chat_settings + else ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON.value # type: ignore[attr-defined] + ) + enable_memory = ConversationAdapters.is_memory_enabled(user) chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_model_options = list() for chat_model in chat_models: @@ -2848,6 +2973,8 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) "enabled_content_source": enabled_content_sources, "has_documents": has_documents, "notion_token": notion_token, + "enable_memory": enable_memory, + "server_memory_mode": server_memory_mode, # user model settings "chat_model_options": chat_model_options, "selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index e103b330..3185e8c0 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional import yaml from khoj.database.adapters import AgentAdapters, EntryAdapters, McpServerAdapters -from khoj.database.models import Agent, ChatMessageModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, @@ -285,6 +285,7 @@ async def apick_next_tool( max_iterations: int = 5, query_images: List[str] = [], query_files: str = None, + relevant_memories: List[UserMemory] = [], max_document_searches: int = 7, max_online_searches: int = 3, max_webpages_to_read: int = 3, @@ -416,6 +417,7 @@ async def apick_next_tool( query="", query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, system_message=function_planning_prompt, chat_history=chat_and_research_history, tools=tools, @@ -479,6 +481,7 @@ async def research( previous_iterations: List[ResearchIteration], query_images: List[str], query_files: str = None, + relevant_memories: List[UserMemory] = [], user_name: str = None, location: LocationData = None, send_status_func: Optional[Callable] = None, @@ -544,6 +547,7 @@ async def research( MAX_ITERATIONS, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, max_document_searches=max_document_searches, max_online_searches=max_online_searches, max_webpages_to_read=max_webpages_to_read, diff --git a/tests/helpers.py b/tests/helpers.py index 985ba00a..d5740d53 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,9 +4,12 @@ import os from datetime import datetime import factory +from asgiref.sync import sync_to_async from django.utils.timezone import make_aware +from khoj.database.adapters import AgentAdapters from khoj.database.models import ( + Agent, AiModelApi, ChatMessageModel, ChatModel, @@ -15,8 +18,10 @@ from khoj.database.models import ( KhojUser, ProcessLock, SearchModelConfig, + ServerChatSettings, Subscription, UserConversationConfig, + UserMemory, ) from khoj.processor.conversation.utils import message_to_log from khoj.utils.helpers import get_absolute_path, is_none_or_empty @@ -277,3 +282,45 @@ class ProcessLockFactory(factory.django.DjangoModelFactory): model = ProcessLock name = "test_lock" + + +class ServerChatSettingsFactory(factory.django.DjangoModelFactory): + class Meta: + model = ServerChatSettings + + memory_mode = ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON + + +# Async-safe wrappers for factories and ORM operations +async def acreate_user(): + return await sync_to_async(UserFactory)() + + +async def acreate_subscription(user): + return await sync_to_async(SubscriptionFactory)(user=user) + + +async def acreate_chat_model(): + return await sync_to_async(ChatModelFactory)() + + +async def acreate_default_agent(): + return await sync_to_async(AgentAdapters.create_default_agent)() + + +async def acreate_agent(name, chat_model, personality): + return await sync_to_async(Agent.objects.create)( + name=name, + chat_model=chat_model, + personality=personality, + ) + + +async def acreate_test_memory(user, agent=None, raw_text="test memory"): + """Create a memory directly in DB without embeddings for testing.""" + return await sync_to_async(UserMemory.objects.create)( + user=user, + agent=agent, + raw=raw_text, + embeddings=[0.1] * 384, # Dummy embeddings + ) diff --git a/tests/test_memory_settings.py b/tests/test_memory_settings.py new file mode 100644 index 00000000..28621670 --- /dev/null +++ b/tests/test_memory_settings.py @@ -0,0 +1,479 @@ +""" +Tests for memory enable/disable settings and memory scoping by user+agent. + +These tests verify: +1. The behavior of ConversationAdapters.is_memory_enabled() for different combinations of: + - ServerChatSettings.memory_mode (DISABLED, ENABLED_DEFAULT_OFF, ENABLED_DEFAULT_ON) + - UserConversationConfig.enable_memory (True, False, or not set) + +2. Memory scoping by user and agent: + - Memories are scoped to user + agent + - Default agent has access to ALL memories across all agents for a user + - Non-default agents only see their own memories +""" + +import pytest +from unittest.mock import MagicMock + +from khoj.database.adapters import ConversationAdapters, UserMemoryAdapters +from khoj.database.models import ServerChatSettings, UserConversationConfig +from khoj.routers.helpers import get_user_config +from tests.helpers import ( + acreate_user, + acreate_subscription, + acreate_chat_model, + acreate_default_agent, + acreate_agent, + acreate_test_memory, + ServerChatSettingsFactory, + SubscriptionFactory, + UserFactory, +) + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with no server config (default behavior) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_enabled_no_server_config_no_user_config(): + """When no server config and no user config exists, memory should be enabled (default on).""" + user = UserFactory() + SubscriptionFactory(user=user) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_no_server_config_user_enabled(): + """When no server config but user has explicitly enabled memory.""" + user = UserFactory() + SubscriptionFactory(user=user) + user_config = UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_no_server_config_user_disabled(): + """When no server config but user has explicitly disabled memory.""" + user = UserFactory() + SubscriptionFactory(user=user) + user_config = UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with server mode DISABLED +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_disabled_server_disabled_no_user_config(): + """When server disables memory, it should override everything - no user config.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +@pytest.mark.django_db +def test_memory_disabled_server_disabled_user_enabled(): + """When server disables memory, it should override user preference (enabled).""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +@pytest.mark.django_db +def test_memory_disabled_server_disabled_user_disabled(): + """When server disables memory, user disabled too - should be disabled.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with server mode ENABLED_DEFAULT_OFF +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_enabled_default_off_no_user_config(): + """When server is enabled_default_off and no user config, memory should be off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +@pytest.mark.django_db +def test_memory_enabled_default_off_user_enabled(): + """When server is enabled_default_off and user opts in, memory should be on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_default_off_user_disabled(): + """When server is enabled_default_off and user explicitly disabled, memory should be off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with server mode ENABLED_DEFAULT_ON +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_enabled_default_on_no_user_config(): + """When server is enabled_default_on and no user config, memory should be on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_default_on_user_enabled(): + """When server is enabled_default_on and user enabled, memory should be on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_default_on_user_disabled(): + """When server is enabled_default_on and user opts out, memory should be off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test get_user_config returns correct enable_memory and server_memory_mode +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_get_user_config_memory_no_server_config(): + """get_user_config should return default values when no server config.""" + user = UserFactory() + SubscriptionFactory(user=user) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is True + assert config["server_memory_mode"] == "enabled_default_on" + + +@pytest.mark.django_db +def test_get_user_config_memory_server_disabled(): + """get_user_config should reflect server disabled mode.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is False + assert config["server_memory_mode"] == "disabled" + + +@pytest.mark.django_db +def test_get_user_config_memory_server_enabled_default_off_user_opted_in(): + """get_user_config should show user opted in when server is default off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + UserConversationConfig.objects.create(user=user, enable_memory=True) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is True + assert config["server_memory_mode"] == "enabled_default_off" + + +@pytest.mark.django_db +def test_get_user_config_memory_server_enabled_default_on_user_opted_out(): + """get_user_config should show user opted out when server is default on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + UserConversationConfig.objects.create(user=user, enable_memory=False) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is False + assert config["server_memory_mode"] == "enabled_default_on" + + +# ---------------------------------------------------------------------------------------------------- +# Test memory scoping by user and agent +# ---------------------------------------------------------------------------------------------------- + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_pull_memories_default_agent_sees_all_memories(): + """Default agent should see ALL memories for the user, including those from other agents.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create a custom agent + custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent") + + # Create memories for different agents + await acreate_test_memory(user, agent=None, raw_text="memory without agent") + await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent") + await acreate_test_memory(user, agent=custom_agent, raw_text="memory for custom agent") + + # Act: Pull memories with default agent + memories = await UserMemoryAdapters.pull_memories(user=user, agent=default_agent) + + # Assert: Default agent sees ALL memories + memory_texts = [m.raw for m in memories] + assert "memory without agent" in memory_texts + assert "memory for default agent" in memory_texts + assert "memory for custom agent" in memory_texts + assert len(memories) == 3 + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_pull_memories_custom_agent_sees_only_own_memories(): + """Custom (non-default) agent should only see its own memories.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create custom agents + custom_agent_1 = await acreate_agent("Custom Agent 1", chat_model, "First custom agent") + custom_agent_2 = await acreate_agent("Custom Agent 2", chat_model, "Second custom agent") + + # Create memories for different agents + await acreate_test_memory(user, agent=None, raw_text="memory without agent") + await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent") + await acreate_test_memory(user, agent=custom_agent_1, raw_text="memory for custom agent 1") + await acreate_test_memory(user, agent=custom_agent_2, raw_text="memory for custom agent 2") + + # Act: Pull memories with custom_agent_1 + memories = await UserMemoryAdapters.pull_memories(user=user, agent=custom_agent_1) + + # Assert: Custom agent 1 only sees its own memories + memory_texts = [m.raw for m in memories] + assert "memory for custom agent 1" in memory_texts + assert "memory without agent" not in memory_texts + assert "memory for default agent" not in memory_texts + assert "memory for custom agent 2" not in memory_texts + assert len(memories) == 1 + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_pull_memories_no_agent_same_as_default_agent(): + """Pulling memories with agent=None should behave same as default agent (see all).""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create a custom agent + custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent") + + # Create memories + await acreate_test_memory(user, agent=None, raw_text="memory without agent") + await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent") + await acreate_test_memory(user, agent=custom_agent, raw_text="memory for custom agent") + + # Act: Pull memories with agent=None + memories = await UserMemoryAdapters.pull_memories(user=user, agent=None) + + # Assert: Should see all memories (same as default agent behavior) + memory_texts = [m.raw for m in memories] + assert "memory without agent" in memory_texts + assert "memory for default agent" in memory_texts + assert "memory for custom agent" in memory_texts + assert len(memories) == 3 + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_save_memory_with_custom_agent_scopes_to_agent(): + """Memories saved with a custom agent should be scoped to that agent.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create custom agent + custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent") + + # Create memory with custom agent (directly in DB to avoid embeddings) + memory = await acreate_test_memory(user, agent=custom_agent, raw_text="custom agent memory") + + # Assert: Memory is scoped to the custom agent + assert memory.agent == custom_agent + assert memory.user == user + + # Verify custom agent can see it + custom_memories = await UserMemoryAdapters.pull_memories(user=user, agent=custom_agent) + assert len(custom_memories) == 1 + assert custom_memories[0].raw == "custom agent memory" + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_save_memory_with_default_agent_has_no_agent_scope(): + """Memories saved with default agent should have agent=None (global scope).""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + await acreate_chat_model() # Required for default agent creation + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create memory with default agent (directly in DB) + # Based on save_memory logic: if agent == default_agent, agent is not set + memory = await acreate_test_memory(user, agent=None, raw_text="default agent memory") + + # Assert: Memory has no agent (global scope) + assert memory.agent is None + assert memory.user == user + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_memories_isolated_between_users(): + """Memories should be isolated between different users.""" + # Setup + user1 = await acreate_user() + user2 = await acreate_user() + await acreate_subscription(user1) + await acreate_subscription(user2) + + # Create default agent + await acreate_default_agent() + + # Create memories for each user + await acreate_test_memory(user1, agent=None, raw_text="user1 memory") + await acreate_test_memory(user2, agent=None, raw_text="user2 memory") + + # Act: Pull memories for each user + user1_memories = await UserMemoryAdapters.pull_memories(user=user1) + user2_memories = await UserMemoryAdapters.pull_memories(user=user2) + + # Assert: Each user only sees their own memories + assert len(user1_memories) == 1 + assert user1_memories[0].raw == "user1 memory" + + assert len(user2_memories) == 1 + assert user2_memories[0].raw == "user2 memory" + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_custom_agent_cannot_see_other_custom_agent_memories(): + """One custom agent should not see another custom agent's memories.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + await acreate_default_agent() + + # Create two custom agents + agent_accountant = await acreate_agent("Accountant", chat_model, "Financial advisor") + agent_chef = await acreate_agent("Chef", chat_model, "Cooking expert") + + # Create memories for each agent + await acreate_test_memory(user, agent=agent_accountant, raw_text="user spent $500 on groceries") + await acreate_test_memory(user, agent=agent_chef, raw_text="user likes Italian food") + + # Act & Assert: Accountant only sees financial memories + accountant_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent_accountant) + assert len(accountant_memories) == 1 + assert accountant_memories[0].raw == "user spent $500 on groceries" + + # Act & Assert: Chef only sees food memories + chef_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent_chef) + assert len(chef_memories) == 1 + assert chef_memories[0].raw == "user likes Italian food"