Give Khoj Long Term Memories (#1168)

# Motivation
A major component of useful AI systems is adaptation to the user
context. This is a major reason why we'd enabled syncing knowledge
bases. The next steps in this direction is to dynamically update the
evolving state of the user as conversations take place across time and
topics. This allows for more personalized conversations and to maintain
context across conversations.

# Overview
This change introduces medium and long term memories in Khoj. 
- The scope of a conversation can be thought of as short term memory. 
- Medium term memory extends to the past week.
- Long term memory extends to anytime in the past, where a search query
results in a match.

# Details
- Enable user to view and manage agent generated memories from their
settings page
- Fully integrate the memory object into all downstream usage, from
image generation, notes extraction, online search, etc.
- Scope memory per agent. The default agent has access to memories
created by other agents as well.
- Enable users and admins to enable/disable Khoj's memory system

---------

Co-authored-by: Debanjum <debanjum@gmail.com>
This commit is contained in:
sabaimran
2026-01-02 19:37:05 -08:00
committed by GitHub
parent d55a00288b
commit d6c2d1fa49
25 changed files with 1910 additions and 17 deletions

View File

@@ -63,6 +63,8 @@ export interface UserConfig {
enabled_content_source: SyncedContent;
has_documents: boolean;
notion_token: string | null;
enable_memory: boolean;
server_memory_mode: "disabled" | "enabled_default_off" | "enabled_default_on";
// user model settings
search_model_options: ModelOptions[];
selected_search_model_config: number;

View File

@@ -0,0 +1,90 @@
import { useState } from "react";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react";
import { useToast } from "@/components/ui/use-toast";
export interface UserMemorySchema {
id: number;
raw: string;
created_at: string;
}
interface UserMemoryProps {
memory: UserMemorySchema;
onDelete: (id: number) => void;
onUpdate: (id: number, raw: string) => void;
}
export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) {
const [isEditing, setIsEditing] = useState(false);
const [content, setContent] = useState(memory.raw);
const { toast } = useToast();
const handleUpdate = () => {
onUpdate(memory.id, content);
setIsEditing(false);
toast({
title: "Memory Updated",
description: "Your memory has been successfully updated.",
});
};
const handleDelete = () => {
onDelete(memory.id);
toast({
title: "Memory Deleted",
description: "Your memory has been successfully deleted.",
});
};
return (
<div className="flex items-center gap-2 w-full">
{isEditing ? (
<>
<Input
value={content}
onChange={(e) => setContent(e.target.value)}
className="flex-1"
/>
<Button
variant="ghost"
size="icon"
onClick={handleUpdate}
title="Save"
>
<FloppyDisk className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
onClick={() => setIsEditing(false)}
title="Cancel"
>
<X className="h-4 w-4" />
</Button>
</>
) : (
<>
<Input value={memory.raw} readOnly className="flex-1" />
<Button
variant="ghost"
size="icon"
onClick={() => setIsEditing(true)}
title="Edit"
>
<Pencil className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
onClick={handleDelete}
title="Delete"
>
<TrashSimple className="h-4 w-4" />
</Button>
</>
)}
</div>
);
}

View File

@@ -15,6 +15,8 @@ import { Button } from "@/components/ui/button";
import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp";
import { Input } from "@/components/ui/input";
import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card";
import { Switch } from "@/components/ui/switch";
import {
DropdownMenu,
DropdownMenuContent,
@@ -33,6 +35,15 @@ import {
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogTrigger
} from "@/components/ui/dialog";
import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table";
import {
@@ -74,6 +85,7 @@ import Loading from "../components/loading/loading";
import IntlTelInput from "intl-tel-input/react";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar";
import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory";
import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo";
import { Progress } from "@/components/ui/progress";
@@ -323,6 +335,9 @@ export default function SettingsView() {
const [numberValidationState, setNumberValidationState] = useState<PhoneNumberValidationState>(
PhoneNumberValidationState.Verified,
);
const [memories, setMemories] = useState<UserMemorySchema[]>([]);
const [enableMemory, setEnableMemory] = useState<boolean>(true);
const [serverMemoryMode, setServerMemoryMode] = useState<string>("enabled_default_on");
const [isExporting, setIsExporting] = useState(false);
const [exportProgress, setExportProgress] = useState(0);
const [exportedConversations, setExportedConversations] = useState(0);
@@ -347,6 +362,8 @@ export default function SettingsView() {
);
setName(initialUserConfig?.given_name);
setNotionToken(initialUserConfig?.notion_token ?? null);
setEnableMemory(initialUserConfig?.enable_memory ?? true);
setServerMemoryMode(initialUserConfig?.server_memory_mode ?? "enabled_default_on");
}, [initialUserConfig]);
const sendOTP = async () => {
@@ -621,6 +638,88 @@ export default function SettingsView() {
}
};
const fetchMemories = async () => {
try {
console.log("Fetching memories...");
const response = await fetch('/api/memories/');
if (!response.ok) throw new Error('Failed to fetch memories');
const data = await response.json();
setMemories(data);
} catch (error) {
console.error('Error fetching memories:', error);
toast({
title: "Error",
description: "Failed to fetch memories. Please try again.",
variant: "destructive"
});
}
};
const handleDeleteMemory = async (id: number) => {
try {
const response = await fetch(`/api/memories/${id}`, {
method: 'DELETE'
});
if (!response.ok) throw new Error('Failed to delete memory');
setMemories(memories.filter(memory => memory.id !== id));
} catch (error) {
console.error('Error deleting memory:', error);
toast({
title: "Error",
description: "Failed to delete memory. Please try again.",
variant: "destructive"
});
}
};
const handleUpdateMemory = async (id: number, raw: string) => {
try {
const response = await fetch(`/api/memories/${id}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ raw, memory_id: id }),
});
if (!response.ok) throw new Error('Failed to update memory');
const updatedMemory: UserMemorySchema = await response.json();
setMemories(memories.map(memory =>
memory.id === id ? updatedMemory : memory
));
} catch (error) {
console.error('Error updating memory:', error);
toast({
title: "Error",
description: "Failed to update memory. Please try again.",
variant: "destructive"
});
}
};
const handleToggleMemory = async (enabled: boolean) => {
try {
const response = await fetch(`/api/user/memory?enable_memory=${enabled}`, {
method: 'PATCH',
});
if (!response.ok) throw new Error('Failed to update memory setting');
setEnableMemory(enabled);
toast({
title: enabled ? "Memory enabled" : "Memory disabled",
description: enabled
? "Khoj will learn and remember from your conversations."
: "Khoj will no longer learn or remember from your conversations.",
});
} catch (error) {
console.error('Error toggling memory:', error);
toast({
title: "Error",
description: "Failed to update memory setting. Please try again.",
variant: "destructive"
});
}
};
const syncContent = async (type: string) => {
try {
const response = await fetch(`/api/content?t=${type}`, {
@@ -1212,7 +1311,64 @@ export default function SettingsView() {
</Button>
</CardFooter>
</Card>
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<Brain className="h-7 w-7 mr-2" />
Memories
</CardHeader>
<CardContent className="overflow-hidden">
<p className="pb-4 text-gray-400">
View and manage your long-term memories
</p>
<div className="flex items-center justify-between">
<label
htmlFor="enable-memory"
className={`text-sm font-medium leading-none ${serverMemoryMode === "disabled" ? "text-gray-400" : ""}`}
>
Enable Memory
</label>
<Switch
id="enable-memory"
checked={enableMemory}
onCheckedChange={(checked) => handleToggleMemory(checked)}
disabled={serverMemoryMode === "disabled"}
/>
</div>
{serverMemoryMode === "disabled" && (
<p className="text-xs text-gray-400 mt-2">
Memory has been disabled by the server administrator.
</p>
)}
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
<Dialog onOpenChange={(open) => open && fetchMemories()}>
<DialogTrigger asChild>
<Button variant="outline">
<Brain className="h-5 w-5 mr-2" />
Browse Memories
</Button>
</DialogTrigger>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogHeader>
<DialogTitle>Your Memories</DialogTitle>
</DialogHeader>
<div className="grid gap-4 py-4">
{memories.map((memory) => (
<UserMemory
key={memory.id}
memory={memory}
onDelete={handleDeleteMemory}
onUpdate={handleUpdateMemory}
/>
))}
{memories.length === 0 && (
<p className="text-center text-gray-500">No memories found</p>
)}
</div>
</DialogContent>
</Dialog>
</CardFooter>
</Card>
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<TrashSimple className="h-7 w-7 mr-2 text-red-500" />

View File

@@ -24,6 +24,7 @@
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toast": "^1.2.15",
"@radix-ui/react-toggle": "^1.1.10",
@@ -274,7 +275,7 @@
"@radix-ui/react-slot": ["@radix-ui/react-slot@1.2.3", "", { "dependencies": { "@radix-ui/react-compose-refs": "1.1.2" }, "peerDependencies": { "@types/react": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react"] }, "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A=="],
"@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="],
"@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.6", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-bByzr1+ep1zk4VubeEVViV592vu2lHE2BZY5OnzehZqOOgogN80+mNtCqPkhn2gklJqOpxWgPoYTSnhBCqpOXQ=="],
"@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.13", "", { "dependencies": { "@radix-ui/primitive": "1.1.3", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.5", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.11", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A=="],
@@ -1478,8 +1479,6 @@
"@radix-ui/react-slider/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="],
"@radix-ui/react-switch/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="],
"@radix-ui/react-toggle-group/@radix-ui/primitive": ["@radix-ui/primitive@1.1.2", "", {}, "sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA=="],
"@radix-ui/react-toggle-group/@radix-ui/react-roving-focus": ["@radix-ui/react-roving-focus@1.1.10", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q=="],
@@ -1576,6 +1575,8 @@
"radix-ui/@radix-ui/react-select": ["@radix-ui/react-select@2.2.5", "", { "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-popper": "1.2.7", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA=="],
"radix-ui/@radix-ui/react-switch": ["@radix-ui/react-switch@1.2.5", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ=="],
"radix-ui/@radix-ui/react-tabs": ["@radix-ui/react-tabs@1.1.12", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-GTVAlRVrQrSw3cEARM0nAx73ixrWDPNZAruETn3oHCNP6SbZ/hNxdxp+u7VkIEv3/sFoLq1PfcHrl7Pnp0CDpw=="],
"radix-ui/@radix-ui/react-toast": ["@radix-ui/react-toast@1.2.14", "", { "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", "@types/react-dom": "*", "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-nAP5FBxBJGQ/YfUB+r+O6USFVkWq3gAInkxyEnmvEV5jtSbfDhfa4hwX8CraCnbjMLsE7XSf/K75l9xXY7joWg=="],

View File

@@ -0,0 +1,29 @@
"use client"
import * as React from "react"
import * as SwitchPrimitives from "@radix-ui/react-switch"
import { cn } from "@/lib/utils"
const Switch = React.forwardRef<
React.ElementRef<typeof SwitchPrimitives.Root>,
React.ComponentPropsWithoutRef<typeof SwitchPrimitives.Root>
>(({ className, ...props }, ref) => (
<SwitchPrimitives.Root
className={cn(
"peer inline-flex h-6 w-11 shrink-0 cursor-pointer items-center rounded-full border-2 border-transparent transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-gray-300 dark:data-[state=unchecked]:bg-gray-600",
className
)}
{...props}
ref={ref}
>
<SwitchPrimitives.Thumb
className={cn(
"pointer-events-none block h-5 w-5 rounded-full bg-background shadow-lg ring-0 transition-transform data-[state=checked]:translate-x-5 data-[state=unchecked]:translate-x-0"
)}
/>
</SwitchPrimitives.Root>
))
Switch.displayName = SwitchPrimitives.Root.displayName
export { Switch }

View File

@@ -38,6 +38,7 @@
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toast": "^1.2.15",
"@radix-ui/react-toggle": "^1.1.10",

View File

@@ -323,6 +323,7 @@ def configure_routes(app):
from khoj.routers.api_automation import api_automation
from khoj.routers.api_chat import api_chat
from khoj.routers.api_content import api_content
from khoj.routers.api_memories import api_memories
from khoj.routers.api_model import api_model
from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client
@@ -332,6 +333,7 @@ def configure_routes(app):
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_automation, prefix="/api/automation")
app.include_router(api_model, prefix="/api/model")
app.include_router(api_memories, prefix="/api/memories")
app.include_router(api_content, prefix="/api/content")
app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client)

View File

@@ -62,6 +62,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserMemory,
UserRequests,
UserTextToImageModelConfig,
UserVoiceModelConfig,
@@ -566,6 +567,16 @@ def get_default_search_model() -> SearchModelConfig:
return SearchModelConfig.objects.first()
async def aget_default_search_model() -> SearchModelConfig:
default_search_model = await SearchModelConfig.objects.filter(name="default").afirst()
if default_search_model:
return default_search_model
elif await SearchModelConfig.objects.count() == 0:
await SearchModelConfig.objects.acreate()
return await SearchModelConfig.objects.afirst()
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
@@ -1556,12 +1567,17 @@ class ConversationAdapters:
):
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
conversation = (
await Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
.prefetch_related("agent", "agent__chat_model")
.afirst()
)
else:
conversation = (
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
await Conversation.objects.filter(user=user, client=client_application)
.prefetch_related("agent", "agent__chat_model")
.order_by("-updated_at")
.afirst()
)
existing_messages = conversation.messages if conversation else []
@@ -1597,6 +1613,85 @@ class ConversationAdapters:
return None
return config.setting
@staticmethod
async def ais_memory_enabled(user: KhojUser) -> bool:
"""
Check if memory is enabled for the user based on server config and user preference.
Logic:
- If server memory_mode is DISABLED: return False (overrides user preference)
- If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False
- If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True
- If no server config exists: default to True
"""
# Get server-level memory configuration
server_settings = await ServerChatSettings.objects.afirst()
if server_settings:
memory_mode = server_settings.memory_mode
# Disabled mode overrides all user preferences
if memory_mode == ServerChatSettings.MemoryMode.DISABLED:
return False
# Get user preference
user_config = await UserConversationConfig.objects.filter(user=user).afirst()
if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF:
# User must explicitly opt-in; check if user has set a preference
if user_config is None:
return False # Default off for new users
return user_config.enable_memory
# ENABLED_DEFAULT_ON: use user preference if set, else True
if user_config is None:
return True # Default on for new users
return user_config.enable_memory
# No server config - default behavior (enabled, default on)
user_config = await UserConversationConfig.objects.filter(user=user).afirst()
if user_config is None:
return True
return user_config.enable_memory
@staticmethod
def is_memory_enabled(user: KhojUser) -> bool:
"""
Sync version of ais_memory_enabled.
Check if memory is enabled for the user based on server config and user preference.
Logic:
- If server memory_mode is DISABLED: return False (overrides user preference)
- If server memory_mode is ENABLED_DEFAULT_OFF: use user preference if set, else False
- If server memory_mode is ENABLED_DEFAULT_ON: use user preference if set, else True
- If no server config exists: default to True
"""
# Get server-level memory configuration
server_settings = ServerChatSettings.objects.first()
if server_settings:
memory_mode = server_settings.memory_mode
# Disabled mode overrides all user preferences
if memory_mode == ServerChatSettings.MemoryMode.DISABLED:
return False
# Get user preference
user_config = UserConversationConfig.objects.filter(user=user).first()
if memory_mode == ServerChatSettings.MemoryMode.ENABLED_DEFAULT_OFF:
# User must explicitly opt-in; check if user has set a preference
if user_config is None:
return False # Default off for new users
return user_config.enable_memory
# ENABLED_DEFAULT_ON: use user preference if set, else True
if user_config is None:
return True # Default on for new users
return user_config.enable_memory
# No server config - default behavior (enabled, default on)
user_config = UserConversationConfig.objects.filter(user=user).first()
if user_config is None:
return True
return user_config.enable_memory
@staticmethod
async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().prefetch_related("ai_model_api").afirst()
@@ -2186,3 +2281,94 @@ class McpServerAdapters:
except Exception as e:
logger.error(f"Error retrieving MCP servers: {e}", exc_info=True)
return servers
class UserMemoryAdapters:
@staticmethod
@require_valid_user
async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]:
"""
Pulls memories from the database for a given user. Medium term memory.
"""
time_frame = datetime.now(timezone.utc) - timedelta(days=window)
default_agent = await AgentAdapters.aget_default_agent()
if agent and agent != default_agent:
memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by(
"-created_at"
)[:limit]
else:
memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit]
return await sync_to_async(list)(memories)
@staticmethod
@require_valid_user
async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory:
"""
Saves a memory to the database for a given user.
"""
embeddings_model = state.embeddings_model
model = await aget_default_search_model()
embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory)
default_agent = await AgentAdapters.aget_default_agent()
if agent and agent != default_agent:
memory_instance = await UserMemory.objects.acreate(
user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent
)
else:
memory_instance = await UserMemory.objects.acreate(
user=user, embeddings=embeddings, raw=memory, search_model=model
)
return memory_instance
@staticmethod
@require_valid_user
async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]:
"""
Searches for memories in the database for a given user. Long term memory.
"""
embeddings_model = state.embeddings_model
model = await aget_default_search_model()
max_distance = model.bi_encoder_confidence_threshold or math.inf
embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query)
default_agent = await AgentAdapters.aget_default_agent()
if agent and agent != default_agent:
relevant_memories = UserMemory.objects.filter(user=user, agent=agent)
else:
relevant_memories = UserMemory.objects.filter(user=user)
relevant_memories = (
relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query))
.order_by("distance")
.filter(distance__lte=max_distance)
)
return await sync_to_async(list)(relevant_memories[:limit])
@staticmethod
@require_valid_user
async def delete_memory(user: KhojUser, memory_id: str) -> bool:
"""
Deletes a memory from the database for a given user.
"""
try:
memory = await UserMemory.objects.aget(user=user, id=memory_id)
await memory.adelete()
return True
except UserMemory.DoesNotExist:
return False
@staticmethod
def to_dict(memories: List[UserMemory]) -> List[dict]:
"""
Converts a list of Memory objects to a list of dictionaries.
"""
return [
{
"id": f"{memory.id}",
"raw": memory.raw,
"updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"),
}
for memory in memories
]

View File

@@ -34,6 +34,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserMemory,
UserRequests,
UserVoiceModelConfig,
VoiceModelOption,
@@ -182,6 +183,7 @@ admin.site.register(UserVoiceModelConfig, unfold_admin.ModelAdmin)
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
admin.site.register(UserMemory, unfold_admin.ModelAdmin)
@admin.register(McpServer)
@@ -295,6 +297,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin):
"think_paid_fast",
"think_paid_deep",
"web_scraper",
"memory_mode",
)
ordering = ("priority",)

View File

@@ -0,0 +1,384 @@
import asyncio
from datetime import datetime, timedelta
from typing import List, Optional
from django.core.management.base import BaseCommand
from django.db.models import Q
from django.utils import timezone
from khoj.configure import initialize_server
from khoj.database.adapters import UserMemoryAdapters
from khoj.database.models import (
Conversation,
DataStore,
KhojUser,
)
from khoj.routers.helpers import extract_facts_from_query
class Command(BaseCommand):
help = "Manage user memories - generate from conversations or delete existing memories"
def add_arguments(self, parser):
parser.add_argument(
"--lookback-days",
type=int,
default=None,
help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories",
)
parser.add_argument(
"--users",
type=str,
help="Process specific users (comma-separated usernames or emails)",
)
parser.add_argument(
"--batch-size",
type=int,
default=10,
help="Number of conversations to process in each batch (default: 10)",
)
parser.add_argument(
"--apply",
action="store_true",
help="Actually perform the operation. Without this flag, only shows what would be processed.",
)
parser.add_argument(
"--delete",
action="store_true",
help="Delete all memories for specified users instead of generating new ones",
)
parser.add_argument(
"--resume",
action="store_true",
help="Resume from last checkpoint if process was interrupted",
)
parser.add_argument(
"--force",
action="store_true",
help="Force regenerate memories even if already processed",
)
def handle(self, *args, **options):
"""Main entry point for the command"""
initialize_server()
asyncio.run(self.async_handle(*args, **options))
async def async_handle(self, *args, **options):
"""Async handler for memory management"""
lookback_days = options["lookback_days"]
usernames = options["users"]
batch_size = options["batch_size"]
apply = options["apply"]
delete = options["delete"]
resume = options["resume"]
force = options["force"]
mode = "APPLY" if apply else "DRY RUN"
# Handle deletion mode
if delete:
# For deletion, only use cutoff_date if lookback_days is explicitly provided
cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None
await self.handle_delete_memories(usernames, cutoff_date, apply)
return
# Handle generation mode
# For generation, default to 7 days if not specified
if lookback_days is None:
lookback_days = 7
cutoff_date = timezone.now() - timedelta(days=lookback_days)
self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days")
# Get users to process
users = await self.get_users_to_process(usernames)
if not users:
self.stdout.write("No users found to process")
return
self.stdout.write(f"Found {len(users)} users to process")
# Initialize or retrieve checkpoint
checkpoint = await self.get_or_create_checkpoint(resume)
total_conversations = 0
total_memories = 0
for user in users:
# Check if user already processed in checkpoint
if not force and user.id in checkpoint.get("processed_users", []):
self.stdout.write(f"Skipping already processed user: {user.username}")
continue
self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})")
# Get conversations for this user
conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force)
if not conversations:
self.stdout.write(f" No conversations to process for {user.username}")
# Mark user as processed
if apply:
await self.update_checkpoint(checkpoint, user_id=user.id)
continue
self.stdout.write(f" Found {len(conversations)} conversations to process")
# Process conversations in batches
user_memories = 0
for i in range(0, len(conversations), batch_size):
batch = conversations[i : i + batch_size]
batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint)
user_memories += batch_memories
total_conversations += len(batch)
# Update progress
progress = min(i + batch_size, len(conversations))
self.stdout.write(
f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories"
)
total_memories += user_memories
self.stdout.write(
f" Completed user {user.username}: "
f"processed {len(conversations)} conversations, "
f"generated {user_memories} memories"
)
# Mark user as processed
if apply:
await self.update_checkpoint(checkpoint, user_id=user.id)
# Clear checkpoint on successful completion
if apply:
await self.clear_checkpoint()
action = "Generated" if apply else "Would generate"
self.stdout.write(
self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations")
)
async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]:
"""Get list of users to comma separated usernames or emails to process"""
if users_str:
usernames = [u.strip() for u in users_str.split(",") if u.strip()]
# Process specific users
users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))]
return users
else:
# Process all users with conversations
return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()]
async def get_user_conversations(
self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool
) -> List[Conversation]:
"""Get conversations for a user that need processing"""
if cutoff_date is None:
query = Conversation.objects.filter(user=user).order_by("updated_at")
else:
query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at")
# Filter out already processed conversations if resuming
if not force and user.id in checkpoint.get("processed_conversations", {}):
processed_ids = checkpoint["processed_conversations"][user.id]
query = query.exclude(id__in=processed_ids)
return [conv async for conv in query]
async def process_conversation_batch(
self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict
) -> int:
"""Process a batch of conversations and generate memories"""
total_memories = 0
for conversation in conversations:
try:
# Get conversation messages using sync_to_async for property access
from asgiref.sync import sync_to_async
# Access conversation_log synchronously
@sync_to_async
def get_messages():
return conversation.messages
messages = await get_messages()
if not messages:
continue
# Get agent if conversation has one
@sync_to_async
def get_agent():
return conversation.agent
agent = await get_agent()
# Get existing memories for context
# Process each conversation turn
conversation_memories = 0
i = 0
while i + 1 < len(messages):
# Only process user-assistant pairs as a valid turn for memory extraction
if messages[i].by != "you" or messages[i + 1].by != "khoj":
i += 1
continue
# Get the conversation history up to this point
history = messages[: i + 2]
# Extract user query text for memory search
q = ""
if messages[i].message is None:
i += 1
continue
elif isinstance(messages[i].message, str):
q = messages[i].message
elif isinstance(messages[i].message, list):
q = "\n\n".join(
content.get("text", "")
for content in messages[i].message
if isinstance(content, dict) and content.get("text")
)
if not q or not q.strip():
i += 1
continue
# Get unique recent and long term relevant memories
recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent)
long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent)
relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values())
if apply:
# Ensure agent is fully loaded with its chat_model
if agent:
@sync_to_async
def load_agent_with_chat_model():
# Force load the chat_model relationship
_ = agent.chat_model
return agent
agent = await load_agent_with_chat_model()
# Update memories based on latest conversation turn
memory_updates = await extract_facts_from_query(
user=user,
conversation_history=history,
existing_facts=relevant_memories,
agent=agent,
tracer={},
)
# Save new memories
for memory in memory_updates.create:
await UserMemoryAdapters.save_memory(user, memory, agent=agent)
conversation_memories += 1
self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...")
# Delete outdated memories
for memory in memory_updates.delete:
await UserMemoryAdapters.delete_memory(user, memory)
self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...")
else:
# Dry run - estimate memories that would be created
conversation_memories += 1 # Rough estimate
# Move to next conversation turn pair
i += 2
total_memories += conversation_memories
# Update checkpoint after each conversation
if apply:
await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id))
except Exception as e:
import traceback
self.stderr.write(
f"Error processing conversation {conversation.id} for user {user.id}: {e}\n"
f"Traceback: {traceback.format_exc()}"
)
continue
return total_memories
async def get_or_create_checkpoint(self, resume: bool) -> dict:
"""Get or create checkpoint for resumable processing"""
checkpoint_key = "memory_generation_checkpoint"
if resume:
# Try to retrieve existing checkpoint
checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst()
if checkpoint_store:
self.stdout.write("Resuming from checkpoint...")
return checkpoint_store.value
# Create new checkpoint
return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}}
async def update_checkpoint(
self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None
):
"""Update checkpoint with progress"""
if user_id and user_id not in checkpoint["processed_users"]:
checkpoint["processed_users"].append(user_id)
if user_id and conversation_id:
if user_id not in checkpoint["processed_conversations"]:
checkpoint["processed_conversations"][user_id] = []
if conversation_id not in checkpoint["processed_conversations"][user_id]:
checkpoint["processed_conversations"][user_id].append(conversation_id)
# Save checkpoint to database
await DataStore.objects.aupdate_or_create(
key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True}
)
async def clear_checkpoint(self):
"""Clear checkpoint after successful completion"""
await DataStore.objects.filter(key="memory_generation_checkpoint").adelete()
self.stdout.write("Checkpoint cleared")
async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool):
"""Handle deletion of user memories"""
from khoj.database.models import UserMemory
# Get users to process
users = await self.get_users_to_process(usernames)
if not users:
self.stdout.write("No users found to process")
return
mode = "APPLY" if apply else "DRY RUN"
if cutoff_date:
# Calculate days from cutoff date
days_back = (timezone.now() - cutoff_date).days
self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users")
else:
self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users")
total_deleted = 0
for user in users:
# Count memories for this user
if cutoff_date is None:
user_memories = UserMemory.objects.filter(user=user)
else:
user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date)
memories_count = await user_memories.acount()
if memories_count == 0:
self.stdout.write(f" User {user.username} has no memories to delete")
continue
self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories")
if apply:
# Delete memories for this user (with date filter if specified)
deleted_count, _ = await user_memories.adelete()
self.stdout.write(f" Deleted {deleted_count} memories")
total_deleted += deleted_count
else:
self.stdout.write(f" Would delete {memories_count} memories")
total_deleted += memories_count
action = "Deleted" if apply else "Would delete"
self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total"))

View File

@@ -0,0 +1,82 @@
# Generated by Django 5.1.10 on 2025-08-29 22:57
import django.db.models.deletion
import pgvector.django
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0098_alter_texttoimagemodelconfig_model_type"),
]
operations = [
migrations.CreateModel(
name="UserMemory",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("embeddings", pgvector.django.VectorField()),
("raw", models.TextField()),
(
"agent",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.agent",
),
),
(
"search_model",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="database.searchmodelconfig",
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="userconversationconfig",
name="enable_memory",
field=models.BooleanField(default=True),
),
migrations.AddField(
model_name="serverchatsettings",
name="memory_mode",
field=models.CharField(
choices=[
("disabled", "Disabled"),
("enabled_default_off", "Enabled, default off"),
("enabled_default_on", "Enabled, default on"),
],
default="enabled_default_on",
help_text="Server-level memory feature configuration. Disabled overrides user preference.",
max_length=20,
),
),
]

View File

@@ -478,6 +478,13 @@ class ServerChatSettings(DbBaseModel):
THINK_PAID_FAST = "think_paid_fast"
THINK_PAID_DEEP = "think_paid_deep"
class MemoryMode(models.TextChoices):
"""Enum for server-level memory feature configuration"""
DISABLED = "disabled", "Disabled"
ENABLED_DEFAULT_OFF = "enabled_default_off", "Enabled, default off"
ENABLED_DEFAULT_ON = "enabled_default_on", "Enabled, default on"
chat_default = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
@@ -506,6 +513,12 @@ class ServerChatSettings(DbBaseModel):
unique=True,
help_text="Priority of the server chat settings. Lower numbers run first.",
)
memory_mode = models.CharField(
max_length=20,
choices=MemoryMode.choices,
default=MemoryMode.ENABLED_DEFAULT_ON,
help_text="Server-level memory feature configuration. Disabled overrides user preference.",
)
def clean(self):
error = {}
@@ -629,6 +642,7 @@ class SpeechToTextModelOptions(DbBaseModel):
class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
enable_memory = models.BooleanField(default=True)
class UserVoiceModelConfig(DbBaseModel):
@@ -836,3 +850,15 @@ class McpServer(DbBaseModel):
def __str__(self):
return self.name
class UserMemory(DbBaseModel):
"""
Long term memory store derived from conversation between user and agent.
"""
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=None)
raw = models.TextField()
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)

View File

@@ -1302,3 +1302,70 @@ user_name = PromptTemplate.from_template(
User's Name: {name}
""".strip()
)
extract_facts_from_query = PromptTemplate.from_template(
"""
You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user.
This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember.
You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action:
1. Create new facts
2. Delete existing facts
You should delete existing facts that are no longer true.
You can enhance new facts with information from existing facts.
You cannot update existing facts directly, instead create new facts and delete related existing ones to update them.
Your output should be a JSON object with two lists: create and delete.
- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective.
- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true.
- Leave the create or delete list empty if you have nothing important to add or remove.
# Example
Existing Facts:
[
{{
"id": "5283",
"raw": "I am not interested in sports",
"updated_at": "2023-10-01T12:00:00+00:00"
}},
{{
"id": "22",
"raw": "I am a software engineer",
"updated_at": "2023-10-31T14:00:00+00:00"
}},
{{
"id": "651",
"raw": "My mother works at the hospital",
"updated_at": "2023-10-02T17:00:00+00:00"
}}
]
Latest Chat Session:
- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline.
In between coding, I took my cat Whiskers out for a walk and played a game of football.
My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat.
- AI: That's great to hear!
Response:
{{
"create": [
"I am interested in AI and machine learning",
"I have a pet cat named Whiskers",
"I enjoy playing football",
"My mother works at the hospital and is a doctor"
],
"delete": [
"5283",
"651"
],
}}
# Input
Existing Facts:
{matched_facts}
Latest Chat Session:
{chat_history}
""".strip()
)

View File

@@ -28,6 +28,7 @@ from khoj.database.models import (
ClientApplication,
Intent,
KhojUser,
UserMemory,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.base_filter import BaseFilter
@@ -551,6 +552,7 @@ async def save_to_conversation_log(
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
relevant_memories: List[UserMemory] = [],
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
@@ -559,6 +561,8 @@ async def save_to_conversation_log(
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
from khoj.routers.helpers import ai_update_memories
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
@@ -605,6 +609,16 @@ async def save_to_conversation_log(
user_message=q,
)
if not automation_id:
# Don't update memories from automations, as this could get noisy.
await ai_update_memories(
user=user,
conversation_history=new_messages or [],
memories=relevant_memories,
agent=db_conversation.agent if db_conversation else None,
tracer=tracer,
)
if is_promptrace_enabled():
merge_message_into_conversation_trace(q, chat_response, tracer)
@@ -666,6 +680,7 @@ def generate_chatml_messages_with_context(
query_files: str = None,
query_images=None,
context_message="",
relevant_memories: List[UserMemory] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
chat_history: list[ChatMessageModel] = [],
@@ -806,6 +821,14 @@ def generate_chatml_messages_with_context(
),
)
if not is_none_or_empty(relevant_memories):
memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n<retrieved_memories>\n"
for memory in relevant_memories:
friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S")
memory_context += f"- [{friendly_dt}]: {memory.raw}\n"
memory_context += "</retrieved_memories>"
messages.append(ChatMessage(content=memory_context, role="user"))
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(

View File

@@ -25,6 +25,7 @@ from khoj.database.models import (
Intent,
KhojUser,
TextToImageModelConfig,
UserMemory,
)
from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url
@@ -47,6 +48,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
agent: Agent = None,
tracer: dict = {},
):
@@ -91,6 +93,7 @@ async def text_to_image(
model_type=text_to_image_config.model_type,
query_images=query_images,
query_files=query_files,
relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,

View File

@@ -5,7 +5,7 @@ import os
from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser, UserMemory
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
@@ -42,6 +42,7 @@ async def operate_environment(
query_images: Optional[List[str]] = None, # TODO: Handle query images
agent: Agent = None,
query_files: str = None, # TODO: Handle query files
relevant_memories: Optional[List[UserMemory]] = None, # TODO: Handle relevant memories
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: Optional[str] = ChatEvent.END_EVENT.value,

View File

@@ -14,6 +14,7 @@ from khoj.database.models import (
Agent,
ChatMessageModel,
KhojUser,
UserMemory,
WebScraper,
)
from khoj.routers.helpers import (
@@ -63,6 +64,7 @@ async def search_online(
max_webpages_to_read: int = 1,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
previous_subqueries: Set = set(),
fast_model: bool = True,
agent: Agent = None,
@@ -82,6 +84,7 @@ async def search_online(
user,
query_images=query_images,
query_files=query_files,
relevant_memories=relevant_memories,
max_queries=max_online_searches,
fast_model=fast_model,
agent=agent,
@@ -160,7 +163,15 @@ async def search_online(
async for event in send_status_func(f"**Browsing**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
extract_from_webpage(link, data["queries"], data.get("content"), user=user, agent=agent, tracer=tracer)
extract_from_webpage(
link,
data["queries"],
data.get("content"),
relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,
)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@@ -438,6 +449,7 @@ async def read_webpages(
fast_model: bool = True,
agent: Agent = None,
max_webpages_to_read: int = 1,
relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
@@ -460,6 +472,7 @@ async def read_webpages(
user,
send_status_func=send_status_func,
agent=agent,
relevant_memories=relevant_memories,
tracer=tracer,
):
yield result
@@ -471,6 +484,7 @@ async def read_webpages_content(
user: KhojUser,
send_status_func: Optional[Callable] = None,
agent: Agent = None,
relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
logger.info(f"Reading web pages at: {urls}")
@@ -478,7 +492,10 @@ async def read_webpages_content(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Browsing**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [extract_from_webpage(url, {query}, user=user, agent=agent, tracer=tracer) for url in urls]
tasks = [
extract_from_webpage(url, {query}, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer)
for url in urls
]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@@ -533,6 +550,7 @@ async def extract_from_webpage(
url: str,
subqueries: set[str] = None,
content: str = None,
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -546,7 +564,9 @@ async def extract_from_webpage(
extracted_info = None
if not is_none_or_empty(content):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent, tracer=tracer)
extracted_info = await extract_relevant_info(
subqueries, content, relevant_memories=relevant_memories, user=user, agent=agent, tracer=tracer
)
return subqueries, url, extracted_info

View File

@@ -21,7 +21,7 @@ from tenacity import (
)
from khoj.database.adapters import AgentAdapters, FileObjectAdapters
from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser
from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser, UserMemory
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
@@ -59,6 +59,7 @@ async def run_code(
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
tracer: dict = {},
@@ -79,6 +80,7 @@ async def run_code(
agent,
tracer,
query_files,
relevant_memories,
)
except Exception as e:
raise ValueError(f"Failed to generate code for {instructions} with error: {e}")
@@ -126,6 +128,7 @@ async def generate_python_code(
agent: Agent = None,
tracer: dict = {},
query_files: str = None,
relevant_memories: List[UserMemory] = None,
) -> GeneratedCode:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@@ -161,6 +164,7 @@ async def generate_python_code(
code_generation_prompt,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,

View File

@@ -13,7 +13,7 @@ from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo
from khoj.database.models import KhojUser, SpeechToTextModelOptions
from khoj.database.models import KhojUser, SpeechToTextModelOptions, UserConversationConfig
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.routers.helpers import (
ApiUserRateLimiter,
@@ -215,6 +215,29 @@ def set_user_name(
return {"status": "ok"}
@api.patch("/user/memory", status_code=200)
@requires(["authenticated"])
def set_user_memory_enabled(
request: Request,
enable_memory: bool,
client: Optional[str] = None,
):
user = request.user.object
user_config, _ = UserConversationConfig.objects.get_or_create(user=user)
user_config.enable_memory = enable_memory
user_config.save()
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_user_memory_enabled",
client=client,
)
return {"status": "ok", "enable_memory": enable_memory}
@api.get("/health", response_class=Response)
@requires(["authenticated"], status_code=200)
def health_check(request: Request) -> Response:

View File

@@ -29,6 +29,7 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
PublicConversationAdapters,
UserMemoryAdapters,
aget_user_name,
)
from khoj.database.models import Agent, KhojUser
@@ -101,7 +102,6 @@ conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
)
api_chat = APIRouter()
@@ -963,6 +963,14 @@ async def event_generator(
location = LocationData(city=city, region=region, country=country, country_code=country_code)
chat_history = conversation.messages
# Get most recent memories and long term relevant memories if memory is enabled
relevant_memories = []
if await ConversationAdapters.ais_memory_enabled(user):
recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent)
long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent)
# Create a de-duped set of memories
relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values())
# If interrupted message in DB
if last_message := await conversation.pop_message(interrupted=True):
# Populate context from interrupted message
@@ -987,6 +995,7 @@ async def event_generator(
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
relevant_memories=relevant_memories,
tracer=tracer,
)
except ValueError as e:
@@ -1024,6 +1033,7 @@ async def event_generator(
previous_iterations=list(research_results),
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
user_name=user_name,
location=location,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1081,6 +1091,7 @@ async def event_generator(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1130,6 +1141,7 @@ async def event_generator(
max_online_searches=3,
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1157,6 +1169,7 @@ async def event_generator(
max_webpages_to_read=1,
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1198,6 +1211,7 @@ async def event_generator(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1223,6 +1237,7 @@ async def event_generator(
list(operator_results)[-1] if operator_results else None,
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent,
cancellation_event=cancellation_event,
@@ -1276,6 +1291,7 @@ async def event_generator(
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1317,6 +1333,7 @@ async def event_generator(
online_results=online_results,
query_images=uploaded_images,
query_files=attached_file_context,
relevant_memories=relevant_memories,
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1375,6 +1392,7 @@ async def event_generator(
user_name,
uploaded_images,
attached_file_context,
relevant_memories,
program_execution_context,
generated_asset_results,
is_subscribed,
@@ -1434,6 +1452,7 @@ async def event_generator(
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
relevant_memories=relevant_memories,
generated_images=generated_images,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,

View File

@@ -0,0 +1,114 @@
import json
import logging
from typing import Optional
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from fastapi.responses import Response
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.database.adapters import UserMemoryAdapters
from khoj.database.models import UserMemory
api_memories = APIRouter()
logger = logging.getLogger(__name__)
@api_memories.get("")
@requires(["authenticated"])
async def get_memories(
request: Request,
client: Optional[str] = None,
):
"""Get all memories for the authenticated user"""
user = request.user.object
memories = UserMemory.objects.filter(user=user)
all_memories = await sync_to_async(list)(memories)
# Convert memories to a list of dictionaries
formatted_memories = [
{
"id": memory.id,
"raw": memory.raw,
"created_at": memory.created_at.isoformat(),
}
for memory in all_memories
]
return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200)
@api_memories.delete("/{memory_id}")
@requires(["authenticated"])
async def delete_memory(
request: Request,
memory_id: int,
client: Optional[str] = None,
):
"""Delete a specific memory by ID"""
user = request.user.object
# Verify memory belongs to user before deleting
memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst()
if not memory:
return Response(
content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404
)
await memory.adelete()
return Response(status_code=204)
class UpdateMemoryBody(BaseModel):
"""Request model for updating a memory"""
raw: str
@api_memories.put("/{memory_id}")
@requires(["authenticated"])
async def update_memory(
request: Request,
body: UpdateMemoryBody,
memory_id: int,
client: Optional[str] = None,
):
"""Update a specific memory's content"""
user = request.user.object
# Get the memory and verify it belongs to the user
memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst()
if not memory:
return Response(
content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404
)
new_content = body.raw
if not new_content:
return Response(
content=json.dumps({"error": "Missing required field 'raw'"}),
media_type="application/json",
status_code=400,
)
await memory.adelete()
# Create a new memory with the updated content
new_memory = await UserMemoryAdapters.save_memory(
user=user,
memory=new_content,
)
return Response(
content=json.dumps(
{
"id": new_memory.id,
"raw": new_memory.raw,
}
),
media_type="application/json",
status_code=200,
)

View File

@@ -46,6 +46,7 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
UserMemoryAdapters,
aget_user_by_email,
create_khoj_token,
get_default_search_model,
@@ -53,6 +54,7 @@ from khoj.database.adapters import (
get_user_name,
get_user_notion_config,
get_user_subscription_state,
require_valid_user,
run_with_process_lock,
)
from khoj.database.models import (
@@ -66,8 +68,10 @@ from khoj.database.models import (
NotionConfig,
ProcessLock,
RateLimitRecord,
ServerChatSettings,
Subscription,
TextToImageModelConfig,
UserMemory,
UserRequests,
)
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
@@ -354,6 +358,7 @@ async def aget_data_sources_and_output_format(
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> Dict[str, Any]:
"""
@@ -415,6 +420,7 @@ async def aget_data_sources_and_output_format(
relevant_tools_prompt,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
response_type="json_object",
response_schema=PickTools,
fast_model=True,
@@ -470,6 +476,7 @@ async def infer_webpage_urls(
user: KhojUser,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
fast_model: bool = True,
agent: Agent = None,
tracer: dict = {},
@@ -506,6 +513,7 @@ async def infer_webpage_urls(
online_queries_prompt,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
response_type="json_object",
response_schema=WebpageUrls,
fast_model=fast_model,
@@ -536,6 +544,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
max_queries: int = 3,
fast_model: bool = True,
agent: Agent = None,
@@ -573,6 +582,7 @@ async def generate_online_subqueries(
online_queries_prompt,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
response_type="json_object",
response_schema=OnlineQueries,
fast_model=fast_model,
@@ -659,7 +669,12 @@ async def aschedule_query(
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
qs: set[str],
corpus: str,
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@@ -683,6 +698,7 @@ async def extract_relevant_info(
response = await send_message_to_model_wrapper(
extract_relevant_information,
system_message=prompts.system_prompt_extract_relevant_information,
relevant_memories=relevant_memories,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
@@ -802,6 +818,7 @@ async def generate_excalidraw_diagram(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -819,6 +836,7 @@ async def generate_excalidraw_diagram(
online_results=online_results,
query_images=query_images,
query_files=query_files,
relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,
@@ -854,6 +872,7 @@ async def generate_better_diagram_description(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -899,6 +918,7 @@ async def generate_better_diagram_description(
improve_diagram_description_prompt,
query_images=query_images,
query_files=query_files,
relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
@@ -956,6 +976,88 @@ async def generate_excalidraw_diagram_from_description(
return response
class MemoryUpdates(BaseModel):
"""Facts to add or remove from memory."""
create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.")
delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.")
async def extract_facts_from_query(
user: KhojUser,
conversation_history: List[ChatMessageModel],
existing_facts: List[UserMemory] = None,
agent: Agent = None,
tracer: dict = {},
) -> MemoryUpdates:
"""
Extract facts from the given query
"""
chat_history = construct_chat_history(conversation_history, n=2)
formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else []
extract_facts_prompt = prompts.extract_facts_from_query.format(
chat_history=chat_history,
matched_facts=formatted_memories,
)
with timer("Chat actor: Extract facts from query", logger):
response = await send_message_to_model_wrapper(
extract_facts_prompt,
response_schema=MemoryUpdates,
user=user,
fast_model=False,
agent_chat_model=agent.chat_model,
tracer=tracer,
)
response = response.text.strip()
# JSON parse the list of strings
try:
response = clean_json(response)
response = json.loads(response)
parsed_response = MemoryUpdates(**response)
if not isinstance(parsed_response, MemoryUpdates):
raise ValueError(f"Invalid response for extracting facts: {response}")
return parsed_response
except Exception:
logger.error(f"Invalid response for extracting facts: {response}")
return MemoryUpdates(create=[], delete=[])
@require_valid_user
async def ai_update_memories(
user: KhojUser,
conversation_history: List[ChatMessageModel],
memories: List[UserMemory],
agent: Agent,
tracer: dict = {},
):
"""
Updates the memories for a given user, based on their latest input query.
"""
# Skip memory updates if memory is disabled for the user
if not await ConversationAdapters.ais_memory_enabled(user):
return
memory_update = await extract_facts_from_query(
user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer
)
if not memory_update:
return
# Save the memory updates to the database
for memory in memory_update.create:
logger.info(f"Creating memory: {memory}")
await UserMemoryAdapters.save_memory(user, memory, agent=agent)
for memory in memory_update.delete:
logger.info(f"Deleting memory: {memory}")
await UserMemoryAdapters.delete_memory(user, memory)
async def generate_mermaidjs_diagram(
q: str,
chat_history: List[ChatMessageModel],
@@ -964,6 +1066,7 @@ async def generate_mermaidjs_diagram(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -981,6 +1084,7 @@ async def generate_mermaidjs_diagram(
online_results=online_results,
query_images=query_images,
query_files=query_files,
relevant_memories=relevant_memories,
user=user,
agent=agent,
tracer=tracer,
@@ -1010,6 +1114,7 @@ async def generate_better_mermaidjs_diagram_description(
online_results: Optional[dict] = None,
query_images: List[str] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -1055,6 +1160,7 @@ async def generate_better_mermaidjs_diagram_description(
improve_diagram_description_prompt,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
@@ -1104,6 +1210,7 @@ async def generate_better_image_prompt(
model_type: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = "",
relevant_memories: List[UserMemory] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
@@ -1148,6 +1255,7 @@ async def generate_better_image_prompt(
q,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
chat_history=conversation_history,
system_message=enhance_image_system_message,
response_type="json_object",
@@ -1180,6 +1288,7 @@ async def search_documents(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = None,
previous_inferred_queries: Set = set(),
fast_model: bool = True,
agent: Agent = None,
@@ -1230,6 +1339,7 @@ async def search_documents(
user=user,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
personality_context=personality_context,
location_data=location_data,
chat_history=chat_history,
@@ -1283,6 +1393,7 @@ async def extract_questions(
query_files: str = None,
query_images: Optional[List[str]] = None,
personality_context: str = "",
relevant_memories: List[UserMemory] = None,
location_data: LocationData = None,
chat_history: List[ChatMessageModel] = [],
max_queries: int = 5,
@@ -1338,6 +1449,7 @@ async def extract_questions(
query=prompt,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
system_message=system_prompt,
response_type="json_object",
response_schema=DocumentQueries,
@@ -1518,6 +1630,7 @@ async def send_message_to_model_wrapper(
query_files: str = None,
query_images: List[str] = None,
context: str = "",
relevant_memories: List[UserMemory] = None,
chat_history: list[ChatMessageModel] = [],
system_message: str = "",
# Model Config
@@ -1568,6 +1681,7 @@ async def send_message_to_model_wrapper(
query_files=query_files,
query_images=query_images,
context_message=context,
relevant_memories=relevant_memories,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model.name,
@@ -1685,6 +1799,7 @@ def build_conversation_context(
operator_results: List[OperatorRun],
query_files: str = None,
query_images: Optional[List[str]] = None,
relevant_memories: List[UserMemory] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
chat_history: List[ChatMessageModel] = [],
@@ -1761,6 +1876,7 @@ def build_conversation_context(
query_files=query_files,
query_images=query_images,
context_message=context_message,
relevant_memories=relevant_memories,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
@@ -1789,6 +1905,7 @@ async def agenerate_chat_response(
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
relevant_memories: List[UserMemory] = [],
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
@@ -1831,6 +1948,7 @@ async def agenerate_chat_response(
operator_results=operator_results,
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
@@ -2792,6 +2910,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
selected_chat_model_config = ConversationAdapters.get_chat_model(
user
) or ConversationAdapters.get_default_chat_model(user)
server_chat_settings = ServerChatSettings.objects.first()
server_memory_mode = (
server_chat_settings.memory_mode
if server_chat_settings
else ServerChatSettings.MemoryMode.ENABLED_DEFAULT_ON.value # type: ignore[attr-defined]
)
enable_memory = ConversationAdapters.is_memory_enabled(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:
@@ -2848,6 +2973,8 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"enabled_content_source": enabled_content_sources,
"has_documents": has_documents,
"notion_token": notion_token,
"enable_memory": enable_memory,
"server_memory_mode": server_memory_mode,
# user model settings
"chat_model_options": chat_model_options,
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,

View File

@@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional
import yaml
from khoj.database.adapters import AgentAdapters, EntryAdapters, McpServerAdapters
from khoj.database.models import Agent, ChatMessageModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, KhojUser, UserMemory
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
@@ -285,6 +285,7 @@ async def apick_next_tool(
max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
relevant_memories: List[UserMemory] = [],
max_document_searches: int = 7,
max_online_searches: int = 3,
max_webpages_to_read: int = 3,
@@ -416,6 +417,7 @@ async def apick_next_tool(
query="",
query_files=query_files,
query_images=query_images,
relevant_memories=relevant_memories,
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
tools=tools,
@@ -479,6 +481,7 @@ async def research(
previous_iterations: List[ResearchIteration],
query_images: List[str],
query_files: str = None,
relevant_memories: List[UserMemory] = [],
user_name: str = None,
location: LocationData = None,
send_status_func: Optional[Callable] = None,
@@ -544,6 +547,7 @@ async def research(
MAX_ITERATIONS,
query_images=query_images,
query_files=query_files,
relevant_memories=relevant_memories,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
max_webpages_to_read=max_webpages_to_read,