From ae10ff4a5f86fd8c0aaf9bf9635ec1ea339531ca Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 29 Apr 2024 15:52:39 +0530 Subject: [PATCH] Create create_scheduled_task func to dedupe logic across ws, http APIs Previously, both the websocket and http endpoint were implementing the same logic. This was becoming too unwieldy --- src/khoj/database/adapters/__init__.py | 8 +-- src/khoj/routers/api_chat.py | 81 +++++--------------------- src/khoj/routers/helpers.py | 36 ++++++++++++ 3 files changed, 53 insertions(+), 72 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index f32e3b8b..d4175704 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -429,7 +429,7 @@ class ProcessLockAdapters: return ProcessLock.objects.filter(name=process_name).delete() @staticmethod - def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600): + def run_with_lock(func: Callable, operation: ProcessLock.Operation, max_duration_in_seconds: int = 600, **kwargs): # Exit early if process lock is already taken if ProcessLockAdapters.is_process_locked(operation): logger.info(f"🔒 Skip executing {func} as {operation} lock is already taken") @@ -443,7 +443,7 @@ class ProcessLockAdapters: # Execute Function with timer(f"🔒 Run {func} with {operation} process lock", logger): - func() + func(**kwargs) success = True except Exception as e: logger.error(f"🚨 Error executing {func} with {operation} process lock: {e}", exc_info=True) @@ -454,11 +454,11 @@ class ProcessLockAdapters: logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}") -def run_with_process_lock(*args): +def run_with_process_lock(*args, **kwargs): """Wrapper function used for scheduling jobs. Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own. """ - return ProcessLockAdapters.run_with_lock(*args) + return ProcessLockAdapters.run_with_lock(*args, **kwargs) class ClientApplicationAdapters: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b6119e12..fe628e69 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,5 +1,3 @@ -import functools -import hashlib import json import logging import math @@ -8,8 +6,6 @@ from datetime import datetime from typing import Dict, Optional from urllib.parse import unquote -import pytz -from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket from fastapi.requests import Request @@ -18,13 +14,8 @@ from starlette.authentication import requires from starlette.websockets import WebSocketDisconnect from websockets import ConnectionClosedOK -from khoj.database.adapters import ( - ConversationAdapters, - EntryAdapters, - aget_user_name, - run_with_process_lock, -) -from khoj.database.models import KhojUser, ProcessLock +from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name +from khoj.database.models import KhojUser from khoj.processor.conversation.prompts import ( help_message, no_entries_found, @@ -45,10 +36,9 @@ from khoj.routers.helpers import ( agenerate_chat_response, aget_relevant_information_sources, aget_relevant_output_modes, + create_scheduled_task, get_conversation_command, is_ready_to_chat, - schedule_query, - scheduled_chat, text_to_image, update_telemetry_state, validate_conversation_config, @@ -399,36 +389,13 @@ async def websocket_endpoint( q = q.replace(f"/{cmd.value}", "").strip() if ConversationCommand.Reminder in conversation_commands: - user_timezone = pytz.timezone(timezone) - crontime, inferred_query, subject = await schedule_query(q, location, meta_log) try: - trigger = CronTrigger.from_crontab(crontime, user_timezone) - except ValueError as e: - await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}") - continue - # Generate the job id from the hash of inferred_query and crontime - job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest() - query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest() - partial_scheduled_chat = functools.partial( - scheduled_chat, inferred_query, q, subject, websocket.user.object, websocket.url - ) - try: - job = state.scheduler.add_job( - run_with_process_lock, - trigger=trigger, - args=( - partial_scheduled_chat, - f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}", - ), - id=job_id, - name=f"{inferred_query}", - max_instances=2, # Allow second instance to kill any previous instance with stale lock - jitter=30, - ) - except: - await send_complete_llm_response( - f"Unable to schedule reminder. Ensure the reminder doesn't already exist." + job, crontime, inferred_query, subject = await create_scheduled_task( + q, location, timezone, user, websocket.url, meta_log ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + await send_complete_llm_response(f"Unable to schedule task. Ensure the task doesn't already exist.") continue # Display next run time in user timezone instead of UTC next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M %Z (%z)") @@ -670,36 +637,14 @@ async def chat( user_name = await aget_user_name(user) if ConversationCommand.Reminder in conversation_commands: - user_timezone = pytz.timezone(timezone) - crontime, inferred_query, subject = await schedule_query(q, location, meta_log) try: - trigger = CronTrigger.from_crontab(crontime, user_timezone) - except ValueError as e: - return Response( - content=f"Unable to create reminder with crontime schedule: {crontime}", - media_type="text/plain", - status_code=500, + job, crontime, inferred_query, subject = await create_scheduled_task( + q, location, timezone, user, request.url, meta_log ) - - # Generate the job id from the hash of inferred_query and crontime - job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest() - query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest() - partial_scheduled_chat = functools.partial( - scheduled_chat, inferred_query, q, subject, request.user.object, request.url - ) - try: - job = state.scheduler.add_job( - run_with_process_lock, - trigger=trigger, - args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}"), - id=job_id, - name=f"{inferred_query}", - max_instances=2, # Allow second instance to kill any previous instance with stale lock - jitter=30, - ) - except: + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") return Response( - content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.", + content=f"Unable to schedule task. Ensure the task doesn't already exist.", media_type="text/plain", status_code=500, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c6974ef5..194dae8a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,5 +1,6 @@ import asyncio import base64 +import hashlib import io import json import logging @@ -21,7 +22,9 @@ from typing import ( from urllib.parse import parse_qs, urlencode import openai +import pytz import requests +from apscheduler.triggers.cron import CronTrigger from fastapi import Depends, Header, HTTPException, Request, UploadFile from PIL import Image from starlette.authentication import has_required_scope @@ -33,12 +36,14 @@ from khoj.database.adapters import ( EntryAdapters, create_khoj_token, get_khoj_tokens, + run_with_process_lock, ) from khoj.database.models import ( ChatModelOptions, ClientApplication, Conversation, KhojUser, + ProcessLock, Subscription, TextToImageModelConfig, UserRequests, @@ -912,3 +917,34 @@ def scheduled_chat(executing_query: str, scheduling_query: str, subject: str, us send_task_email(user.get_short_name(), user.email, scheduling_query, ai_response, subject) else: return raw_response + + +async def create_scheduled_task( + q: str, location: LocationData, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {} +): + user_timezone = pytz.timezone(timezone) + crontime, inferred_query, subject = await schedule_query(q, location, meta_log) + trigger = CronTrigger.from_crontab(crontime, user_timezone) + # Generate id and metadata used by task scheduler and process locks for the task runs + job_id = f"job_{user.uuid}_" + hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest() + query_id = hashlib.md5(f"{inferred_query}".encode("utf-8")).hexdigest() + job = state.scheduler.add_job( + run_with_process_lock, + trigger=trigger, + args=( + scheduled_chat, + f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}", + ), + kwargs={ + "executing_query": inferred_query, + "scheduling_query": q, + "subject": subject, + "user": user, + "calling_url": calling_url, + }, + id=job_id, + name=f"{inferred_query}", + max_instances=2, # Allow second instance to kill any previous instance with stale lock + jitter=30, + ) + return job, crontime, inferred_query, subject