Make scheduled jobs persistent and work in multiple worker setups

- Store scheduled job state in Postgres so job schedules persist
  across app restarts
- Use Process Locks to only allow single worker to process a given job
  type. This prevents duplicating job runs across all workers
This commit is contained in:
Debanjum Singh Solanky
2024-04-17 16:28:42 +05:30
parent fcf878e1f3
commit af0972c539
5 changed files with 144 additions and 40 deletions

View File

@@ -454,6 +454,13 @@ class ProcessLockAdapters:
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}") logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
def run_with_process_lock(*args):
"""Wrapper function used for scheduling jobs.
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
"""
return ProcessLockAdapters.run_with_lock(*args)
class ClientApplicationAdapters: class ClientApplicationAdapters:
@staticmethod @staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str): async def aget_client_application_by_id(client_id: str, client_secret: str):

View File

@@ -128,7 +128,20 @@ def run(should_start_server=True):
poll_task_scheduler() poll_task_scheduler()
# Setup Background Scheduler # Setup Background Scheduler
state.scheduler = BackgroundScheduler() from django.conf import settings as django_settings
django_db = django_settings.DATABASES["default"]
state.scheduler = BackgroundScheduler(
{
"apscheduler.jobstores.default": {
"type": "sqlalchemy",
"url": f'postgresql://{django_db["USER"]}:{django_db["PASSWORD"]}@{django_db["HOST"]}:{django_db["PORT"]}/{django_db["NAME"]}',
},
"apscheduler.timezone": "UTC",
"apscheduler.job_defaults.misfire_grace_time": "60", # Useful to run scheduled jobs even when worker delayed because it was busy or down
"apscheduler.job_defaults.coalesce": "true", # Combine multiple jobs into one if they are scheduled at the same time
}
)
state.scheduler.start() state.scheduler.start()
# Start Server # Start Server
@@ -150,6 +163,9 @@ def run(should_start_server=True):
if should_start_server: if should_start_server:
start_server(app, host=args.host, port=args.port, socket=args.socket) start_server(app, host=args.host, port=args.port, socket=args.socket)
# Teardown
state.scheduler.shutdown()
def set_state(args): def set_state(args):
state.config_file = args.config_file state.config_file = args.config_file

View File

@@ -102,6 +102,7 @@ def save_to_conversation_log(
intent_type: str = "remember", intent_type: str = "remember",
client_application: ClientApplication = None, client_application: ClientApplication = None,
conversation_id: int = None, conversation_id: int = None,
job_id: str = None,
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log( updated_conversation = message_to_log(
@@ -112,6 +113,7 @@ def save_to_conversation_log(
"context": compiled_references, "context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type}, "intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results, "onlineContext": online_results,
"jobId": job_id,
}, },
conversation_log=meta_log.get("chat", []), conversation_log=meta_log.get("chat", []),
) )

View File

