Create create_scheduled_task func to dedupe logic across ws, http APIs

Previously, both the websocket and http endpoint were implementing
the same logic. This was becoming too unwieldy
This commit is contained in:
Debanjum Singh Solanky
2024-04-29 15:52:39 +05:30
parent 8dfa0bf047
commit ae10ff4a5f
3 changed files with 53 additions and 72 deletions

View File

@@ -429,7 +429,7 @@ class ProcessLockAdapters:
return ProcessLock.objects.filter(name=process_name).delete()
@staticmethod
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600):
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600, **kwargs):
# Exit early if process lock is already taken
if ProcessLockAdapters.is_process_locked(operation):
logger.info(f"🔒 Skip executing {func} as {operation} lock is already taken")
@@ -443,7 +443,7 @@ class ProcessLockAdapters:
# Execute Function
with timer(f"🔒 Run {func} with {operation} process lock", logger):
func()
func(**kwargs)
success = True
except Exception as e:
logger.error(f"🚨 Error executing {func} with {operation} process lock: {e}", exc_info=True)
@@ -454,11 +454,11 @@ class ProcessLockAdapters:
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
def run_with_process_lock(*args):
def run_with_process_lock(*args, **kwargs):
"""Wrapper function used for scheduling jobs.
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
"""
return ProcessLockAdapters.run_with_lock(*args)
return ProcessLockAdapters.run_with_lock(*args, **kwargs)
class ClientApplicationAdapters:

View File

@@ -1,5 +1,3 @@
import functools
import hashlib
import json
import logging
import math
@@ -8,8 +6,6 @@ from datetime import datetime
from typing import Dict, Optional
from urllib.parse import unquote
import pytz
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
from fastapi.requests import Request
@@ -18,13 +14,8 @@ from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
aget_user_name,
run_with_process_lock,
)
from khoj.database.models import KhojUser, ProcessLock
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import (
help_message,
no_entries_found,
@@ -45,10 +36,9 @@ from khoj.routers.helpers import (
agenerate_chat_response,
aget_relevant_information_sources,
aget_relevant_output_modes,
create_scheduled_task,
get_conversation_command,
is_ready_to_chat,
schedule_query,
scheduled_chat,
text_to_image,
update_telemetry_state,
validate_conversation_config,
@@ -399,36 +389,13 @@ async def websocket_endpoint(
q = q.replace(f"/{cmd.value}", "").strip()
if ConversationCommand.Reminder in conversation_commands:
user_timezone = pytz.timezone(timezone)
crontime, inferred_query, subject = await schedule_query(q, location, meta_log)
try:
trigger = CronTrigger.from_crontab(crontime, user_timezone)
except ValueError as e:
await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}")
continue
# Generate the job id from the hash of inferred_query and crontime
job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest()
partial_scheduled_chat = functools.partial(
scheduled_chat, inferred_query, q, subject, websocket.user.object, websocket.url
)
try:
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(
partial_scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
),
id=job_id,
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
except:
await send_complete_llm_response(
f"Unable to schedule reminder. Ensure the reminder doesn't already exist."
job, crontime, inferred_query, subject = await create_scheduled_task(
q, location, timezone, user, websocket.url, meta_log
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
await send_complete_llm_response(f"Unable to schedule task. Ensure the task doesn't already exist.")
continue
# Display next run time in user timezone instead of UTC
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M %Z (%z)")
@@ -670,36 +637,14 @@ async def chat(
user_name = await aget_user_name(user)
if ConversationCommand.Reminder in conversation_commands:
user_timezone = pytz.timezone(timezone)
crontime, inferred_query, subject = await schedule_query(q, location, meta_log)
try:
trigger = CronTrigger.from_crontab(crontime, user_timezone)
except ValueError as e:
return Response(
content=f"Unable to create reminder with crontime schedule: {crontime}",
media_type="text/plain",
status_code=500,
job, crontime, inferred_query, subject = await create_scheduled_task(
q, location, timezone, user, request.url, meta_log
)
# Generate the job id from the hash of inferred_query and crontime
job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest()
partial_scheduled_chat = functools.partial(
scheduled_chat, inferred_query, q, subject, request.user.object, request.url
)
try:
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}"),
id=job_id,
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
except:
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
return Response(
content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.",
content=f"Unable to schedule task. Ensure the task doesn't already exist.",
media_type="text/plain",
status_code=500,
)

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import hashlib
import io
import json
import logging
@@ -21,7 +22,9 @@ from typing import (
from urllib.parse import parse_qs, urlencode
import openai
import pytz
import requests
from apscheduler.triggers.cron import CronTrigger
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image
from starlette.authentication import has_required_scope
@@ -33,12 +36,14 @@ from khoj.database.adapters import (
EntryAdapters,
create_khoj_token,
get_khoj_tokens,
run_with_process_lock,
)
from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
KhojUser,
ProcessLock,
Subscription,
TextToImageModelConfig,
UserRequests,
@@ -912,3 +917,34 @@ def scheduled_chat(executing_query: str, scheduling_query: str, subject: str, us
send_task_email(user.get_short_name(), user.email, scheduling_query, ai_response, subject)
else:
return raw_response
async def create_scheduled_task(
q: str, location: LocationData, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}
):
user_timezone = pytz.timezone(timezone)
crontime, inferred_query, subject = await schedule_query(q, location, meta_log)
trigger = CronTrigger.from_crontab(crontime, user_timezone)
# Generate id and metadata used by task scheduler and process locks for the task runs
job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest()
job = state.scheduler.add_job(
run_with_process_lock,
trigger=trigger,
args=(
scheduled_chat,
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
),
kwargs={
"executing_query": inferred_query,
"scheduling_query": q,
"subject": subject,
"user": user,
"calling_url": calling_url,
},
id=job_id,
name=f"{inferred_query}",
max_instances=2, # Allow second instance to kill any previous instance with stale lock
jitter=30,
)
return job, crontime, inferred_query, subject