Make scheduled jobs persistent and work in multiple worker setups

- Store scheduled job state in Postgres so job schedules persist
  across app restarts
- Use Process Locks to only allow single worker to process a given job
  type. This prevents duplicating job runs across all workers
This commit is contained in:
Debanjum Singh Solanky
2024-04-17 16:28:42 +05:30
parent fcf878e1f3
commit af0972c539
5 changed files with 144 additions and 40 deletions

View File

@@ -454,6 +454,13 @@ class ProcessLockAdapters:
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
def run_with_process_lock(*args):
"""Wrapper function used for scheduling jobs.
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
"""
return ProcessLockAdapters.run_with_lock(*args)
class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):

View File

@@ -128,7 +128,20 @@ def run(should_start_server=True):
poll_task_scheduler()
# Setup Background Scheduler
state.scheduler = BackgroundScheduler()
from django.conf import settings as django_settings
django_db = django_settings.DATABASES["default"]
state.scheduler = BackgroundScheduler(
{
"apscheduler.jobstores.default": {
"type": "sqlalchemy",
"url": f'postgresql://{django_db["USER"]}:{django_db["PASSWORD"]}@{django_db["HOST"]}:{django_db["PORT"]}/{django_db["NAME"]}',
},
"apscheduler.timezone": "UTC",
"apscheduler.job_defaults.misfire_grace_time": "60", # Useful to run scheduled jobs even when worker delayed because it was busy or down
"apscheduler.job_defaults.coalesce": "true", # Combine multiple jobs into one if they are scheduled at the same time
}
)
state.scheduler.start()
# Start Server
@@ -150,6 +163,9 @@ def run(should_start_server=True):
if should_start_server:
start_server(app, host=args.host, port=args.port, socket=args.socket)
# Teardown
state.scheduler.shutdown()
def set_state(args):
state.config_file = args.config_file

View File

@@ -102,6 +102,7 @@ def save_to_conversation_log(
intent_type: str = "remember",
client_application: ClientApplication = None,
conversation_id: int = None,
job_id: str = None,
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@@ -112,6 +113,7 @@ def save_to_conversation_log(
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"jobId": job_id,
},
conversation_log=meta_log.get("chat", []),
)

View File

@@ -1,3 +1,4 @@
import functools
import json
import logging
import math
@@ -13,8 +14,13 @@ from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
from khoj.database.models import KhojUser
from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
aget_user_name,
run_with_process_lock,
)
from khoj.database.models import KhojUser, ProcessLock
from khoj.processor.conversation.prompts import (
help_message,
no_entries_found,
@@ -38,6 +44,7 @@ from khoj.routers.helpers import (
get_conversation_command,
is_ready_to_chat,
schedule_query,
scheduled_chat,
text_to_image,
update_telemetry_state,
validate_conversation_config,
@@ -386,35 +393,40 @@ async def websocket_endpoint(
if ConversationCommand.Reminder in conversation_commands:
crontime, inferred_query = await schedule_query(q, location, meta_log)
trigger = CronTrigger.from_crontab(crontime)
common = CommonQueryParamsClass(
client=websocket.user.client_app,
user_agent=websocket.headers.get("user-agent"),
host=websocket.headers.get("host"),
try:
trigger = CronTrigger.from_crontab(crontime)
except ValueError as e:
await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}")
continue
partial_scheduled_chat = functools.partial(
scheduled_chat, inferred_query, websocket.user.object, websocket.url
)
scope = websocket.scope.copy()
scope["path"] = "/api/chat"
scope["type"] = "http"
request = Request(scope)
try:
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(
partial_scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}",
),
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
except:
await send_complete_llm_response(
f"Unable to schedule reminder. Ensure the reminder doesn't already exist."
)
continue
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
llm_response = f"""
### 🕒 Scheduled Job
- Query: **"{inferred_query}"**
- Schedule: `{crontime}`
- Next Run At: **{next_run_time}** UTC.
""".strip()
state.scheduler.add_job(
async_to_sync(chat),
trigger=trigger,
args=(request, common, inferred_query),
kwargs={
"stream": False,
"conversation_id": conversation_id,
"city": city,
"region": region,
"country": country,
},
id=f"job_{user.uuid}_{inferred_query}",
replace_existing=True,
)
llm_response = (
f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
)
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
@@ -423,6 +435,13 @@ async def websocket_endpoint(
intent_type="reminder",
client_application=websocket.user.client_app,
conversation_id=conversation_id,
inferred_queries=[inferred_query],
job_id=job.id,
)
common = CommonQueryParamsClass(
client=websocket.user.client_app,
user_agent=websocket.headers.get("user-agent"),
host=websocket.headers.get("host"),
)
update_telemetry_state(
request=websocket,
@@ -630,16 +649,41 @@ async def chat(
if ConversationCommand.Reminder in conversation_commands:
crontime, inferred_query = await schedule_query(q, location, meta_log)
trigger = CronTrigger.from_crontab(crontime)
state.scheduler.add_job(
async_to_sync(chat),
trigger=trigger,
args=(request, common, inferred_query, n, d, False, title, conversation_id, city, region, country),
id=f"job_{user.uuid}_{inferred_query}",
replace_existing=True,
)
try:
trigger = CronTrigger.from_crontab(crontime)
except ValueError as e:
return Response(
content=f"Unable to create reminder with crontime schedule: {crontime}",
media_type="text/plain",
status_code=500,
)
partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, request.user.object, request.url)
try:
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}"),
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
except:
return Response(
content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.",
media_type="text/plain",
status_code=500,
)
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
llm_response = f"""
### 🕒 Scheduled Job
- Query: **"{inferred_query}"**
- Schedule: `{crontime}`
- Next Run At: **{next_run_time}** UTC.'
""".strip()
llm_response = f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
await sync_to_async(save_to_conversation_log)(
q,
llm_response,
@@ -648,6 +692,8 @@ async def chat(
intent_type="reminder",
client_application=request.user.client_app,
conversation_id=conversation_id,
inferred_queries=[inferred_query],
job_id=job.id,
)
if stream:

View File

@@ -17,13 +17,22 @@ from typing import (
Tuple,
Union,
)
from urllib.parse import parse_qs, urlencode
import openai
import requests
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image
from starlette.authentication import has_required_scope
from starlette.requests import URL
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.adapters import (
AgentAdapters,
ConversationAdapters,
EntryAdapters,
create_khoj_token,
get_khoj_tokens,
)
from khoj.database.models import (
ChatModelOptions,
ClientApplication,
@@ -779,3 +788,27 @@ class CommonQueryParamsClass:
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def scheduled_chat(query, user: KhojUser, calling_url: URL):
# Construct the URL, header for the chat API
scheme = "http" if calling_url.scheme == "http" or calling_url.scheme == "ws" else "https"
# Replace the original scheduling query with the scheduled query
query_dict = parse_qs(calling_url.query)
query_dict["q"] = [query]
# Convert the dictionary back into a query string
scheduled_query = urlencode(query_dict, doseq=True)
url = f"{scheme}://{calling_url.netloc}/api/chat?{scheduled_query}"
headers = {"User-Agent": "Khoj"}
if not state.anonymous_mode:
# Add authorization request header in non-anonymous mode
token = get_khoj_tokens(user)
if is_none_or_empty(token):
token = create_khoj_token(user)
else:
token = token[0]
headers["Authorization"] = f"Bearer {token}"
# Call the chat API endpoint with authenticated user token and query
return requests.get(url, headers=headers)