@@ -1,3 +1,4 @@
import functools
import json import json
import logging import logging
import math import math
@@ -13,8 +14,13 @@ from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK from websockets import ConnectionClosedOK
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name from khoj.database.adapters import (
from khoj.database.models import KhojUser ConversationAdapters,
EntryAdapters,
aget_user_name,
run_with_process_lock,
)
from khoj.database.models import KhojUser, ProcessLock
from khoj.processor.conversation.prompts import ( from khoj.processor.conversation.prompts import (
help_message, help_message,
no_entries_found, no_entries_found,
@@ -38,6 +44,7 @@ from khoj.routers.helpers import (
get_conversation_command, get_conversation_command,
is_ready_to_chat, is_ready_to_chat,
schedule_query, schedule_query,
scheduled_chat,
text_to_image, text_to_image,
update_telemetry_state, update_telemetry_state,
validate_conversation_config, validate_conversation_config,
@@ -386,35 +393,40 @@ async def websocket_endpoint(
if ConversationCommand.Reminder in conversation_commands: if ConversationCommand.Reminder in conversation_commands:
crontime, inferred_query = await schedule_query(q, location, meta_log) crontime, inferred_query = await schedule_query(q, location, meta_log)
trigger = CronTrigger.from_crontab(crontime) try:
common = CommonQueryParamsClass( trigger = CronTrigger.from_crontab(crontime)
client=websocket.user.client_app, except ValueError as e:
user_agent=websocket.headers.get("user-agent"), await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}")
host=websocket.headers.get("host"), continue
partial_scheduled_chat = functools.partial(
scheduled_chat, inferred_query, websocket.user.object, websocket.url
) )
scope = websocket.scope.copy() try:
scope["path"] = "/api/chat" job = state.scheduler.add_job(
scope["type"] = "http" run_with_process_lock,
request = Request(scope) trigger=trigger,
args=(
partial_scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}",
),
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
except:
await send_complete_llm_response(
f"Unable to schedule reminder. Ensure the reminder doesn't already exist."
)
continue
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
llm_response = f"""
### 🕒 Scheduled Job
- Query: **"{inferred_query}"**
- Schedule: `{crontime}`
- Next Run At: **{next_run_time}** UTC.
""".strip()
state.scheduler.add_job(
async_to_sync(chat),
trigger=trigger,
args=(request, common, inferred_query),
kwargs={
"stream": False,
"conversation_id": conversation_id,
"city": city,
"region": region,
"country": country,
},
id=f"job_{user.uuid}_{inferred_query}",
replace_existing=True,
)
llm_response = (
f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
)
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
llm_response, llm_response,
@@ -423,6 +435,13 @@ async def websocket_endpoint(
intent_type="reminder", intent_type="reminder",
client_application=websocket.user.client_app, client_application=websocket.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
inferred_queries=[inferred_query],
job_id=job.id,
)
common = CommonQueryParamsClass(
client=websocket.user.client_app,
user_agent=websocket.headers.get("user-agent"),
host=websocket.headers.get("host"),
) )
update_telemetry_state( update_telemetry_state(
request=websocket, request=websocket,
@@ -630,16 +649,41 @@ async def chat(
if ConversationCommand.Reminder in conversation_commands: if ConversationCommand.Reminder in conversation_commands:
crontime, inferred_query = await schedule_query(q, location, meta_log) crontime, inferred_query = await schedule_query(q, location, meta_log)
trigger = CronTrigger.from_crontab(crontime) try:
state.scheduler.add_job( trigger = CronTrigger.from_crontab(crontime)
async_to_sync(chat), except ValueError as e:
trigger=trigger, return Response(
args=(request, common, inferred_query, n, d, False, title, conversation_id, city, region, country), content=f"Unable to create reminder with crontime schedule: {crontime}",
id=f"job_{user.uuid}_{inferred_query}", media_type="text/plain",
replace_existing=True, status_code=500,
) )
partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, request.user.object, request.url)
try:
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}"),
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
except:
return Response(
content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.",
media_type="text/plain",
status_code=500,
)
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
llm_response = f"""
### 🕒 Scheduled Job
- Query: **"{inferred_query}"**
- Schedule: `{crontime}`
- Next Run At: **{next_run_time}** UTC.'
""".strip()
llm_response = f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
llm_response, llm_response,
@@ -648,6 +692,8 @@ async def chat(
intent_type="reminder", intent_type="reminder",
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
inferred_queries=[inferred_query],
job_id=job.id,
) )
if stream: if stream:

View File

@@ -17,13 +17,22 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from urllib.parse import parse_qs, urlencode
import openai import openai
import requests
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image from PIL import Image
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from starlette.requests import URL
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters from khoj.database.adapters import (
AgentAdapters,
ConversationAdapters,
EntryAdapters,
create_khoj_token,
get_khoj_tokens,
)
from khoj.database.models import ( from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
ClientApplication, ClientApplication,
@@ -779,3 +788,27 @@ class CommonQueryParamsClass:
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()] CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def scheduled_chat(query, user: KhojUser, calling_url: URL):
# Construct the URL, header for the chat API
scheme = "http" if calling_url.scheme == "http" or calling_url.scheme == "ws" else "https"
# Replace the original scheduling query with the scheduled query
query_dict = parse_qs(calling_url.query)
query_dict["q"] = [query]
# Convert the dictionary back into a query string
scheduled_query = urlencode(query_dict, doseq=True)
url = f"{scheme}://{calling_url.netloc}/api/chat?{scheduled_query}"
headers = {"User-Agent": "Khoj"}
if not state.anonymous_mode:
# Add authorization request header in non-anonymous mode
token = get_khoj_tokens(user)
if is_none_or_empty(token):
token = create_khoj_token(user)
else:
token = token[0]
headers["Authorization"] = f"Bearer {token}"
# Call the chat API endpoint with authenticated user token and query
return requests.get(url, headers=headers)