mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Create create_scheduled_task func to dedupe logic across ws, http APIs
Previously, both the websocket and http endpoint were implementing the same logic. This was becoming too unwieldy
This commit is contained in:
@@ -429,7 +429,7 @@ class ProcessLockAdapters:
|
||||
return ProcessLock.objects.filter(name=process_name).delete()
|
||||
|
||||
@staticmethod
|
||||
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600):
|
||||
def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600, **kwargs):
|
||||
# Exit early if process lock is already taken
|
||||
if ProcessLockAdapters.is_process_locked(operation):
|
||||
logger.info(f"🔒 Skip executing {func} as {operation} lock is already taken")
|
||||
@@ -443,7 +443,7 @@ class ProcessLockAdapters:
|
||||
|
||||
# Execute Function
|
||||
with timer(f"🔒 Run {func} with {operation} process lock", logger):
|
||||
func()
|
||||
func(**kwargs)
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error executing {func} with {operation} process lock: {e}", exc_info=True)
|
||||
@@ -454,11 +454,11 @@ class ProcessLockAdapters:
|
||||
logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}")
|
||||
|
||||
|
||||
def run_with_process_lock(*args):
|
||||
def run_with_process_lock(*args, **kwargs):
|
||||
"""Wrapper function used for scheduling jobs.
|
||||
Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own.
|
||||
"""
|
||||
return ProcessLockAdapters.run_with_lock(*args)
|
||||
return ProcessLockAdapters.run_with_lock(*args, **kwargs)
|
||||
|
||||
|
||||
class ClientApplicationAdapters:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@@ -8,8 +6,6 @@ from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import pytz
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
|
||||
from fastapi.requests import Request
|
||||
@@ -18,13 +14,8 @@ from starlette.authentication import requires
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets import ConnectionClosedOK
|
||||
|
||||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
aget_user_name,
|
||||
run_with_process_lock,
|
||||
)
|
||||
from khoj.database.models import KhojUser, ProcessLock
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.conversation.prompts import (
|
||||
help_message,
|
||||
no_entries_found,
|
||||
@@ -45,10 +36,9 @@ from khoj.routers.helpers import (
|
||||
agenerate_chat_response,
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
create_scheduled_task,
|
||||
get_conversation_command,
|
||||
is_ready_to_chat,
|
||||
schedule_query,
|
||||
scheduled_chat,
|
||||
text_to_image,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
@@ -399,36 +389,13 @@ async def websocket_endpoint(
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
if ConversationCommand.Reminder in conversation_commands:
|
||||
user_timezone = pytz.timezone(timezone)
|
||||
crontime, inferred_query, subject = await schedule_query(q, location, meta_log)
|
||||
try:
|
||||
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
||||
except ValueError as e:
|
||||
await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}")
|
||||
continue
|
||||
# Generate the job id from the hash of inferred_query and crontime
|
||||
job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
|
||||
query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest()
|
||||
partial_scheduled_chat = functools.partial(
|
||||
scheduled_chat, inferred_query, q, subject, websocket.user.object, websocket.url
|
||||
)
|
||||
try:
|
||||
job = state.scheduler.add_job(
|
||||
run_with_process_lock,
|
||||
trigger=trigger,
|
||||
args=(
|
||||
partial_scheduled_chat,
|
||||
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
|
||||
),
|
||||
id=job_id,
|
||||
name=f"{inferred_query}",
|
||||
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||
jitter=30,
|
||||
)
|
||||
except:
|
||||
await send_complete_llm_response(
|
||||
f"Unable to schedule reminder. Ensure the reminder doesn't already exist."
|
||||
job, crontime, inferred_query, subject = await create_scheduled_task(
|
||||
q, location, timezone, user, websocket.url, meta_log
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
await send_complete_llm_response(f"Unable to schedule task. Ensure the task doesn't already exist.")
|
||||
continue
|
||||
# Display next run time in user timezone instead of UTC
|
||||
next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M %Z (%z)")
|
||||
@@ -670,36 +637,14 @@ async def chat(
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
if ConversationCommand.Reminder in conversation_commands:
|
||||
user_timezone = pytz.timezone(timezone)
|
||||
crontime, inferred_query, subject = await schedule_query(q, location, meta_log)
|
||||
try:
|
||||
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
||||
except ValueError as e:
|
||||
return Response(
|
||||
content=f"Unable to create reminder with crontime schedule: {crontime}",
|
||||
media_type="text/plain",
|
||||
status_code=500,
|
||||
job, crontime, inferred_query, subject = await create_scheduled_task(
|
||||
q, location, timezone, user, request.url, meta_log
|
||||
)
|
||||
|
||||
# Generate the job id from the hash of inferred_query and crontime
|
||||
job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
|
||||
query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest()
|
||||
partial_scheduled_chat = functools.partial(
|
||||
scheduled_chat, inferred_query, q, subject, request.user.object, request.url
|
||||
)
|
||||
try:
|
||||
job = state.scheduler.add_job(
|
||||
run_with_process_lock,
|
||||
trigger=trigger,
|
||||
args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}"),
|
||||
id=job_id,
|
||||
name=f"{inferred_query}",
|
||||
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||
jitter=30,
|
||||
)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
return Response(
|
||||
content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.",
|
||||
content=f"Unable to schedule task. Ensure the task doesn't already exist.",
|
||||
media_type="text/plain",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
@@ -21,7 +22,9 @@ from typing import (
|
||||
from urllib.parse import parse_qs, urlencode
|
||||
|
||||
import openai
|
||||
import pytz
|
||||
import requests
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from PIL import Image
|
||||
from starlette.authentication import has_required_scope
|
||||
@@ -33,12 +36,14 @@ from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
run_with_process_lock,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
KhojUser,
|
||||
ProcessLock,
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserRequests,
|
||||
@@ -912,3 +917,34 @@ def scheduled_chat(executing_query: str, scheduling_query: str, subject: str, us
|
||||
send_task_email(user.get_short_name(), user.email, scheduling_query, ai_response, subject)
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
|
||||
async def create_scheduled_task(
|
||||
q: str, location: LocationData, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}
|
||||
):
|
||||
user_timezone = pytz.timezone(timezone)
|
||||
crontime, inferred_query, subject = await schedule_query(q, location, meta_log)
|
||||
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
||||
# Generate id and metadata used by task scheduler and process locks for the task runs
|
||||
job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
|
||||
query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest()
|
||||
job = state.scheduler.add_job(
|
||||
run_with_process_lock,
|
||||
trigger=trigger,
|
||||
args=(
|
||||
scheduled_chat,
|
||||
f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}",
|
||||
),
|
||||
kwargs={
|
||||
"executing_query": inferred_query,
|
||||
"scheduling_query": q,
|
||||
"subject": subject,
|
||||
"user": user,
|
||||
"calling_url": calling_url,
|
||||
},
|
||||
id=job_id,
|
||||
name=f"{inferred_query}",
|
||||
max_instances=2, # Allow second instance to kill any previous instance with stale lock
|
||||
jitter=30,
|
||||
)
|
||||
return job, crontime, inferred_query, subject
|
||||
|
||||
Reference in New Issue
Block a user