diff --git a/src/interface/web/app/common/auth.ts b/src/interface/web/app/common/auth.ts index defbf536..6344fbfe 100644 --- a/src/interface/web/app/common/auth.ts +++ b/src/interface/web/app/common/auth.ts @@ -63,6 +63,8 @@ export interface UserConfig { enabled_content_source: SyncedContent; has_documents: boolean; notion_token: string | null; + enable_memory: boolean; + server_memory_mode: "disabled" | "enabled_default_off" | "enabled_default_on"; // user model settings search_model_options: ModelOptions[]; selected_search_model_config: number; diff --git a/src/interface/web/app/components/userMemory/userMemory.tsx b/src/interface/web/app/components/userMemory/userMemory.tsx new file mode 100644 index 00000000..db360e7b --- /dev/null +++ b/src/interface/web/app/components/userMemory/userMemory.tsx @@ -0,0 +1,90 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Button } from "@/components/ui/button"; +import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react"; +import { useToast } from "@/components/ui/use-toast"; + +export interface UserMemorySchema { + id: number; + raw: string; + created_at: string; +} + +interface UserMemoryProps { + memory: UserMemorySchema; + onDelete: (id: number) => void; + onUpdate: (id: number, raw: string) => void; +} + +export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) { + const [isEditing, setIsEditing] = useState(false); + const [content, setContent] = useState(memory.raw); + const { toast } = useToast(); + + const handleUpdate = () => { + onUpdate(memory.id, content); + setIsEditing(false); + toast({ + title: "Memory Updated", + description: "Your memory has been successfully updated.", + }); + }; + + const handleDelete = () => { + onDelete(memory.id); + toast({ + title: "Memory Deleted", + description: "Your memory has been successfully deleted.", + }); + }; + + return ( +
+ {isEditing ? ( + <> + setContent(e.target.value)} + className="flex-1" + /> + + + + ) : ( + <> + + + + + )} +
+ ); +} diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index ad8b88f5..75ea7ee2 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -15,6 +15,8 @@ import { Button } from "@/components/ui/button"; import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp"; import { Input } from "@/components/ui/input"; import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card"; +import { Switch } from "@/components/ui/switch"; + import { DropdownMenu, DropdownMenuContent, @@ -33,6 +35,15 @@ import { AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; + +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog"; + import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table"; import { @@ -74,6 +85,7 @@ import Loading from "../components/loading/loading"; import IntlTelInput from "intl-tel-input/react"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; +import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory"; import { Separator } from "@/components/ui/separator"; import { KhojLogoType } from "../components/logo/khojLogo"; import { Progress } from "@/components/ui/progress"; @@ -323,6 +335,9 @@ export default function SettingsView() { const [numberValidationState, setNumberValidationState] = useState( PhoneNumberValidationState.Verified, ); + const [memories, setMemories] = useState([]); + const [enableMemory, setEnableMemory] = useState(true); + const [serverMemoryMode, setServerMemoryMode] = useState("enabled_default_on"); const [isExporting, setIsExporting] = useState(false); const [exportProgress, setExportProgress] = useState(0); const [exportedConversations, setExportedConversations] = useState(0); @@ -347,6 +362,8 @@ export default function SettingsView() { ); setName(initialUserConfig?.given_name); setNotionToken(initialUserConfig?.notion_token ?? null); + setEnableMemory(initialUserConfig?.enable_memory ?? true); + setServerMemoryMode(initialUserConfig?.server_memory_mode ?? "enabled_default_on"); }, [initialUserConfig]); const sendOTP = async () => { @@ -621,6 +638,88 @@ export default function SettingsView() { } }; + const fetchMemories = async () => { + try { + console.log("Fetching memories..."); + const response = await fetch('/api/memories/'); + if (!response.ok) throw new Error('Failed to fetch memories'); + const data = await response.json(); + setMemories(data); + } catch (error) { + console.error('Error fetching memories:', error); + toast({ + title: "Error", + description: "Failed to fetch memories. Please try again.", + variant: "destructive" + }); + } + }; + + const handleDeleteMemory = async (id: number) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'DELETE' + }); + if (!response.ok) throw new Error('Failed to delete memory'); + setMemories(memories.filter(memory => memory.id !== id)); + } catch (error) { + console.error('Error deleting memory:', error); + toast({ + title: "Error", + description: "Failed to delete memory. Please try again.", + variant: "destructive" + }); + } + }; + + const handleUpdateMemory = async (id: number, raw: string) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ raw, memory_id: id }), + }); + if (!response.ok) throw new Error('Failed to update memory'); + const updatedMemory: UserMemorySchema = await response.json(); + setMemories(memories.map(memory => + memory.id === id ? updatedMemory : memory + )); + } catch (error) { + console.error('Error updating memory:', error); + toast({ + title: "Error", + description: "Failed to update memory. Please try again.", + variant: "destructive" + }); + } + }; + + const handleToggleMemory = async (enabled: boolean) => { + try { + const response = await fetch(`/api/user/memory?enable_memory=${enabled}`, { + method: 'PATCH', + }); + if (!response.ok) throw new Error('Failed to update memory setting'); + setEnableMemory(enabled); + toast({ + title: enabled ? "Memory enabled" : "Memory disabled", + description: enabled + ? "Khoj will learn and remember from your conversations." + : "Khoj will no longer learn or remember from your conversations.", + }); + } catch (error) { + console.error('Error toggling memory:', error); + toast({ + title: "Error", + description: "Failed to update memory setting. Please try again.", + variant: "destructive" + }); + } + }; + + const syncContent = async (type: string) => { try { const response = await fetch(`/api/content?t=${type}`, { @@ -1212,7 +1311,64 @@ export default function SettingsView() { - + + + + Memories + + +

+ View and manage your long-term memories +

+
+ + handleToggleMemory(checked)} + disabled={serverMemoryMode === "disabled"} + /> +
+ {serverMemoryMode === "disabled" && ( +

+ Memory has been disabled by the server administrator. +

+ )} +
+ + open && fetchMemories()}> + + + + + + Your Memories + +
+ {memories.map((memory) => ( + + ))} + {memories.length === 0 && ( +

No memories found

+ )} +
+
+
+
+
diff --git a/src/interface/web/bun.lock b/src/interface/web/bun.lock index 36f1d3cf..9287dd9f 100644 --- a/src/interface/web/bun.lock +++ b/src/interface/web/bun.lock @@ -24,6 +24,7 @@ "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-slot": "^1.2.3", + "@radix-ui/react-switch": "^1.2.6", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-toast": "^1.2.15", "@radix-ui/react-toggle": "^1.1.10", @@ -274,7 +275,7 @@ "@radix-ui/react-slot": ["@radix-ui/react-slot@1.2.3", "", { "dependencies": { "@radix-ui/react-compose-refs": "1.1.2" }, "peerDependencies": { "@types/react": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react"] }, "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A=="], - "@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="], + "@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.6", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-bByzr1+ep1zk4VubeEVViV592vu2lHE2BZY5OnzehZqOOgogN80+mNtCqPkhn2gklJqOpxWgPoYTSnhBCqpOXQ=="], "@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.13", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.5", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.11", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A=="], @@ -1478,8 +1479,6 @@ "@radix-ui/react-slider/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="], - "@radix-ui/react-switch/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="], - "@radix-ui/react-toggle-group/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="], "@radix-ui/react-toggle-group/@radix-ui/react-roving-focus": ["@radix-ui/react-roving-focus@1.1.10", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q=="], @@ -1576,6 +1575,8 @@ "radix-ui/@radix-ui/react-select": ["@radix-ui/react-select@2.2.5", "", { "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-popper": "1.2.7", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA=="], + "radix-ui/@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="], + "radix-ui/@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.12", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-GTVAlRVrQrSw3cEARM0nAx73ixrWDPNZAruETn3oHCNP6SbZ/hNxdxp+u7VkIEv3/sFoLq1PfcHrl7Pnp0CDpw=="], "radix-ui/@radix-ui/react-toast": ["@radix-ui/react-toast@1.2.14", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-nAP5FBxBJGQ/YfUB+r+O6USFVkWq3gAInkxyEnmvEV5jtSbfDhfa4hwX8CraCnbjMLsE7XSf/K75l9xXY7joWg=="], diff --git a/src/interface/web/components/ui/switch.tsx b/src/interface/web/components/ui/switch.tsx new file mode 100644 index 00000000..812af0df --- /dev/null +++ b/src/interface/web/components/ui/switch.tsx @@ -0,0 +1,29 @@ +"use client" + +import * as React from "react" +import * as SwitchPrimitives from "@radix-ui/react-switch" + +import { cn } from "@/lib/utils" + +const Switch = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)) +Switch.displayName = SwitchPrimitives.Root.displayName + +export { Switch } diff --git a/src/interface/web/package.json b/src/interface/web/package.json index 3eceac6f..ccbbd4bf 100644 --- a/src/interface/web/package.json +++ b/src/interface/web/package.json @@ -38,6 +38,7 @@ "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-slot": "^1.2.3", + "@radix-ui/react-switch": "^1.2.6", "@radix-ui/react-tabs": "^1.1.13", "@radix-ui/react-toast": "^1.2.15", "@radix-ui/react-toggle": "^1.1.10", diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 40d1eeb5..50b2dd31 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -323,6 +323,7 @@ def configure_routes(app): from khoj.routers.api_automation import api_automation from khoj.routers.api_chat import api_chat from khoj.routers.api_content import api_content + from khoj.routers.api_memories import api_memories from khoj.routers.api_model import api_model from khoj.routers.notion import notion_router from khoj.routers.web_client import web_client @@ -332,6 +333,7 @@ def configure_routes(app): app.include_router(api_agents, prefix="/api/agents") app.include_router(api_automation, prefix="/api/automation") app.include_router(api_model, prefix="/api/model") + app.include_router(api_memories, prefix="/api/memories") app.include_router(api_content, prefix="/api/content") app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 3ab08ef2..abcc226f 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -62,6 +62,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserTextToImageModelConfig, UserVoiceModelConfig, @@ -566,6 +567,16 @@ def get_default_search_model() -> SearchModelConfig: return SearchModelConfig.objects.first() +async def aget_default_search_model() -> SearchModelConfig: + default_search_model = await SearchModelConfig.objects.filter(name="default").afirst() + + if default_search_model: + return default_search_model + elif await SearchModelConfig.objects.count() == 0: + await SearchModelConfig.objects.acreate() + return await SearchModelConfig.objects.afirst() + + def get_or_create_search_models(): search_models = SearchModelConfig.objects.all() if search_models.count() == 0: @@ -1556,12 +1567,17 @@ class ConversationAdapters: ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = await Conversation.objects.filter( - user=user, client=client_application, id=conversation_id - ).afirst() + conversation = ( + await Conversation.objects.filter(user=user, client=client_application, id=conversation_id) + .prefetch_related("agent", "agent__chat_model") + .afirst() + ) else: conversation = ( - await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() + await Conversation.objects.filter(user=user, client=client_application) + .prefetch_related("agent", "agent__chat_model") + .order_by("-updated_at") + .afirst() ) existing_messages = conversation.messages if conversation else [] @@ -1597,6 +1613,85 @@ class ConversationAdapters: return None return config.setting + @staticmethod + async def ais_memory_enabled(user: KhojUser) -> bool: + """ + Check if memory is enabled for the user based on server config and user preference. + + Logic: + - If server memory_mode is DISABLED: return False (overrides user preference) + - If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False + - If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True + - If no server config exists: default to True + """ + # Get server-level memory configuration + server_settings = await ServerChatSettings.objects.afirst() + if server_settings: + memory_mode = server_settings.memory_mode + # Disabled mode overrides all user preferences + if memory_mode == ServerChatSettings.MemoryMode.DISABLED: + return False + + # Get user preference + user_config = await UserConversationConfig.objects.filter(user=user).afirst() + + if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF: + # User must explicitly opt-in; check if user has set a preference + if user_config is None: + return False # Default off for new users + return user_config.enable_memory + + # ENABLED_DEFAULT_ON: use user preference if set, else True + if user_config is None: + return True # Default on for new users + return user_config.enable_memory + + # No server config - default behavior (enabled, default on) + user_config = await UserConversationConfig.objects.filter(user=user).afirst() + if user_config is None: + return True + return user_config.enable_memory + + @staticmethod + def is_memory_enabled(user: KhojUser) -> bool: + """ + Sync version of ais_memory_enabled. + Check if memory is enabled for the user based on server config and user preference. + + Logic: + - If server memory_mode is DISABLED: return False (overrides user preference) + - If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False + - If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True + - If no server config exists: default to True + """ + # Get server-level memory configuration + server_settings = ServerChatSettings.objects.first() + if server_settings: + memory_mode = server_settings.memory_mode + # Disabled mode overrides all user preferences + if memory_mode == ServerChatSettings.MemoryMode.DISABLED: + return False + + # Get user preference + user_config = UserConversationConfig.objects.filter(user=user).first() + + if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF: + # User must explicitly opt-in; check if user has set a preference + if user_config is None: + return False # Default off for new users + return user_config.enable_memory + + # ENABLED_DEFAULT_ON: use user preference if set, else True + if user_config is None: + return True # Default on for new users + return user_config.enable_memory + + # No server config - default behavior (enabled, default on) + user_config = UserConversationConfig.objects.filter(user=user).first() + if user_config is None: + return True + return user_config.enable_memory + @staticmethod async def get_speech_to_text_config(): return await SpeechToTextModelOptions.objects.filter().prefetch_related("ai_model_api").afirst() @@ -2186,3 +2281,94 @@ class McpServerAdapters: except Exception as e: logger.error(f"Error retrieving MCP servers: {e}", exc_info=True) return servers + + +class UserMemoryAdapters: + @staticmethod + @require_valid_user + async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]: + """ + Pulls memories from the database for a given user. Medium term memory. + """ + time_frame = datetime.now(timezone.utc) - timedelta(days=window) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by( + "-created_at" + )[:limit] + else: + memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit] + return await sync_to_async(list)(memories) + + @staticmethod + @require_valid_user + async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory: + """ + Saves a memory to the database for a given user. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent + ) + else: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model + ) + + return memory_instance + + @staticmethod + @require_valid_user + async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]: + """ + Searches for memories in the database for a given user. Long term memory. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + max_distance = model.bi_encoder_confidence_threshold or math.inf + embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query) + default_agent = await AgentAdapters.aget_default_agent() + + if agent and agent != default_agent: + relevant_memories = UserMemory.objects.filter(user=user, agent=agent) + else: + relevant_memories = UserMemory.objects.filter(user=user) + + relevant_memories = ( + relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query)) + .order_by("distance") + .filter(distance__lte=max_distance) + ) + + return await sync_to_async(list)(relevant_memories[:limit]) + + @staticmethod + @require_valid_user + async def delete_memory(user: KhojUser, memory_id: str) -> bool: + """ + Deletes a memory from the database for a given user. + """ + try: + memory = await UserMemory.objects.aget(user=user, id=memory_id) + await memory.adelete() + return True + except UserMemory.DoesNotExist: + return False + + @staticmethod + def to_dict(memories: List[UserMemory]) -> List[dict]: + """ + Converts a list of Memory objects to a list of dictionaries. + """ + return [ + { + "id": f"{memory.id}", + "raw": memory.raw, + "updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"), + } + for memory in memories + ] diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 872c57ca..345b03fd 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -34,6 +34,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserVoiceModelConfig, VoiceModelOption, @@ -182,6 +183,7 @@ admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin) admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin) admin.site.register(UserRequests, unfold_admin.ModelAdmin) admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin) +admin.site.register(UserMemory, unfold_admin.ModelAdmin) @admin.register(McpServer) @@ -295,6 +297,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): "think_paid_fast", "think_paid_deep", "web_scraper", + "memory_mode", ) ordering = ("priority",) diff --git a/src/khoj/database/management/commands/manage_memories.py b/src/khoj/database/management/commands/manage_memories.py new file mode 100644 index 00000000..62e14c97 --- /dev/null +++ b/src/khoj/database/management/commands/manage_memories.py @@ -0,0 +1,384 @@ +import asyncio +from datetime import datetime, timedelta +from typing import List, Optional + +from django.core.management.base import BaseCommand +from django.db.models import Q +from django.utils import timezone + +from khoj.configure import initialize_server +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import ( + Conversation, + DataStore, + KhojUser, +) +from khoj.routers.helpers import extract_facts_from_query + + +class Command(BaseCommand): + help = "Manage user memories - generate from conversations or delete existing memories" + + def add_arguments(self, parser): + parser.add_argument( + "--lookback-days", + type=int, + default=None, + help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories", + ) + parser.add_argument( + "--users", + type=str, + help="Process specific users (comma-separated usernames or emails)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=10, + help="Number of conversations to process in each batch (default: 10)", + ) + parser.add_argument( + "--apply", + action="store_true", + help="Actually perform the operation. Without this flag, only shows what would be processed.", + ) + parser.add_argument( + "--delete", + action="store_true", + help="Delete all memories for specified users instead of generating new ones", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from last checkpoint if process was interrupted", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force regenerate memories even if already processed", + ) + + def handle(self, *args, **options): + """Main entry point for the command""" + initialize_server() + asyncio.run(self.async_handle(*args, **options)) + + async def async_handle(self, *args, **options): + """Async handler for memory management""" + lookback_days = options["lookback_days"] + usernames = options["users"] + batch_size = options["batch_size"] + apply = options["apply"] + delete = options["delete"] + resume = options["resume"] + force = options["force"] + + mode = "APPLY" if apply else "DRY RUN" + + # Handle deletion mode + if delete: + # For deletion, only use cutoff_date if lookback_days is explicitly provided + cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None + await self.handle_delete_memories(usernames, cutoff_date, apply) + return + + # Handle generation mode + # For generation, default to 7 days if not specified + if lookback_days is None: + lookback_days = 7 + cutoff_date = timezone.now() - timedelta(days=lookback_days) + self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days") + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + self.stdout.write(f"Found {len(users)} users to process") + + # Initialize or retrieve checkpoint + checkpoint = await self.get_or_create_checkpoint(resume) + + total_conversations = 0 + total_memories = 0 + + for user in users: + # Check if user already processed in checkpoint + if not force and user.id in checkpoint.get("processed_users", []): + self.stdout.write(f"Skipping already processed user: {user.username}") + continue + + self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})") + + # Get conversations for this user + conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force) + + if not conversations: + self.stdout.write(f" No conversations to process for {user.username}") + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + continue + + self.stdout.write(f" Found {len(conversations)} conversations to process") + + # Process conversations in batches + user_memories = 0 + for i in range(0, len(conversations), batch_size): + batch = conversations[i : i + batch_size] + batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint) + user_memories += batch_memories + total_conversations += len(batch) + + # Update progress + progress = min(i + batch_size, len(conversations)) + self.stdout.write( + f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories" + ) + + total_memories += user_memories + self.stdout.write( + f" Completed user {user.username}: " + f"processed {len(conversations)} conversations, " + f"generated {user_memories} memories" + ) + + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + + # Clear checkpoint on successful completion + if apply: + await self.clear_checkpoint() + + action = "Generated" if apply else "Would generate" + self.stdout.write( + self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations") + ) + + async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]: + """Get list of users to comma separated usernames or emails to process""" + if users_str: + usernames = [u.strip() for u in users_str.split(",") if u.strip()] + # Process specific users + users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))] + return users + else: + # Process all users with conversations + return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()] + + async def get_user_conversations( + self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool + ) -> List[Conversation]: + """Get conversations for a user that need processing""" + if cutoff_date is None: + query = Conversation.objects.filter(user=user).order_by("updated_at") + else: + query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at") + + # Filter out already processed conversations if resuming + if not force and user.id in checkpoint.get("processed_conversations", {}): + processed_ids = checkpoint["processed_conversations"][user.id] + query = query.exclude(id__in=processed_ids) + + return [conv async for conv in query] + + async def process_conversation_batch( + self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict + ) -> int: + """Process a batch of conversations and generate memories""" + total_memories = 0 + + for conversation in conversations: + try: + # Get conversation messages using sync_to_async for property access + from asgiref.sync import sync_to_async + + # Access conversation_log synchronously + @sync_to_async + def get_messages(): + return conversation.messages + + messages = await get_messages() + if not messages: + continue + + # Get agent if conversation has one + @sync_to_async + def get_agent(): + return conversation.agent + + agent = await get_agent() + + # Get existing memories for context + # Process each conversation turn + conversation_memories = 0 + i = 0 + while i + 1 < len(messages): + # Only process user-assistant pairs as a valid turn for memory extraction + if messages[i].by != "you" or messages[i + 1].by != "khoj": + i += 1 + continue + + # Get the conversation history up to this point + history = messages[: i + 2] + + # Extract user query text for memory search + q = "" + if messages[i].message is None: + i += 1 + continue + elif isinstance(messages[i].message, str): + q = messages[i].message + elif isinstance(messages[i].message, list): + q = "\n\n".join( + content.get("text", "") + for content in messages[i].message + if isinstance(content, dict) and content.get("text") + ) + + if not q or not q.strip(): + i += 1 + continue + + # Get unique recent and long term relevant memories + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + + if apply: + # Ensure agent is fully loaded with its chat_model + if agent: + + @sync_to_async + def load_agent_with_chat_model(): + # Force load the chat_model relationship + _ = agent.chat_model + return agent + + agent = await load_agent_with_chat_model() + + # Update memories based on latest conversation turn + memory_updates = await extract_facts_from_query( + user=user, + conversation_history=history, + existing_facts=relevant_memories, + agent=agent, + tracer={}, + ) + + # Save new memories + for memory in memory_updates.create: + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + conversation_memories += 1 + self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...") + + # Delete outdated memories + for memory in memory_updates.delete: + await UserMemoryAdapters.delete_memory(user, memory) + self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...") + else: + # Dry run - estimate memories that would be created + conversation_memories += 1 # Rough estimate + + # Move to next conversation turn pair + i += 2 + + total_memories += conversation_memories + + # Update checkpoint after each conversation + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id)) + except Exception as e: + import traceback + + self.stderr.write( + f"Error processing conversation {conversation.id} for user {user.id}: {e}\n" + f"Traceback: {traceback.format_exc()}" + ) + continue + + return total_memories + + async def get_or_create_checkpoint(self, resume: bool) -> dict: + """Get or create checkpoint for resumable processing""" + checkpoint_key = "memory_generation_checkpoint" + + if resume: + # Try to retrieve existing checkpoint + checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst() + if checkpoint_store: + self.stdout.write("Resuming from checkpoint...") + return checkpoint_store.value + + # Create new checkpoint + return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}} + + async def update_checkpoint( + self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None + ): + """Update checkpoint with progress""" + if user_id and user_id not in checkpoint["processed_users"]: + checkpoint["processed_users"].append(user_id) + + if user_id and conversation_id: + if user_id not in checkpoint["processed_conversations"]: + checkpoint["processed_conversations"][user_id] = [] + if conversation_id not in checkpoint["processed_conversations"][user_id]: + checkpoint["processed_conversations"][user_id].append(conversation_id) + + # Save checkpoint to database + await DataStore.objects.aupdate_or_create( + key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True} + ) + + async def clear_checkpoint(self): + """Clear checkpoint after successful completion""" + await DataStore.objects.filter(key="memory_generation_checkpoint").adelete() + self.stdout.write("Checkpoint cleared") + + async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool): + """Handle deletion of user memories""" + from khoj.database.models import UserMemory + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + mode = "APPLY" if apply else "DRY RUN" + if cutoff_date: + # Calculate days from cutoff date + days_back = (timezone.now() - cutoff_date).days + self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users") + else: + self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users") + + total_deleted = 0 + for user in users: + # Count memories for this user + if cutoff_date is None: + user_memories = UserMemory.objects.filter(user=user) + else: + user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date) + + memories_count = await user_memories.acount() + if memories_count == 0: + self.stdout.write(f" User {user.username} has no memories to delete") + continue + + self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories") + + if apply: + # Delete memories for this user (with date filter if specified) + deleted_count, _ = await user_memories.adelete() + self.stdout.write(f" Deleted {deleted_count} memories") + total_deleted += deleted_count + else: + self.stdout.write(f" Would delete {memories_count} memories") + total_deleted += memories_count + + action = "Deleted" if apply else "Would delete" + self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total")) diff --git a/src/khoj/database/migrations/0099_usermemory.py b/src/khoj/database/migrations/0099_usermemory.py new file mode 100644 index 00000000..b1d6939a --- /dev/null +++ b/src/khoj/database/migrations/0099_usermemory.py @@ -0,0 +1,82 @@ +# Generated by Django 5.1.10 on 2025-08-29 22:57 + +import django.db.models.deletion +import pgvector.django +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0098_alter_texttoimagemodelconfig_model_type"), + ] + + operations = [ + migrations.CreateModel( + name="UserMemory", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("embeddings", pgvector.django.VectorField()), + ("raw", models.TextField()), + ( + "agent", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.agent", + ), + ), + ( + "search_model", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="database.searchmodelconfig", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="userconversationconfig", + name="enable_memory", + field=models.BooleanField(default=True), + ), + migrations.AddField( + model_name="serverchatsettings", + name="memory_mode", + field=models.CharField( + choices=[ + ("disabled", "Disabled"), + ("enabled_default_off", "Enabled, default off"), + ("enabled_default_on", "Enabled, default on"), + ], + default="enabled_default_on", + help_text="Server-level memory feature configuration. Disabled overrides user preference.", + max_length=20, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 404b0f38..56516b2d 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -478,6 +478,13 @@ class ServerChatSettings(DbBaseModel): THINK_PAID_FAST = "think_paid_fast" THINK_PAID_DEEP = "think_paid_deep" + class MemoryMode(models.TextChoices): + """Enum for server-level memory feature configuration""" + + DISABLED = "disabled", "Disabled" + ENABLED_DEFAULT_OFF = "enabled_default_off", "Enabled, default off" + ENABLED_DEFAULT_ON = "enabled_default_on", "Enabled, default on" + chat_default = models.ForeignKey( ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) @@ -506,6 +513,12 @@ class ServerChatSettings(DbBaseModel): unique=True, help_text="Priority of the server chat settings. Lower numbers run first.", ) + memory_mode = models.CharField( + max_length=20, + choices=MemoryMode.choices, + default=MemoryMode.ENABLED_DEFAULT_ON, + help_text="Server-level memory feature configuration. Disabled overrides user preference.", + ) def clean(self): error = {} @@ -629,6 +642,7 @@ class SpeechToTextModelOptions(DbBaseModel): class UserConversationConfig(DbBaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True) + enable_memory = models.BooleanField(default=True) class UserVoiceModelConfig(DbBaseModel): @@ -836,3 +850,15 @@ class McpServer(DbBaseModel): def __str__(self): return self.name + + +class UserMemory(DbBaseModel): + """ + Long term memory store derived from conversation between user and agent. + """ + + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True) + embeddings = VectorField(dimensions=None) + raw = models.TextField() + search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 611d41a8..0ed601ea 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1302,3 +1302,70 @@ user_name = PromptTemplate.from_template( User's Name: {name} """.strip() ) + +extract_facts_from_query = PromptTemplate.from_template( + """ +You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user. +This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember. + +You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action: +1. Create new facts +2. Delete existing facts + +You should delete existing facts that are no longer true. +You can enhance new facts with information from existing facts. +You cannot update existing facts directly, instead create new facts and delete related existing ones to update them. + +Your output should be a JSON object with two lists: create and delete. +- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective. +- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true. +- Leave the create or delete list empty if you have nothing important to add or remove. + +# Example +Existing Facts: +[ + {{ + "id": "5283", + "raw": "I am not interested in sports", + "updated_at": "2023-10-01T12:00:00+00:00" + }}, + {{ + "id": "22", + "raw": "I am a software engineer", + "updated_at": "2023-10-31T14:00:00+00:00" + }}, + {{ + "id": "651", + "raw": "My mother works at the hospital", + "updated_at": "2023-10-02T17:00:00+00:00" + }} +] + +Latest Chat Session: +- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline. +In between coding, I took my cat Whiskers out for a walk and played a game of football. +My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat. +- AI: That's great to hear! + +Response: +{{ + "create": [ + "I am interested in AI and machine learning", + "I have a pet cat named Whiskers", + "I enjoy playing football", + "My mother works at the hospital and is a doctor" + ], + "delete": [ + "5283", + "651" + ], +}} + +# Input +Existing Facts: +{matched_facts} + +Latest Chat Session: +{chat_history} +""".strip() +) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a2069243..dae182c2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -28,6 +28,7 @@ from khoj.database.models import ( ClientApplication, Intent, KhojUser, + UserMemory, ) from khoj.processor.conversation import prompts from khoj.search_filter.base_filter import BaseFilter @@ -551,6 +552,7 @@ async def save_to_conversation_log( client_application: ClientApplication = None, conversation_id: str = None, automation_id: str = None, + relevant_memories: List[UserMemory] = [], query_images: List[str] = None, raw_query_files: List[FileAttachment] = [], generated_images: List[str] = [], @@ -559,6 +561,8 @@ async def save_to_conversation_log( train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): + from khoj.routers.helpers import ai_update_memories + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") turn_id = tracer.get("mid") or str(uuid.uuid4()) @@ -605,6 +609,16 @@ async def save_to_conversation_log( user_message=q, ) + if not automation_id: + # Don't update memories from automations, as this could get noisy. + await ai_update_memories( + user=user, + conversation_history=new_messages or [], + memories=relevant_memories, + agent=db_conversation.agent if db_conversation else None, + tracer=tracer, + ) + if is_promptrace_enabled(): merge_message_into_conversation_trace(q, chat_response, tracer) @@ -666,6 +680,7 @@ def generate_chatml_messages_with_context( query_files: str = None, query_images=None, context_message="", + relevant_memories: List[UserMemory] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = [], chat_history: list[ChatMessageModel] = [], @@ -806,6 +821,14 @@ def generate_chatml_messages_with_context( ), ) + if not is_none_or_empty(relevant_memories): + memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n\n" + for memory in relevant_memories: + friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") + memory_context += f"- [{friendly_dt}]: {memory.raw}\n" + memory_context += "" + messages.append(ChatMessage(content=memory_context, role="user")) + if not is_none_or_empty(user_message): messages.append( ChatMessage( diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 45a640c1..38c2cd84 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -25,6 +25,7 @@ from khoj.database.models import ( Intent, KhojUser, TextToImageModelConfig, + UserMemory, ) from khoj.processor.conversation.google.utils import _is_retryable_error from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url @@ -47,6 +48,7 @@ async def text_to_image( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, agent: Agent = None, tracer: dict = {}, ): @@ -91,6 +93,7 @@ async def text_to_image( model_type=text_to_image_config.model_type, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer, diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 0aa12ca4..ccd780aa 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -5,7 +5,7 @@ import os from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters -from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser, UserMemory from khoj.processor.conversation.utils import ( AgentMessage, OperatorRun, @@ -42,6 +42,7 @@ async def operate_environment( query_images: Optional[List[str]] = None, # TODO: Handle query images agent: Agent = None, query_files: str = None, # TODO: Handle query files + relevant_memories: Optional[List[UserMemory]] = None, # TODO: Handle relevant memories cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, abort_message: Optional[str] = ChatEvent.END_EVENT.value, diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index beeb640c..96a587ae 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -14,6 +14,7 @@ from khoj.database.models import ( Agent, ChatMessageModel, KhojUser, + UserMemory, WebScraper, ) from khoj.routers.helpers import ( @@ -63,6 +64,7 @@ async def search_online( max_webpages_to_read: int = 1, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, previous_subqueries: Set = set(), fast_model: bool = True, agent: Agent = None, @@ -82,6 +84,7 @@ async def search_online( user, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, max_queries=max_online_searches, fast_model=fast_model, agent=agent, @@ -160,7 +163,15 @@ async def search_online( async for event in send_status_func(f"**Browsing**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - extract_from_webpage(link, data["queries"], data.get("content"), user=user, agent=agent, tracer=tracer) + extract_from_webpage( + link, + data["queries"], + data.get("content"), + relevant_memories=relevant_memories, + user=user, + agent=agent, + tracer=tracer, + ) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) @@ -438,6 +449,7 @@ async def read_webpages( fast_model: bool = True, agent: Agent = None, max_webpages_to_read: int = 1, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" @@ -460,6 +472,7 @@ async def read_webpages( user, send_status_func=send_status_func, agent=agent, + relevant_memories=relevant_memories, tracer=tracer, ): yield result @@ -471,6 +484,7 @@ async def read_webpages_content( user: KhojUser, send_status_func: Optional[Callable] = None, agent: Agent = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): logger.info(f"Reading web pages at: {urls}") @@ -478,7 +492,10 @@ async def read_webpages_content( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Browsing**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [extract_from_webpage(url, {query}, user=user, agent=agent, tracer=tracer) for url in urls] + tasks = [ + extract_from_webpage(url, {query}, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer) + for url in urls + ] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -533,6 +550,7 @@ async def extract_from_webpage( url: str, subqueries: set[str] = None, content: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -546,7 +564,9 @@ async def extract_from_webpage( extracted_info = None if not is_none_or_empty(content): with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent, tracer=tracer) + extracted_info = await extract_relevant_info( + subqueries, content, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer + ) return subqueries, url, extracted_info diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 2ec392ff..45c95115 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -21,7 +21,7 @@ from tenacity import ( ) from khoj.database.adapters import AgentAdapters, FileObjectAdapters -from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser +from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ChatEvent, @@ -59,6 +59,7 @@ async def run_code( send_status_func: Optional[Callable] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, agent: Agent = None, sandbox_url: str = SANDBOX_URL, tracer: dict = {}, @@ -79,6 +80,7 @@ async def run_code( agent, tracer, query_files, + relevant_memories, ) except Exception as e: raise ValueError(f"Failed to generate code for {instructions} with error: {e}") @@ -126,6 +128,7 @@ async def generate_python_code( agent: Agent = None, tracer: dict = {}, query_files: str = None, + relevant_memories: List[UserMemory] = None, ) -> GeneratedCode: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -161,6 +164,7 @@ async def generate_python_code( code_generation_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f7345d94..31950b1f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -13,7 +13,7 @@ from starlette.authentication import has_required_scope, requires from khoj.configure import initialize_content from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo -from khoj.database.models import KhojUser, SpeechToTextModelOptions +from khoj.database.models import KhojUser, SpeechToTextModelOptions, UserConversationConfig from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.routers.helpers import ( ApiUserRateLimiter, @@ -215,6 +215,29 @@ def set_user_name( return {"status": "ok"} +@api.patch("/user/memory", status_code=200) +@requires(["authenticated"]) +def set_user_memory_enabled( + request: Request, + enable_memory: bool, + client: Optional[str] = None, +): + user = request.user.object + + user_config, _ = UserConversationConfig.objects.get_or_create(user=user) + user_config.enable_memory = enable_memory + user_config.save() + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_user_memory_enabled", + client=client, + ) + + return {"status": "ok", "enable_memory": enable_memory} + + @api.get("/health", response_class=Response) @requires(["authenticated"], status_code=200) def health_check(request: Request) -> Response: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 04cb84de..2e326690 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -29,6 +29,7 @@ from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, PublicConversationAdapters, + UserMemoryAdapters, aget_user_name, ) from khoj.database.models import Agent, KhojUser @@ -101,7 +102,6 @@ conversation_command_rate_limiter = ConversationCommandRateLimiter( trial_rate_limit=20, subscribed_rate_limit=75, slug="command" ) - api_chat = APIRouter() @@ -963,6 +963,14 @@ async def event_generator( location = LocationData(city=city, region=region, country=country, country_code=country_code) chat_history = conversation.messages + # Get most recent memories and long term relevant memories if memory is enabled + relevant_memories = [] + if await ConversationAdapters.ais_memory_enabled(user): + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + # Create a de-duped set of memories + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + # If interrupted message in DB if last_message := await conversation.pop_message(interrupted=True): # Populate context from interrupted message @@ -987,6 +995,7 @@ async def event_generator( query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ) except ValueError as e: @@ -1024,6 +1033,7 @@ async def event_generator( previous_iterations=list(research_results), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, user_name=user_name, location=location, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1081,6 +1091,7 @@ async def event_generator( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1130,6 +1141,7 @@ async def event_generator( max_online_searches=3, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1157,6 +1169,7 @@ async def event_generator( max_webpages_to_read=1, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1198,6 +1211,7 @@ async def event_generator( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1223,6 +1237,7 @@ async def event_generator( list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, send_status_func=partial(send_event, ChatEvent.STATUS), agent=agent, cancellation_event=cancellation_event, @@ -1276,6 +1291,7 @@ async def event_generator( send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1317,6 +1333,7 @@ async def event_generator( online_results=online_results, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1375,6 +1392,7 @@ async def event_generator( user_name, uploaded_images, attached_file_context, + relevant_memories, program_execution_context, generated_asset_results, is_subscribed, @@ -1434,6 +1452,7 @@ async def event_generator( query_images=uploaded_images, train_of_thought=train_of_thought, raw_query_files=raw_query_files, + relevant_memories=relevant_memories, generated_images=generated_images, generated_mermaidjs_diagram=generated_mermaidjs_diagram, tracer=tracer, diff --git a/src/khoj/routers/api_memories.py b/src/khoj/routers/api_memories.py new file mode 100644 index 00000000..d7527827 --- /dev/null +++ b/src/khoj/routers/api_memories.py @@ -0,0 +1,114 @@ +import json +import logging +from typing import Optional + +from asgiref.sync import sync_to_async +from fastapi import APIRouter, Request +from fastapi.responses import Response +from pydantic import BaseModel +from starlette.authentication import requires + +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import UserMemory + +api_memories = APIRouter() +logger = logging.getLogger(__name__) + + +@api_memories.get("") +@requires(["authenticated"]) +async def get_memories( + request: Request, + client: Optional[str] = None, +): + """Get all memories for the authenticated user""" + user = request.user.object + + memories = UserMemory.objects.filter(user=user) + all_memories = await sync_to_async(list)(memories) + + # Convert memories to a list of dictionaries + formatted_memories = [ + { + "id": memory.id, + "raw": memory.raw, + "created_at": memory.created_at.isoformat(), + } + for memory in all_memories + ] + + return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200) + + +@api_memories.delete("/{memory_id}") +@requires(["authenticated"]) +async def delete_memory( + request: Request, + memory_id: int, + client: Optional[str] = None, +): + """Delete a specific memory by ID""" + user = request.user.object + + # Verify memory belongs to user before deleting + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + await memory.adelete() + + return Response(status_code=204) + + +class UpdateMemoryBody(BaseModel): + """Request model for updating a memory""" + + raw: str + + +@api_memories.put("/{memory_id}") +@requires(["authenticated"]) +async def update_memory( + request: Request, + body: UpdateMemoryBody, + memory_id: int, + client: Optional[str] = None, +): + """Update a specific memory's content""" + user = request.user.object + + # Get the memory and verify it belongs to the user + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + new_content = body.raw + if not new_content: + return Response( + content=json.dumps({"error": "Missing required field 'raw'"}), + media_type="application/json", + status_code=400, + ) + + await memory.adelete() + + # Create a new memory with the updated content + new_memory = await UserMemoryAdapters.save_memory( + user=user, + memory=new_content, + ) + + return Response( + content=json.dumps( + { + "id": new_memory.id, + "raw": new_memory.raw, + } + ), + media_type="application/json", + status_code=200, + ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 1c9f9b33..29a0df0f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -46,6 +46,7 @@ from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, FileObjectAdapters, + UserMemoryAdapters, aget_user_by_email, create_khoj_token, get_default_search_model, @@ -53,6 +54,7 @@ from khoj.database.adapters import ( get_user_name, get_user_notion_config, get_user_subscription_state, + require_valid_user, run_with_process_lock, ) from khoj.database.models import ( @@ -66,8 +68,10 @@ from khoj.database.models import ( NotionConfig, ProcessLock, RateLimitRecord, + ServerChatSettings, Subscription, TextToImageModelConfig, + UserMemory, UserRequests, ) from khoj.processor.content.docx.docx_to_entries import DocxToEntries @@ -354,6 +358,7 @@ async def aget_data_sources_and_output_format( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Dict[str, Any]: """ @@ -415,6 +420,7 @@ async def aget_data_sources_and_output_format( relevant_tools_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=PickTools, fast_model=True, @@ -470,6 +476,7 @@ async def infer_webpage_urls( user: KhojUser, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, fast_model: bool = True, agent: Agent = None, tracer: dict = {}, @@ -506,6 +513,7 @@ async def infer_webpage_urls( online_queries_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=WebpageUrls, fast_model=fast_model, @@ -536,6 +544,7 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, max_queries: int = 3, fast_model: bool = True, agent: Agent = None, @@ -573,6 +582,7 @@ async def generate_online_subqueries( online_queries_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=OnlineQueries, fast_model=fast_model, @@ -659,7 +669,12 @@ async def aschedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} + qs: set[str], + corpus: str, + relevant_memories: List[UserMemory] = None, + user: KhojUser = None, + agent: Agent = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -683,6 +698,7 @@ async def extract_relevant_info( response = await send_message_to_model_wrapper( extract_relevant_information, system_message=prompts.system_prompt_extract_relevant_information, + relevant_memories=relevant_memories, fast_model=True, agent_chat_model=agent_chat_model, user=user, @@ -802,6 +818,7 @@ async def generate_excalidraw_diagram( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -819,6 +836,7 @@ async def generate_excalidraw_diagram( online_results=online_results, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer, @@ -854,6 +872,7 @@ async def generate_better_diagram_description( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -899,6 +918,7 @@ async def generate_better_diagram_description( improve_diagram_description_prompt, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, @@ -956,6 +976,88 @@ async def generate_excalidraw_diagram_from_description( return response +class MemoryUpdates(BaseModel): + """Facts to add or remove from memory.""" + + create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.") + delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.") + + +async def extract_facts_from_query( + user: KhojUser, + conversation_history: List[ChatMessageModel], + existing_facts: List[UserMemory] = None, + agent: Agent = None, + tracer: dict = {}, +) -> MemoryUpdates: + """ + Extract facts from the given query + """ + chat_history = construct_chat_history(conversation_history, n=2) + + formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else [] + + extract_facts_prompt = prompts.extract_facts_from_query.format( + chat_history=chat_history, + matched_facts=formatted_memories, + ) + + with timer("Chat actor: Extract facts from query", logger): + response = await send_message_to_model_wrapper( + extract_facts_prompt, + response_schema=MemoryUpdates, + user=user, + fast_model=False, + agent_chat_model=agent.chat_model, + tracer=tracer, + ) + response = response.text.strip() + # JSON parse the list of strings + try: + response = clean_json(response) + response = json.loads(response) + parsed_response = MemoryUpdates(**response) + if not isinstance(parsed_response, MemoryUpdates): + raise ValueError(f"Invalid response for extracting facts: {response}") + return parsed_response + + except Exception: + logger.error(f"Invalid response for extracting facts: {response}") + return MemoryUpdates(create=[], delete=[]) + + +@require_valid_user +async def ai_update_memories( + user: KhojUser, + conversation_history: List[ChatMessageModel], + memories: List[UserMemory], + agent: Agent, + tracer: dict = {}, +): + """ + Updates the memories for a given user, based on their latest input query. + """ + # Skip memory updates if memory is disabled for the user + if not await ConversationAdapters.ais_memory_enabled(user): + return + + memory_update = await extract_facts_from_query( + user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer + ) + + if not memory_update: + return + + # Save the memory updates to the database + for memory in memory_update.create: + logger.info(f"Creating memory: {memory}") + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + + for memory in memory_update.delete: + logger.info(f"Deleting memory: {memory}") + await UserMemoryAdapters.delete_memory(user, memory) + + async def generate_mermaidjs_diagram( q: str, chat_history: List[ChatMessageModel], @@ -964,6 +1066,7 @@ async def generate_mermaidjs_diagram( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -981,6 +1084,7 @@ async def generate_mermaidjs_diagram( online_results=online_results, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer, @@ -1010,6 +1114,7 @@ async def generate_better_mermaidjs_diagram_description( online_results: Optional[dict] = None, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -1055,6 +1160,7 @@ async def generate_better_mermaidjs_diagram_description( improve_diagram_description_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, @@ -1104,6 +1210,7 @@ async def generate_better_image_prompt( model_type: Optional[str] = None, query_images: Optional[List[str]] = None, query_files: str = "", + relevant_memories: List[UserMemory] = None, user: KhojUser = None, agent: Agent = None, tracer: dict = {}, @@ -1148,6 +1255,7 @@ async def generate_better_image_prompt( q, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, chat_history=conversation_history, system_message=enhance_image_system_message, response_type="json_object", @@ -1180,6 +1288,7 @@ async def search_documents( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, previous_inferred_queries: Set = set(), fast_model: bool = True, agent: Agent = None, @@ -1230,6 +1339,7 @@ async def search_documents( user=user, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, personality_context=personality_context, location_data=location_data, chat_history=chat_history, @@ -1283,6 +1393,7 @@ async def extract_questions( query_files: str = None, query_images: Optional[List[str]] = None, personality_context: str = "", + relevant_memories: List[UserMemory] = None, location_data: LocationData = None, chat_history: List[ChatMessageModel] = [], max_queries: int = 5, @@ -1338,6 +1449,7 @@ async def extract_questions( query=prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, system_message=system_prompt, response_type="json_object", response_schema=DocumentQueries, @@ -1518,6 +1630,7 @@ async def send_message_to_model_wrapper( query_files: str = None, query_images: List[str] = None, context: str = "", + relevant_memories: List[UserMemory] = None, chat_history: list[ChatMessageModel] = [], system_message: str = "", # Model Config @@ -1568,6 +1681,7 @@ async def send_message_to_model_wrapper( query_files=query_files, query_images=query_images, context_message=context, + relevant_memories=relevant_memories, chat_history=chat_history, system_message=system_message, model_name=chat_model.name, @@ -1685,6 +1799,7 @@ def build_conversation_context( operator_results: List[OperatorRun], query_files: str = None, query_images: Optional[List[str]] = None, + relevant_memories: List[UserMemory] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, chat_history: List[ChatMessageModel] = [], @@ -1761,6 +1876,7 @@ def build_conversation_context( query_files=query_files, query_images=query_images, context_message=context_message, + relevant_memories=relevant_memories, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, chat_history=chat_history, @@ -1789,6 +1905,7 @@ async def agenerate_chat_response( user_name: Optional[str] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = [], program_execution_context: List[str] = [], generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, @@ -1831,6 +1948,7 @@ async def agenerate_chat_response( operator_results=operator_results, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, chat_history=chat_history, @@ -2792,6 +2910,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) selected_chat_model_config = ConversationAdapters.get_chat_model( user ) or ConversationAdapters.get_default_chat_model(user) + server_chat_settings = ServerChatSettings.objects.first() + server_memory_mode = ( + server_chat_settings.memory_mode + if server_chat_settings + else ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON.value # type: ignore[attr-defined] + ) + enable_memory = ConversationAdapters.is_memory_enabled(user) chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_model_options = list() for chat_model in chat_models: @@ -2848,6 +2973,8 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) "enabled_content_source": enabled_content_sources, "has_documents": has_documents, "notion_token": notion_token, + "enable_memory": enable_memory, + "server_memory_mode": server_memory_mode, # user model settings "chat_model_options": chat_model_options, "selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index e103b330..3185e8c0 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional import yaml from khoj.database.adapters import AgentAdapters, EntryAdapters, McpServerAdapters -from khoj.database.models import Agent, ChatMessageModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, @@ -285,6 +285,7 @@ async def apick_next_tool( max_iterations: int = 5, query_images: List[str] = [], query_files: str = None, + relevant_memories: List[UserMemory] = [], max_document_searches: int = 7, max_online_searches: int = 3, max_webpages_to_read: int = 3, @@ -416,6 +417,7 @@ async def apick_next_tool( query="", query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, system_message=function_planning_prompt, chat_history=chat_and_research_history, tools=tools, @@ -479,6 +481,7 @@ async def research( previous_iterations: List[ResearchIteration], query_images: List[str], query_files: str = None, + relevant_memories: List[UserMemory] = [], user_name: str = None, location: LocationData = None, send_status_func: Optional[Callable] = None, @@ -544,6 +547,7 @@ async def research( MAX_ITERATIONS, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, max_document_searches=max_document_searches, max_online_searches=max_online_searches, max_webpages_to_read=max_webpages_to_read, diff --git a/tests/helpers.py b/tests/helpers.py index 985ba00a..d5740d53 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,9 +4,12 @@ import os from datetime import datetime import factory +from asgiref.sync import sync_to_async from django.utils.timezone import make_aware +from khoj.database.adapters import AgentAdapters from khoj.database.models import ( + Agent, AiModelApi, ChatMessageModel, ChatModel, @@ -15,8 +18,10 @@ from khoj.database.models import ( KhojUser, ProcessLock, SearchModelConfig, + ServerChatSettings, Subscription, UserConversationConfig, + UserMemory, ) from khoj.processor.conversation.utils import message_to_log from khoj.utils.helpers import get_absolute_path, is_none_or_empty @@ -277,3 +282,45 @@ class ProcessLockFactory(factory.django.DjangoModelFactory): model = ProcessLock name = "test_lock" + + +class ServerChatSettingsFactory(factory.django.DjangoModelFactory): + class Meta: + model = ServerChatSettings + + memory_mode = ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON + + +# Async-safe wrappers for factories and ORM operations +async def acreate_user(): + return await sync_to_async(UserFactory)() + + +async def acreate_subscription(user): + return await sync_to_async(SubscriptionFactory)(user=user) + + +async def acreate_chat_model(): + return await sync_to_async(ChatModelFactory)() + + +async def acreate_default_agent(): + return await sync_to_async(AgentAdapters.create_default_agent)() + + +async def acreate_agent(name, chat_model, personality): + return await sync_to_async(Agent.objects.create)( + name=name, + chat_model=chat_model, + personality=personality, + ) + + +async def acreate_test_memory(user, agent=None, raw_text="test memory"): + """Create a memory directly in DB without embeddings for testing.""" + return await sync_to_async(UserMemory.objects.create)( + user=user, + agent=agent, + raw=raw_text, + embeddings=[0.1] * 384, # Dummy embeddings + ) diff --git a/tests/test_memory_settings.py b/tests/test_memory_settings.py new file mode 100644 index 00000000..28621670 --- /dev/null +++ b/tests/test_memory_settings.py @@ -0,0 +1,479 @@ +""" +Tests for memory enable/disable settings and memory scoping by user+agent. + +These tests verify: +1. The behavior of ConversationAdapters.is_memory_enabled() for different combinations of: + - ServerChatSettings.memory_mode (DISABLED, ENABLED_DEFAULT_OFF, ENABLED_DEFAULT_ON) + - UserConversationConfig.enable_memory (True, False, or not set) + +2. Memory scoping by user and agent: + - Memories are scoped to user + agent + - Default agent has access to ALL memories across all agents for a user + - Non-default agents only see their own memories +""" + +import pytest +from unittest.mock import MagicMock + +from khoj.database.adapters import ConversationAdapters, UserMemoryAdapters +from khoj.database.models import ServerChatSettings, UserConversationConfig +from khoj.routers.helpers import get_user_config +from tests.helpers import ( + acreate_user, + acreate_subscription, + acreate_chat_model, + acreate_default_agent, + acreate_agent, + acreate_test_memory, + ServerChatSettingsFactory, + SubscriptionFactory, + UserFactory, +) + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with no server config (default behavior) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_enabled_no_server_config_no_user_config(): + """When no server config and no user config exists, memory should be enabled (default on).""" + user = UserFactory() + SubscriptionFactory(user=user) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_no_server_config_user_enabled(): + """When no server config but user has explicitly enabled memory.""" + user = UserFactory() + SubscriptionFactory(user=user) + user_config = UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_no_server_config_user_disabled(): + """When no server config but user has explicitly disabled memory.""" + user = UserFactory() + SubscriptionFactory(user=user) + user_config = UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with server mode DISABLED +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_disabled_server_disabled_no_user_config(): + """When server disables memory, it should override everything - no user config.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +@pytest.mark.django_db +def test_memory_disabled_server_disabled_user_enabled(): + """When server disables memory, it should override user preference (enabled).""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +@pytest.mark.django_db +def test_memory_disabled_server_disabled_user_disabled(): + """When server disables memory, user disabled too - should be disabled.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with server mode ENABLED_DEFAULT_OFF +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_enabled_default_off_no_user_config(): + """When server is enabled_default_off and no user config, memory should be off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +@pytest.mark.django_db +def test_memory_enabled_default_off_user_enabled(): + """When server is enabled_default_off and user opts in, memory should be on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_default_off_user_disabled(): + """When server is enabled_default_off and user explicitly disabled, memory should be off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test is_memory_enabled with server mode ENABLED_DEFAULT_ON +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_memory_enabled_default_on_no_user_config(): + """When server is enabled_default_on and no user config, memory should be on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_default_on_user_enabled(): + """When server is enabled_default_on and user enabled, memory should be on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + UserConversationConfig.objects.create(user=user, enable_memory=True) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is True + + +@pytest.mark.django_db +def test_memory_enabled_default_on_user_disabled(): + """When server is enabled_default_on and user opts out, memory should be off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + UserConversationConfig.objects.create(user=user, enable_memory=False) + + result = ConversationAdapters.is_memory_enabled(user) + + assert result is False + + +# ---------------------------------------------------------------------------------------------------- +# Test get_user_config returns correct enable_memory and server_memory_mode +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +def test_get_user_config_memory_no_server_config(): + """get_user_config should return default values when no server config.""" + user = UserFactory() + SubscriptionFactory(user=user) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is True + assert config["server_memory_mode"] == "enabled_default_on" + + +@pytest.mark.django_db +def test_get_user_config_memory_server_disabled(): + """get_user_config should reflect server disabled mode.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.DISABLED) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is False + assert config["server_memory_mode"] == "disabled" + + +@pytest.mark.django_db +def test_get_user_config_memory_server_enabled_default_off_user_opted_in(): + """get_user_config should show user opted in when server is default off.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF) + UserConversationConfig.objects.create(user=user, enable_memory=True) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is True + assert config["server_memory_mode"] == "enabled_default_off" + + +@pytest.mark.django_db +def test_get_user_config_memory_server_enabled_default_on_user_opted_out(): + """get_user_config should show user opted out when server is default on.""" + user = UserFactory() + SubscriptionFactory(user=user) + ServerChatSettingsFactory(memory_mode=ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON) + UserConversationConfig.objects.create(user=user, enable_memory=False) + request = MagicMock() + request.url = MagicMock() + request.url.path = "/api/config" + request.session = {} + + config = get_user_config(user, request, is_detailed=True) + + assert config["enable_memory"] is False + assert config["server_memory_mode"] == "enabled_default_on" + + +# ---------------------------------------------------------------------------------------------------- +# Test memory scoping by user and agent +# ---------------------------------------------------------------------------------------------------- + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_pull_memories_default_agent_sees_all_memories(): + """Default agent should see ALL memories for the user, including those from other agents.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create a custom agent + custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent") + + # Create memories for different agents + await acreate_test_memory(user, agent=None, raw_text="memory without agent") + await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent") + await acreate_test_memory(user, agent=custom_agent, raw_text="memory for custom agent") + + # Act: Pull memories with default agent + memories = await UserMemoryAdapters.pull_memories(user=user, agent=default_agent) + + # Assert: Default agent sees ALL memories + memory_texts = [m.raw for m in memories] + assert "memory without agent" in memory_texts + assert "memory for default agent" in memory_texts + assert "memory for custom agent" in memory_texts + assert len(memories) == 3 + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_pull_memories_custom_agent_sees_only_own_memories(): + """Custom (non-default) agent should only see its own memories.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create custom agents + custom_agent_1 = await acreate_agent("Custom Agent 1", chat_model, "First custom agent") + custom_agent_2 = await acreate_agent("Custom Agent 2", chat_model, "Second custom agent") + + # Create memories for different agents + await acreate_test_memory(user, agent=None, raw_text="memory without agent") + await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent") + await acreate_test_memory(user, agent=custom_agent_1, raw_text="memory for custom agent 1") + await acreate_test_memory(user, agent=custom_agent_2, raw_text="memory for custom agent 2") + + # Act: Pull memories with custom_agent_1 + memories = await UserMemoryAdapters.pull_memories(user=user, agent=custom_agent_1) + + # Assert: Custom agent 1 only sees its own memories + memory_texts = [m.raw for m in memories] + assert "memory for custom agent 1" in memory_texts + assert "memory without agent" not in memory_texts + assert "memory for default agent" not in memory_texts + assert "memory for custom agent 2" not in memory_texts + assert len(memories) == 1 + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_pull_memories_no_agent_same_as_default_agent(): + """Pulling memories with agent=None should behave same as default agent (see all).""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create a custom agent + custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent") + + # Create memories + await acreate_test_memory(user, agent=None, raw_text="memory without agent") + await acreate_test_memory(user, agent=default_agent, raw_text="memory for default agent") + await acreate_test_memory(user, agent=custom_agent, raw_text="memory for custom agent") + + # Act: Pull memories with agent=None + memories = await UserMemoryAdapters.pull_memories(user=user, agent=None) + + # Assert: Should see all memories (same as default agent behavior) + memory_texts = [m.raw for m in memories] + assert "memory without agent" in memory_texts + assert "memory for default agent" in memory_texts + assert "memory for custom agent" in memory_texts + assert len(memories) == 3 + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_save_memory_with_custom_agent_scopes_to_agent(): + """Memories saved with a custom agent should be scoped to that agent.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create custom agent + custom_agent = await acreate_agent("Custom Agent", chat_model, "A custom agent") + + # Create memory with custom agent (directly in DB to avoid embeddings) + memory = await acreate_test_memory(user, agent=custom_agent, raw_text="custom agent memory") + + # Assert: Memory is scoped to the custom agent + assert memory.agent == custom_agent + assert memory.user == user + + # Verify custom agent can see it + custom_memories = await UserMemoryAdapters.pull_memories(user=user, agent=custom_agent) + assert len(custom_memories) == 1 + assert custom_memories[0].raw == "custom agent memory" + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_save_memory_with_default_agent_has_no_agent_scope(): + """Memories saved with default agent should have agent=None (global scope).""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + await acreate_chat_model() # Required for default agent creation + + # Create default agent + default_agent = await acreate_default_agent() + assert default_agent is not None + + # Create memory with default agent (directly in DB) + # Based on save_memory logic: if agent == default_agent, agent is not set + memory = await acreate_test_memory(user, agent=None, raw_text="default agent memory") + + # Assert: Memory has no agent (global scope) + assert memory.agent is None + assert memory.user == user + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_memories_isolated_between_users(): + """Memories should be isolated between different users.""" + # Setup + user1 = await acreate_user() + user2 = await acreate_user() + await acreate_subscription(user1) + await acreate_subscription(user2) + + # Create default agent + await acreate_default_agent() + + # Create memories for each user + await acreate_test_memory(user1, agent=None, raw_text="user1 memory") + await acreate_test_memory(user2, agent=None, raw_text="user2 memory") + + # Act: Pull memories for each user + user1_memories = await UserMemoryAdapters.pull_memories(user=user1) + user2_memories = await UserMemoryAdapters.pull_memories(user=user2) + + # Assert: Each user only sees their own memories + assert len(user1_memories) == 1 + assert user1_memories[0].raw == "user1 memory" + + assert len(user2_memories) == 1 + assert user2_memories[0].raw == "user2 memory" + + +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_custom_agent_cannot_see_other_custom_agent_memories(): + """One custom agent should not see another custom agent's memories.""" + # Setup + user = await acreate_user() + await acreate_subscription(user) + chat_model = await acreate_chat_model() + + # Create default agent + await acreate_default_agent() + + # Create two custom agents + agent_accountant = await acreate_agent("Accountant", chat_model, "Financial advisor") + agent_chef = await acreate_agent("Chef", chat_model, "Cooking expert") + + # Create memories for each agent + await acreate_test_memory(user, agent=agent_accountant, raw_text="user spent $500 on groceries") + await acreate_test_memory(user, agent=agent_chef, raw_text="user likes Italian food") + + # Act & Assert: Accountant only sees financial memories + accountant_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent_accountant) + assert len(accountant_memories) == 1 + assert accountant_memories[0].raw == "user spent $500 on groceries" + + # Act & Assert: Chef only sees food memories + chef_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent_chef) + assert len(chef_memories) == 1 + assert chef_memories[0].raw == "user likes Italian food"