mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Make scheduled jobs persistent and work in multiple worker setups
- Store scheduled job state in Postgres so job schedules persist across app restarts - Use Process Locks to only allow single worker to process a given job type. This prevents duplicating job runs across all workers
This commit is contained in:
@@ -454,6 +454,13 @@ class ProcessLockAdapters:
|
|||||||
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
|
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_with_process_lock(*args):
|
||||||
|
"""Wrapper function used for scheduling jobs.
|
||||||
|
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
|
||||||
|
"""
|
||||||
|
return ProcessLockAdapters.run_with_lock(*args)
|
||||||
|
|
||||||
|
|
||||||
class ClientApplicationAdapters:
|
class ClientApplicationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_client_application_by_id(client_id: str, client_secret: str):
|
async def aget_client_application_by_id(client_id: str, client_secret: str):
|
||||||
|
|||||||
@@ -128,7 +128,20 @@ def run(should_start_server=True):
|
|||||||
poll_task_scheduler()
|
poll_task_scheduler()
|
||||||
|
|
||||||
# Setup Background Scheduler
|
# Setup Background Scheduler
|
||||||
state.scheduler = BackgroundScheduler()
|
from django.conf import settings as django_settings
|
||||||
|
|
||||||
|
django_db = django_settings.DATABASES["default"]
|
||||||
|
state.scheduler = BackgroundScheduler(
|
||||||
|
{
|
||||||
|
"apscheduler.jobstores.default": {
|
||||||
|
"type": "sqlalchemy",
|
||||||
|
"url": f'postgresql://{django_db["USER"]}:{django_db["PASSWORD"]}@{django_db["HOST"]}:{django_db["PORT"]}/{django_db["NAME"]}',
|
||||||
|
},
|
||||||
|
"apscheduler.timezone": "UTC",
|
||||||
|
"apscheduler.job_defaults.misfire_grace_time": "60", # Useful to run scheduled jobs even when worker delayed because it was busy or down
|
||||||
|
"apscheduler.job_defaults.coalesce": "true", # Combine multiple jobs into one if they are scheduled at the same time
|
||||||
|
}
|
||||||
|
)
|
||||||
state.scheduler.start()
|
state.scheduler.start()
|
||||||
|
|
||||||
# Start Server
|
# Start Server
|
||||||
@@ -150,6 +163,9 @@ def run(should_start_server=True):
|
|||||||
if should_start_server:
|
if should_start_server:
|
||||||
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
||||||
|
|
||||||
|
# Teardown
|
||||||
|
state.scheduler.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def set_state(args):
|
def set_state(args):
|
||||||
state.config_file = args.config_file
|
state.config_file = args.config_file
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ def save_to_conversation_log(
|
|||||||
intent_type: str = "remember",
|
intent_type: str = "remember",
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
conversation_id: int = None,
|
conversation_id: int = None,
|
||||||
|
job_id: str = None,
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
@@ -112,6 +113,7 @@ def save_to_conversation_log(
|
|||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||||
"onlineContext": online_results,
|
"onlineContext": online_results,
|
||||||
|
"jobId": job_id,
|
||||||
},
|
},
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -13,8 +14,13 @@ from starlette.authentication import requires
|
|||||||
from starlette.websockets import WebSocketDisconnect
|
from starlette.websockets import WebSocketDisconnect
|
||||||
from websockets import ConnectionClosedOK
|
from websockets import ConnectionClosedOK
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
from khoj.database.adapters import (
|
||||||
from khoj.database.models import KhojUser
|
ConversationAdapters,
|
||||||
|
EntryAdapters,
|
||||||
|
aget_user_name,
|
||||||
|
run_with_process_lock,
|
||||||
|
)
|
||||||
|
from khoj.database.models import KhojUser, ProcessLock
|
||||||
from khoj.processor.conversation.prompts import (
|
from khoj.processor.conversation.prompts import (
|
||||||
help_message,
|
help_message,
|
||||||
no_entries_found,
|
no_entries_found,
|
||||||
@@ -38,6 +44,7 @@ from khoj.routers.helpers import (
|
|||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
schedule_query,
|
schedule_query,
|
||||||
|
scheduled_chat,
|
||||||
text_to_image,
|
text_to_image,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
@@ -386,35 +393,40 @@ async def websocket_endpoint(
|
|||||||
|
|
||||||
if ConversationCommand.Reminder in conversation_commands:
|
if ConversationCommand.Reminder in conversation_commands:
|
||||||
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
||||||
trigger = CronTrigger.from_crontab(crontime)
|
try:
|
||||||
common = CommonQueryParamsClass(
|
trigger = CronTrigger.from_crontab(crontime)
|
||||||
client=websocket.user.client_app,
|
except ValueError as e:
|
||||||
user_agent=websocket.headers.get("user-agent"),
|
await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}")
|
||||||
host=websocket.headers.get("host"),
|
continue
|
||||||
|
partial_scheduled_chat = functools.partial(
|
||||||
|
scheduled_chat, inferred_query, websocket.user.object, websocket.url
|
||||||
)
|
)
|
||||||
scope = websocket.scope.copy()
|
try:
|
||||||
scope["path"] = "/api/chat"
|
job = state.scheduler.add_job(
|
||||||
scope["type"] = "http"
|
run_with_process_lock,
|
||||||
request = Request(scope)
|
trigger=trigger,
|
||||||
|
args=(
|
||||||
|
partial_scheduled_chat,
|
||||||
|
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}",
|
||||||
|
),
|
||||||
|
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
|
||||||
|
name=f"{inferred_query}",
|
||||||
|
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||||
|
jitter=30,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
await send_complete_llm_response(
|
||||||
|
f"Unable to schedule reminder. Ensure the reminder doesn't already exist."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
llm_response = f"""
|
||||||
|
### 🕒 Scheduled Job
|
||||||
|
- Query: **"{inferred_query}"**
|
||||||
|
- Schedule: `{crontime}`
|
||||||
|
- Next Run At: **{next_run_time}** UTC.
|
||||||
|
""".strip()
|
||||||
|
|
||||||
state.scheduler.add_job(
|
|
||||||
async_to_sync(chat),
|
|
||||||
trigger=trigger,
|
|
||||||
args=(request, common, inferred_query),
|
|
||||||
kwargs={
|
|
||||||
"stream": False,
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"city": city,
|
|
||||||
"region": region,
|
|
||||||
"country": country,
|
|
||||||
},
|
|
||||||
id=f"job_{user.uuid}_{inferred_query}",
|
|
||||||
replace_existing=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_response = (
|
|
||||||
f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
|
|
||||||
)
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
llm_response,
|
llm_response,
|
||||||
@@ -423,6 +435,13 @@ async def websocket_endpoint(
|
|||||||
intent_type="reminder",
|
intent_type="reminder",
|
||||||
client_application=websocket.user.client_app,
|
client_application=websocket.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
inferred_queries=[inferred_query],
|
||||||
|
job_id=job.id,
|
||||||
|
)
|
||||||
|
common = CommonQueryParamsClass(
|
||||||
|
client=websocket.user.client_app,
|
||||||
|
user_agent=websocket.headers.get("user-agent"),
|
||||||
|
host=websocket.headers.get("host"),
|
||||||
)
|
)
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=websocket,
|
request=websocket,
|
||||||
@@ -630,16 +649,41 @@ async def chat(
|
|||||||
|
|
||||||
if ConversationCommand.Reminder in conversation_commands:
|
if ConversationCommand.Reminder in conversation_commands:
|
||||||
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
||||||
trigger = CronTrigger.from_crontab(crontime)
|
try:
|
||||||
state.scheduler.add_job(
|
trigger = CronTrigger.from_crontab(crontime)
|
||||||
async_to_sync(chat),
|
except ValueError as e:
|
||||||
trigger=trigger,
|
return Response(
|
||||||
args=(request, common, inferred_query, n, d, False, title, conversation_id, city, region, country),
|
content=f"Unable to create reminder with crontime schedule: {crontime}",
|
||||||
id=f"job_{user.uuid}_{inferred_query}",
|
media_type="text/plain",
|
||||||
replace_existing=True,
|
status_code=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, request.user.object, request.url)
|
||||||
|
try:
|
||||||
|
job = state.scheduler.add_job(
|
||||||
|
run_with_process_lock,
|
||||||
|
trigger=trigger,
|
||||||
|
args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}"),
|
||||||
|
id=f"job_{user.uuid}_{inferred_query}_{crontime}",
|
||||||
|
name=f"{inferred_query}",
|
||||||
|
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||||
|
jitter=30,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
return Response(
|
||||||
|
content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.",
|
||||||
|
media_type="text/plain",
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
llm_response = f"""
|
||||||
|
### 🕒 Scheduled Job
|
||||||
|
- Query: **"{inferred_query}"**
|
||||||
|
- Schedule: `{crontime}`
|
||||||
|
- Next Run At: **{next_run_time}** UTC.'
|
||||||
|
""".strip()
|
||||||
|
|
||||||
llm_response = f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
llm_response,
|
llm_response,
|
||||||
@@ -648,6 +692,8 @@ async def chat(
|
|||||||
intent_type="reminder",
|
intent_type="reminder",
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
inferred_queries=[inferred_query],
|
||||||
|
job_id=job.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
@@ -17,13 +17,22 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import parse_qs, urlencode
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
import requests
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
|
from starlette.requests import URL
|
||||||
|
|
||||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import (
|
||||||
|
AgentAdapters,
|
||||||
|
ConversationAdapters,
|
||||||
|
EntryAdapters,
|
||||||
|
create_khoj_token,
|
||||||
|
get_khoj_tokens,
|
||||||
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
ClientApplication,
|
ClientApplication,
|
||||||
@@ -779,3 +788,27 @@ class CommonQueryParamsClass:
|
|||||||
|
|
||||||
|
|
||||||
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||||
|
|
||||||
|
|
||||||
|
def scheduled_chat(query, user: KhojUser, calling_url: URL):
|
||||||
|
# Construct the URL, header for the chat API
|
||||||
|
scheme = "http" if calling_url.scheme == "http" or calling_url.scheme == "ws" else "https"
|
||||||
|
# Replace the original scheduling query with the scheduled query
|
||||||
|
query_dict = parse_qs(calling_url.query)
|
||||||
|
query_dict["q"] = [query]
|
||||||
|
# Convert the dictionary back into a query string
|
||||||
|
scheduled_query = urlencode(query_dict, doseq=True)
|
||||||
|
url = f"{scheme}://{calling_url.netloc}/api/chat?{scheduled_query}"
|
||||||
|
|
||||||
|
headers = {"User-Agent": "Khoj"}
|
||||||
|
if not state.anonymous_mode:
|
||||||
|
# Add authorization request header in non-anonymous mode
|
||||||
|
token = get_khoj_tokens(user)
|
||||||
|
if is_none_or_empty(token):
|
||||||
|
token = create_khoj_token(user)
|
||||||
|
else:
|
||||||
|
token = token[0]
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
# Call the chat API endpoint with authenticated user token and query
|
||||||
|
return requests.get(url, headers=headers)
|
||||||
|
|||||||
Reference in New Issue
Block a user