Update agent knowledge base and configuration atomically

This should help prevent partial updates to agent. Especially useful
for agent's with large knowledge bases being updated. Failing the call
should raise an exception. This will allow your to retry save instead
of losing your previous agent changes or saving only partial.
This commit is contained in:
Debanjum
2025-07-02 14:41:50 -07:00
parent e6cc9b1182
commit 9774bb012e

View File

@@ -24,6 +24,7 @@ import cron_descriptor
from apscheduler.job import Job
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
from django.db import transaction
from django.db.models import Prefetch, Q
from django.db.models.manager import BaseManager
from django.db.utils import IntegrityError
@@ -850,6 +851,73 @@ class AgentAdapters:
async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user)
@staticmethod
@transaction.atomic
@require_valid_user
def atomic_update_agent(
user: KhojUser,
name: str,
personality: str,
privacy_level: str,
icon: str,
color: str,
chat_model_option: ChatModel,
files: List[str],
input_tools: List[str],
output_modes: List[str],
slug: Optional[str] = None,
is_hidden: Optional[bool] = False,
):
agent, created = Agent.objects.filter(slug=slug, creator=user).update_or_create(
defaults={
"name": name,
"creator": user,
"personality": personality,
"privacy_level": privacy_level,
"style_icon": icon,
"style_color": color,
"chat_model": chat_model_option,
"input_tools": input_tools,
"output_modes": output_modes,
"is_hidden": is_hidden,
}
)
FileObject.objects.filter(agent=agent).delete()
Entry.objects.filter(agent=agent).delete()
new_file_objects = []
reference_files_qs = FileObject.objects.filter(file_name__in=files, user=agent.creator, agent=None)
for ref_file in reference_files_qs:
new_file_objects.append(FileObject(file_name=ref_file.file_name, agent=agent, raw_text=ref_file.raw_text))
if new_file_objects:
FileObject.objects.bulk_create(new_file_objects, batch_size=100)
entries_to_create = []
reference_entries_qs = Entry.objects.filter(file_path__in=files, user=agent.creator, agent=None)
for entry in reference_entries_qs:
entries_to_create.append(
Entry(
agent=agent,
embeddings=entry.embeddings,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading,
file_source=entry.file_source,
file_type=entry.file_type,
file_path=entry.file_path,
file_name=entry.file_name,
url=entry.url,
hashed_value=entry.hashed_value,
)
)
if entries_to_create:
Entry.objects.bulk_create(entries_to_create, batch_size=500)
return agent
@staticmethod
@arequire_valid_user
async def aupdate_agent(
@@ -870,54 +938,24 @@ class AgentAdapters:
chat_model = await ConversationAdapters.aget_default_chat_model(user)
chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
defaults={
"name": name,
"creator": user,
"personality": personality,
"privacy_level": privacy_level,
"style_icon": icon,
"style_color": color,
"chat_model": chat_model_option,
"input_tools": input_tools,
"output_modes": output_modes,
"is_hidden": is_hidden,
}
)
# Delete all existing files and entries
await FileObject.objects.filter(agent=agent).adelete()
await Entry.objects.filter(agent=agent).adelete()
for file in files:
reference_file = await FileObject.objects.filter(file_name=file, user=agent.creator).afirst()
if reference_file:
await FileObject.objects.acreate(file_name=file, agent=agent, raw_text=reference_file.raw_text)
# Duplicate all entries associated with the file
entries: List[Entry] = []
async for entry in Entry.objects.filter(file_path=file, user=agent.creator).aiterator():
entries.append(
Entry(
agent=agent,
embeddings=entry.embeddings,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading,
file_source=entry.file_source,
file_type=entry.file_type,
file_path=entry.file_path,
file_name=entry.file_name,
url=entry.url,
hashed_value=entry.hashed_value,
)
)
# Bulk create entries
await Entry.objects.abulk_create(entries)
return agent
try:
return await sync_to_async(AgentAdapters.atomic_update_agent, thread_sensitive=True)(
user=user,
name=name,
personality=personality,
privacy_level=privacy_level,
icon=icon,
color=color,
chat_model_option=chat_model_option,
files=files,
input_tools=input_tools,
output_modes=output_modes,
slug=slug,
is_hidden=is_hidden,
)
except Exception as e:
logger.error(f"Error updating agent: {e}", exc_info=True)
raise
@staticmethod
@arequire_valid_user