diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index ffd24caf..f3bbf4fe 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -24,6 +24,7 @@ import cron_descriptor from apscheduler.job import Job from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore +from django.db import transaction from django.db.models import Prefetch, Q from django.db.models.manager import BaseManager from django.db.utils import IntegrityError @@ -850,6 +851,73 @@ class AgentAdapters: async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]: return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user) + @staticmethod + @transaction.atomic + @require_valid_user + def atomic_update_agent( + user: KhojUser, + name: str, + personality: str, + privacy_level: str, + icon: str, + color: str, + chat_model_option: ChatModel, + files: List[str], + input_tools: List[str], + output_modes: List[str], + slug: Optional[str] = None, + is_hidden: Optional[bool] = False, + ): + agent, created = Agent.objects.filter(slug=slug, creator=user).update_or_create( + defaults={ + "name": name, + "creator": user, + "personality": personality, + "privacy_level": privacy_level, + "style_icon": icon, + "style_color": color, + "chat_model": chat_model_option, + "input_tools": input_tools, + "output_modes": output_modes, + "is_hidden": is_hidden, + } + ) + + FileObject.objects.filter(agent=agent).delete() + Entry.objects.filter(agent=agent).delete() + + new_file_objects = [] + reference_files_qs = FileObject.objects.filter(file_name__in=files, user=agent.creator, agent=None) + for ref_file in reference_files_qs: + new_file_objects.append(FileObject(file_name=ref_file.file_name, agent=agent, raw_text=ref_file.raw_text)) + + if new_file_objects: + FileObject.objects.bulk_create(new_file_objects, batch_size=100) + + entries_to_create = [] + reference_entries_qs = Entry.objects.filter(file_path__in=files, user=agent.creator, agent=None) + for entry in reference_entries_qs: + entries_to_create.append( + Entry( + agent=agent, + embeddings=entry.embeddings, + raw=entry.raw, + compiled=entry.compiled, + heading=entry.heading, + file_source=entry.file_source, + file_type=entry.file_type, + file_path=entry.file_path, + file_name=entry.file_name, + url=entry.url, + hashed_value=entry.hashed_value, + ) + ) + + if entries_to_create: + Entry.objects.bulk_create(entries_to_create, batch_size=500) + + return agent + @staticmethod @arequire_valid_user async def aupdate_agent( @@ -870,54 +938,24 @@ class AgentAdapters: chat_model = await ConversationAdapters.aget_default_chat_model(user) chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst() - # Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug - agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create( - defaults={ - "name": name, - "creator": user, - "personality": personality, - "privacy_level": privacy_level, - "style_icon": icon, - "style_color": color, - "chat_model": chat_model_option, - "input_tools": input_tools, - "output_modes": output_modes, - "is_hidden": is_hidden, - } - ) - - # Delete all existing files and entries - await FileObject.objects.filter(agent=agent).adelete() - await Entry.objects.filter(agent=agent).adelete() - - for file in files: - reference_file = await FileObject.objects.filter(file_name=file, user=agent.creator).afirst() - if reference_file: - await FileObject.objects.acreate(file_name=file, agent=agent, raw_text=reference_file.raw_text) - - # Duplicate all entries associated with the file - entries: List[Entry] = [] - async for entry in Entry.objects.filter(file_path=file, user=agent.creator).aiterator(): - entries.append( - Entry( - agent=agent, - embeddings=entry.embeddings, - raw=entry.raw, - compiled=entry.compiled, - heading=entry.heading, - file_source=entry.file_source, - file_type=entry.file_type, - file_path=entry.file_path, - file_name=entry.file_name, - url=entry.url, - hashed_value=entry.hashed_value, - ) - ) - - # Bulk create entries - await Entry.objects.abulk_create(entries) - - return agent + try: + return await sync_to_async(AgentAdapters.atomic_update_agent, thread_sensitive=True)( + user=user, + name=name, + personality=personality, + privacy_level=privacy_level, + icon=icon, + color=color, + chat_model_option=chat_model_option, + files=files, + input_tools=input_tools, + output_modes=output_modes, + slug=slug, + is_hidden=is_hidden, + ) + except Exception as e: + logger.error(f"Error updating agent: {e}", exc_info=True) + raise @staticmethod @arequire_valid_user