Add a database lock for jobs that shouldn't be run by multiple workers (#706)

* Add a database lock for jobs that shouldn't be run by multiple workers

* Import relevant functions from utils.helpers
This commit is contained in:
sabaimran
2024-04-16 08:59:27 -07:00
committed by GitHub
parent adb2e8cc5f
commit 91c8b137f1
4 changed files with 88 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
import json import json
import logging import logging
import os import os
from datetime import datetime from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@@ -24,6 +24,7 @@ from khoj.database.adapters import (
AgentAdapters, AgentAdapters,
ClientApplicationAdapters, ClientApplicationAdapters,
ConversationAdapters, ConversationAdapters,
ProcessLockAdapters,
SubscriptionState, SubscriptionState,
aget_or_create_user_by_phone_number, aget_or_create_user_by_phone_number,
aget_user_by_phone_number, aget_user_by_phone_number,
@@ -32,14 +33,14 @@ from khoj.database.adapters import (
get_all_users, get_all_users,
get_or_create_search_models, get_or_create_search_models,
) )
from khoj.database.models import ClientApplication, KhojUser, Subscription from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search from khoj.routers.indexer import configure_content, configure_search
from khoj.routers.twilio import is_twilio_enabled from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled from khoj.utils.helpers import is_none_or_empty, telemetry_disabled, timer
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -306,18 +307,28 @@ def configure_middleware(app):
app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret")) app.add_middleware(SessionMiddleware, secret_key=os.environ.get("KHOJ_DJANGO_SECRET_KEY", "!secret"))
@schedule.repeat(schedule.every(22).to(26).hours) @schedule.repeat(schedule.every(22).to(25).hours)
def update_search_index(): def update_content_index():
try: try:
logger.info("📬 Updating content index via Scheduler") if ProcessLockAdapters.is_process_locked(ProcessLock.Operation.UPDATE_EMBEDDINGS):
for user in get_all_users(): logger.info("🔒 Skipping update content index due to lock")
all_files = collect_files(user=user) return
success = configure_content(all_files, user=user) ProcessLockAdapters.set_process_lock(
all_files = collect_files(user=None) ProcessLock.Operation.UPDATE_EMBEDDINGS, max_duration_in_seconds=60 * 60 * 2
success = configure_content(all_files, user=None) )
if not success:
raise RuntimeError("Failed to update content index") with timer("📬 Updating content index via Scheduler"):
for user in get_all_users():
all_files = collect_files(user=user)
success = configure_content(all_files, user=user)
all_files = collect_files(user=None)
success = configure_content(all_files, user=None)
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler") logger.info("📪 Content index updated via Scheduler")
ProcessLockAdapters.remove_process_lock(ProcessLock.Operation.UPDATE_EMBEDDINGS)
except Exception as e: except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True) logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)

View File

@@ -30,6 +30,7 @@ from khoj.database.models import (
NotionConfig, NotionConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ProcessLock,
ReflectiveQuestion, ReflectiveQuestion,
SearchModelConfig, SearchModelConfig,
SpeechToTextModelOptions, SpeechToTextModelOptions,
@@ -402,6 +403,32 @@ async def aget_user_search_model(user: KhojUser):
return config.setting return config.setting
class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
return ProcessLock.objects.filter(name=process_name).first()
@staticmethod
def set_process_lock(process_name: str, max_duration_in_seconds: int = 600):
return ProcessLock.objects.create(name=process_name, max_duration_in_seconds=max_duration_in_seconds)
@staticmethod
def is_process_locked(process_name: str):
process_lock = ProcessLock.objects.filter(name=process_name).first()
if not process_lock:
return False
if process_lock.started_at + timedelta(seconds=process_lock.max_duration_in_seconds) < datetime.now(
tz=timezone.utc
):
process_lock.delete()
return False
return True
@staticmethod
def remove_process_lock(process_name: str):
return ProcessLock.objects.filter(name=process_name).delete()
class ClientApplicationAdapters: class ClientApplicationAdapters:
@staticmethod @staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str): async def aget_client_application_by_id(client_id: str, client_secret: str):

View File

@@ -0,0 +1,26 @@
# Generated by Django 4.2.10 on 2024-04-15 08:48
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0034_alter_chatmodeloptions_chat_model"),
]
operations = [
migrations.CreateModel(
name="ProcessLock",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(choices=[("update_embeddings", "Update Embeddings")], max_length=200)),
("started_at", models.DateTimeField(auto_now_add=True)),
("max_duration_in_seconds", models.IntegerField(default=43200)),
],
options={
"abstract": False,
},
),
]

View File

@@ -98,6 +98,17 @@ class Agent(BaseModel):
slug = models.CharField(max_length=200) slug = models.CharField(max_length=200)
class ProcessLock(BaseModel):
class Operation(models.TextChoices):
UPDATE_EMBEDDINGS = "update_embeddings"
# We need to make sure that some operations are thread-safe. To do so, add locks for potentially shared operations.
# For example, we need to make sure that only one process is updating the embeddings at a time.
name = models.CharField(max_length=200, choices=Operation.choices)
started_at = models.DateTimeField(auto_now_add=True)
max_duration_in_seconds = models.IntegerField(default=60 * 60 * 12) # 12 hours
@receiver(pre_save, sender=Agent) @receiver(pre_save, sender=Agent)
def verify_agent(sender, instance, **kwargs): def verify_agent(sender, instance, **kwargs):
# check if this is a new instance # check if this is a new instance