mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Make scheduled jobs persistent and work in multiple worker setups
- Store scheduled job state in Postgres so job schedules persist across app restarts - Use Process Locks to only allow single worker to process a given job type. This prevents duplicating job runs across all workers
This commit is contained in:
@@ -454,6 +454,13 @@ class ProcessLockAdapters:
|
||||
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
|
||||
|
||||
|
||||
def run_with_process_lock(*args):
|
||||
"""Wrapper function used for scheduling jobs.
|
||||
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
|
||||
"""
|
||||
return ProcessLockAdapters.run_with_lock(*args)
|
||||
|
||||
|
||||
class ClientApplicationAdapters:
|
||||
@staticmethod
|
||||
async def aget_client_application_by_id(client_id: str, client_secret: str):
|
||||
|
||||
@@ -128,7 +128,20 @@ def run(should_start_server=True):
|
||||
poll_task_scheduler()
|
||||
|
||||
# Setup Background Scheduler
|
||||
state.scheduler = BackgroundScheduler()
|
||||
from django.conf import settings as django_settings
|
||||
|
||||
django_db = django_settings.DATABASES["default"]
|
||||
state.scheduler = BackgroundScheduler(
|
||||
{
|
||||
"apscheduler.jobstores.default": {
|
||||
"type": "sqlalchemy",
|
||||
"url": f'postgresql://{django_db["USER"]}:{django_db["PASSWORD"]}@{django_db["HOST"]}:{django_db["PORT"]}/{django_db["NAME"]}',
|
||||
},
|
||||
"apscheduler.timezone": "UTC",
|
||||
"apscheduler.job_defaults.misfire_grace_time": "60", # Useful to run scheduled jobs even when worker delayed because it was busy or down
|
||||
"apscheduler.job_defaults.coalesce": "true", # Combine multiple jobs into one if they are scheduled at the same time
|
||||
}
|
||||
)
|
||||
state.scheduler.start()
|
||||
|
||||
# Start Server
|
||||
@@ -150,6 +163,9 @@ def run(should_start_server=True):
|
||||
if should_start_server:
|
||||
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
||||
|
||||
# Teardown
|
||||
state.scheduler.shutdown()
|
||||
|
||||
|
||||
def set_state(args):
|
||||
state.config_file = args.config_file
|
||||
|
||||
@@ -102,6 +102,7 @@ def save_to_conversation_log(
|
||||
intent_type: str = "remember",
|
||||
client_application: ClientApplication = None,
|
||||
conversation_id: int = None,
|
||||
job_id: str = None,
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
@@ -112,6 +113,7 @@ def save_to_conversation_log(
|
||||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
"jobId": job_id,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@@ -13,8 +14,13 @@ from starlette.authentication import requires
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets import ConnectionClosedOK
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
aget_user_name,
|
||||
run_with_process_lock,
|
||||
)
|
||||
from khoj.database.models import KhojUser, ProcessLock
|
||||
from khoj.processor.conversation.prompts import (
|
||||
help_message,
|
||||
no_entries_found,
|
||||
@@ -38,6 +44,7 @@ from khoj.routers.helpers import (
|
||||
get_conversation_command,
|
||||
is_ready_to_chat,
|
||||
schedule_query,
|
||||
scheduled_chat,
|
||||
text_to_image,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
@@ -386,35 +393,40 @@ async def websocket_endpoint(
|
||||
|
||||
if ConversationCommand.Reminder in conversation_commands:
|
||||
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
||||
trigger = CronTrigger.from_crontab(crontime)
|
||||
common = CommonQueryParamsClass(
|
||||
client=websocket.user.client_app,
|
||||
user_agent=websocket.headers.get("user-agent"),
|
||||
host=websocket.headers.get("host"),
|
||||
try:
|
||||
trigger = CronTrigger.from_crontab(crontime)
|
||||
except ValueError as e:
|
||||
await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}")
|
||||
continue
|
||||
partial_scheduled_chat = functools.partial(
|
||||
scheduled_chat, inferred_query, websocket.user.object, websocket.url
|
||||
)
|
||||
scope = websocket.scope.copy()
|
||||
scope["path"] = "/api/chat"
|
||||
scope["type"] = "http"
|
||||
request = Request(scope)
|
||||
try:
|
||||
job = state.scheduler.add_job(
|
||||
run_with_process_lock,
|
||||
trigger=trigger,
|
||||
args=(
|
||||
partial_scheduled_chat,
|
||||
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}",
|
||||
),
|
||||
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
|
||||
name=f"{inferred_query}",
|
||||
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||
jitter=30,
|
||||
)
|
||||
except:
|
||||
await send_complete_llm_response(
|
||||
f"Unable to schedule reminder. Ensure the reminder doesn't already exist."
|
||||
)
|
||||
continue
|
||||
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
llm_response = f"""
|
||||
### 🕒 Scheduled Job
|
||||
- Query: **"{inferred_query}"**
|
||||
- Schedule: `{crontime}`
|
||||
- Next Run At: **{next_run_time}** UTC.
|
||||
""".strip()
|
||||
|
||||
state.scheduler.add_job(
|
||||
async_to_sync(chat),
|
||||
trigger=trigger,
|
||||
args=(request, common, inferred_query),
|
||||
kwargs={
|
||||
"stream": False,
|
||||
"conversation_id": conversation_id,
|
||||
"city": city,
|
||||
"region": region,
|
||||
"country": country,
|
||||
},
|
||||
id=f"job_{user.uuid}_{inferred_query}",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
llm_response = (
|
||||
f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
|
||||
)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
@@ -423,6 +435,13 @@ async def websocket_endpoint(
|
||||
intent_type="reminder",
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
inferred_queries=[inferred_query],
|
||||
job_id=job.id,
|
||||
)
|
||||
common = CommonQueryParamsClass(
|
||||
client=websocket.user.client_app,
|
||||
user_agent=websocket.headers.get("user-agent"),
|
||||
host=websocket.headers.get("host"),
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
@@ -630,16 +649,41 @@ async def chat(
|
||||
|
||||
if ConversationCommand.Reminder in conversation_commands:
|
||||
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
||||
trigger = CronTrigger.from_crontab(crontime)
|
||||
state.scheduler.add_job(
|
||||
async_to_sync(chat),
|
||||
trigger=trigger,
|
||||
args=(request, common, inferred_query, n, d, False, title, conversation_id, city, region, country),
|
||||
id=f"job_{user.uuid}_{inferred_query}",
|
||||
replace_existing=True,
|
||||
)
|
||||
try:
|
||||
trigger = CronTrigger.from_crontab(crontime)
|
||||
except ValueError as e:
|
||||
return Response(
|
||||
content=f"Unable to create reminder with crontime schedule: {crontime}",
|
||||
media_type="text/plain",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, request.user.object, request.url)
|
||||
try:
|
||||
job = state.scheduler.add_job(
|
||||
run_with_process_lock,
|
||||
trigger=trigger,
|
||||
args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}"),
|
||||
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
|
||||
name=f"{inferred_query}",
|
||||
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||
jitter=30,
|
||||
)
|
||||
except:
|
||||
return Response(
|
||||
content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.",
|
||||
media_type="text/plain",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
llm_response = f"""
|
||||
### 🕒 Scheduled Job
|
||||
- Query: **"{inferred_query}"**
|
||||
- Schedule: `{crontime}`
|
||||
- Next Run At: **{next_run_time}** UTC.'
|
||||
""".strip()
|
||||
|
||||
llm_response = f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
@@ -648,6 +692,8 @@ async def chat(
|
||||
intent_type="reminder",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
inferred_queries=[inferred_query],
|
||||
job_id=job.id,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
@@ -17,13 +17,22 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import parse_qs, urlencode
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from PIL import Image
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||
from khoj.database.adapters import (
|
||||
AgentAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
@@ -779,3 +788,27 @@ class CommonQueryParamsClass:
|
||||
|
||||
|
||||
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||
|
||||
|
||||
def scheduled_chat(query, user: KhojUser, calling_url: URL):
|
||||
# Construct the URL, header for the chat API
|
||||
scheme = "http" if calling_url.scheme == "http" or calling_url.scheme == "ws" else "https"
|
||||
# Replace the original scheduling query with the scheduled query
|
||||
query_dict = parse_qs(calling_url.query)
|
||||
query_dict["q"] = [query]
|
||||
# Convert the dictionary back into a query string
|
||||
scheduled_query = urlencode(query_dict, doseq=True)
|
||||
url = f"{scheme}://{calling_url.netloc}/api/chat?{scheduled_query}"
|
||||
|
||||
headers = {"User-Agent": "Khoj"}
|
||||
if not state.anonymous_mode:
|
||||
# Add authorization request header in non-anonymous mode
|
||||
token = get_khoj_tokens(user)
|
||||
if is_none_or_empty(token):
|
||||
token = create_khoj_token(user)
|
||||
else:
|
||||
token = token[0]
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
# Call the chat API endpoint with authenticated user token and query
|
||||
return requests.get(url, headers=headers)
|
||||
|
||||
Reference in New Issue
Block a user