mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Update agent knowledge base and configuration atomically
This should help prevent partial updates to agent. Especially useful for agent's with large knowledge bases being updated. Failing the call should raise an exception. This will allow your to retry save instead of losing your previous agent changes or saving only partial.
This commit is contained in:
@@ -24,6 +24,7 @@ import cron_descriptor
|
||||
from apscheduler.job import Job
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from django.db import transaction
|
||||
from django.db.models import Prefetch, Q
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.utils import IntegrityError
|
||||
@@ -850,6 +851,73 @@ class AgentAdapters:
|
||||
async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
|
||||
return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
@require_valid_user
|
||||
def atomic_update_agent(
|
||||
user: KhojUser,
|
||||
name: str,
|
||||
personality: str,
|
||||
privacy_level: str,
|
||||
icon: str,
|
||||
color: str,
|
||||
chat_model_option: ChatModel,
|
||||
files: List[str],
|
||||
input_tools: List[str],
|
||||
output_modes: List[str],
|
||||
slug: Optional[str] = None,
|
||||
is_hidden: Optional[bool] = False,
|
||||
):
|
||||
agent, created = Agent.objects.filter(slug=slug, creator=user).update_or_create(
|
||||
defaults={
|
||||
"name": name,
|
||||
"creator": user,
|
||||
"personality": personality,
|
||||
"privacy_level": privacy_level,
|
||||
"style_icon": icon,
|
||||
"style_color": color,
|
||||
"chat_model": chat_model_option,
|
||||
"input_tools": input_tools,
|
||||
"output_modes": output_modes,
|
||||
"is_hidden": is_hidden,
|
||||
}
|
||||
)
|
||||
|
||||
FileObject.objects.filter(agent=agent).delete()
|
||||
Entry.objects.filter(agent=agent).delete()
|
||||
|
||||
new_file_objects = []
|
||||
reference_files_qs = FileObject.objects.filter(file_name__in=files, user=agent.creator, agent=None)
|
||||
for ref_file in reference_files_qs:
|
||||
new_file_objects.append(FileObject(file_name=ref_file.file_name, agent=agent, raw_text=ref_file.raw_text))
|
||||
|
||||
if new_file_objects:
|
||||
FileObject.objects.bulk_create(new_file_objects, batch_size=100)
|
||||
|
||||
entries_to_create = []
|
||||
reference_entries_qs = Entry.objects.filter(file_path__in=files, user=agent.creator, agent=None)
|
||||
for entry in reference_entries_qs:
|
||||
entries_to_create.append(
|
||||
Entry(
|
||||
agent=agent,
|
||||
embeddings=entry.embeddings,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading,
|
||||
file_source=entry.file_source,
|
||||
file_type=entry.file_type,
|
||||
file_path=entry.file_path,
|
||||
file_name=entry.file_name,
|
||||
url=entry.url,
|
||||
hashed_value=entry.hashed_value,
|
||||
)
|
||||
)
|
||||
|
||||
if entries_to_create:
|
||||
Entry.objects.bulk_create(entries_to_create, batch_size=500)
|
||||
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aupdate_agent(
|
||||
@@ -870,54 +938,24 @@ class AgentAdapters:
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
|
||||
|
||||
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
|
||||
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
|
||||
defaults={
|
||||
"name": name,
|
||||
"creator": user,
|
||||
"personality": personality,
|
||||
"privacy_level": privacy_level,
|
||||
"style_icon": icon,
|
||||
"style_color": color,
|
||||
"chat_model": chat_model_option,
|
||||
"input_tools": input_tools,
|
||||
"output_modes": output_modes,
|
||||
"is_hidden": is_hidden,
|
||||
}
|
||||
)
|
||||
|
||||
# Delete all existing files and entries
|
||||
await FileObject.objects.filter(agent=agent).adelete()
|
||||
await Entry.objects.filter(agent=agent).adelete()
|
||||
|
||||
for file in files:
|
||||
reference_file = await FileObject.objects.filter(file_name=file, user=agent.creator).afirst()
|
||||
if reference_file:
|
||||
await FileObject.objects.acreate(file_name=file, agent=agent, raw_text=reference_file.raw_text)
|
||||
|
||||
# Duplicate all entries associated with the file
|
||||
entries: List[Entry] = []
|
||||
async for entry in Entry.objects.filter(file_path=file, user=agent.creator).aiterator():
|
||||
entries.append(
|
||||
Entry(
|
||||
agent=agent,
|
||||
embeddings=entry.embeddings,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading,
|
||||
file_source=entry.file_source,
|
||||
file_type=entry.file_type,
|
||||
file_path=entry.file_path,
|
||||
file_name=entry.file_name,
|
||||
url=entry.url,
|
||||
hashed_value=entry.hashed_value,
|
||||
)
|
||||
)
|
||||
|
||||
# Bulk create entries
|
||||
await Entry.objects.abulk_create(entries)
|
||||
|
||||
return agent
|
||||
try:
|
||||
return await sync_to_async(AgentAdapters.atomic_update_agent, thread_sensitive=True)(
|
||||
user=user,
|
||||
name=name,
|
||||
personality=personality,
|
||||
privacy_level=privacy_level,
|
||||
icon=icon,
|
||||
color=color,
|
||||
chat_model_option=chat_model_option,
|
||||
files=files,
|
||||
input_tools=input_tools,
|
||||
output_modes=output_modes,
|
||||
slug=slug,
|
||||
is_hidden=is_hidden,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
|
||||
Reference in New Issue
Block a user