diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 10fde9e8..f32e3b8b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -454,6 +454,13 @@ class ProcessLockAdapters: logger.info(f"🔓 Unlocked {operation} process after executing {func} {'Succeeded' if success else 'Failed'}") +def run_with_process_lock(*args): + """Wrapper function used for scheduling jobs. + Required as APScheduler can't discover the `ProcessLockAdapter.run_with_lock' method on its own. + """ + return ProcessLockAdapters.run_with_lock(*args) + + class ClientApplicationAdapters: @staticmethod async def aget_client_application_by_id(client_id: str, client_secret: str): diff --git a/src/khoj/main.py b/src/khoj/main.py index 74807137..6ce30c7a 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -128,7 +128,20 @@ def run(should_start_server=True): poll_task_scheduler() # Setup Background Scheduler - state.scheduler = BackgroundScheduler() + from django.conf import settings as django_settings + + django_db = django_settings.DATABASES["default"] + state.scheduler = BackgroundScheduler( + { + "apscheduler.jobstores.default": { + "type": "sqlalchemy", + "url": f'postgresql://{django_db["USER"]}:{django_db["PASSWORD"]}@{django_db["HOST"]}:{django_db["PORT"]}/{django_db["NAME"]}', + }, + "apscheduler.timezone": "UTC", + "apscheduler.job_defaults.misfire_grace_time": "60", # Useful to run scheduled jobs even when worker delayed because it was busy or down + "apscheduler.job_defaults.coalesce": "true", # Combine multiple jobs into one if they are scheduled at the same time + } + ) state.scheduler.start() # Start Server @@ -150,6 +163,9 @@ def run(should_start_server=True): if should_start_server: start_server(app, host=args.host, port=args.port, socket=args.socket) + # Teardown + state.scheduler.shutdown() + def set_state(args): state.config_file = args.config_file diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index c970c421..6ef7016d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -102,6 +102,7 @@ def save_to_conversation_log( intent_type: str = "remember", client_application: ClientApplication = None, conversation_id: int = None, + job_id: str = None, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") updated_conversation = message_to_log( @@ -112,6 +113,7 @@ def save_to_conversation_log( "context": compiled_references, "intent": {"inferred-queries": inferred_queries, "type": intent_type}, "onlineContext": online_results, + "jobId": job_id, }, conversation_log=meta_log.get("chat", []), ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index bb164b13..34bbc5fa 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,3 +1,4 @@ +import functools import json import logging import math @@ -13,8 +14,13 @@ from starlette.authentication import requires from starlette.websockets import WebSocketDisconnect from websockets import ConnectionClosedOK -from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name -from khoj.database.models import KhojUser +from khoj.database.adapters import ( + ConversationAdapters, + EntryAdapters, + aget_user_name, + run_with_process_lock, +) +from khoj.database.models import KhojUser, ProcessLock from khoj.processor.conversation.prompts import ( help_message, no_entries_found, @@ -38,6 +44,7 @@ from khoj.routers.helpers import ( get_conversation_command, is_ready_to_chat, schedule_query, + scheduled_chat, text_to_image, update_telemetry_state, validate_conversation_config, @@ -386,35 +393,40 @@ async def websocket_endpoint( if ConversationCommand.Reminder in conversation_commands: crontime, inferred_query = await schedule_query(q, location, meta_log) - trigger = CronTrigger.from_crontab(crontime) - common = CommonQueryParamsClass( - client=websocket.user.client_app, - user_agent=websocket.headers.get("user-agent"), - host=websocket.headers.get("host"), + try: + trigger = CronTrigger.from_crontab(crontime) + except ValueError as e: + await send_complete_llm_response(f"Unable to create reminder with crontime schedule: {crontime}") + continue + partial_scheduled_chat = functools.partial( + scheduled_chat, inferred_query, websocket.user.object, websocket.url ) - scope = websocket.scope.copy() - scope["path"] = "/api/chat" - scope["type"] = "http" - request = Request(scope) + try: + job = state.scheduler.add_job( + run_with_process_lock, + trigger=trigger, + args=( + partial_scheduled_chat, + f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}", + ), + id=f"job_{user.uuid}_{inferred_query}_{crontime}", + name=f"{inferred_query}", + max_instances=2, # Allow second instance to kill any previous instance with stale lock + jitter=30, + ) + except: + await send_complete_llm_response( + f"Unable to schedule reminder. Ensure the reminder doesn't already exist." + ) + continue + next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S") + llm_response = f""" + ### 🕒 Scheduled Job +- Query: **"{inferred_query}"** +- Schedule: `{crontime}` +- Next Run At: **{next_run_time}** UTC. + """.strip() - state.scheduler.add_job( - async_to_sync(chat), - trigger=trigger, - args=(request, common, inferred_query), - kwargs={ - "stream": False, - "conversation_id": conversation_id, - "city": city, - "region": region, - "country": country, - }, - id=f"job_{user.uuid}_{inferred_query}", - replace_existing=True, - ) - - llm_response = ( - f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).' - ) await sync_to_async(save_to_conversation_log)( q, llm_response, @@ -423,6 +435,13 @@ async def websocket_endpoint( intent_type="reminder", client_application=websocket.user.client_app, conversation_id=conversation_id, + inferred_queries=[inferred_query], + job_id=job.id, + ) + common = CommonQueryParamsClass( + client=websocket.user.client_app, + user_agent=websocket.headers.get("user-agent"), + host=websocket.headers.get("host"), ) update_telemetry_state( request=websocket, @@ -630,16 +649,41 @@ async def chat( if ConversationCommand.Reminder in conversation_commands: crontime, inferred_query = await schedule_query(q, location, meta_log) - trigger = CronTrigger.from_crontab(crontime) - state.scheduler.add_job( - async_to_sync(chat), - trigger=trigger, - args=(request, common, inferred_query, n, d, False, title, conversation_id, city, region, country), - id=f"job_{user.uuid}_{inferred_query}", - replace_existing=True, - ) + try: + trigger = CronTrigger.from_crontab(crontime) + except ValueError as e: + return Response( + content=f"Unable to create reminder with crontime schedule: {crontime}", + media_type="text/plain", + status_code=500, + ) + + partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, request.user.object, request.url) + try: + job = state.scheduler.add_job( + run_with_process_lock, + trigger=trigger, + args=(partial_scheduled_chat, f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{inferred_query}"), + id=f"job_{user.uuid}_{inferred_query}_{crontime}", + name=f"{inferred_query}", + max_instances=2, # Allow second instance to kill any previous instance with stale lock + jitter=30, + ) + except: + return Response( + content=f"Unable to schedule reminder. Ensure the reminder doesn't already exist.", + media_type="text/plain", + status_code=500, + ) + + next_run_time = job.next_run_time.strftime("%Y-%m-%d %H:%M:%S") + llm_response = f""" + ### 🕒 Scheduled Job +- Query: **"{inferred_query}"** +- Schedule: `{crontime}` +- Next Run At: **{next_run_time}** UTC.' + """.strip() - llm_response = f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).' await sync_to_async(save_to_conversation_log)( q, llm_response, @@ -648,6 +692,8 @@ async def chat( intent_type="reminder", client_application=request.user.client_app, conversation_id=conversation_id, + inferred_queries=[inferred_query], + job_id=job.id, ) if stream: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 1dab6c53..ce49d3da 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -17,13 +17,22 @@ from typing import ( Tuple, Union, ) +from urllib.parse import parse_qs, urlencode import openai +import requests from fastapi import Depends, Header, HTTPException, Request, UploadFile from PIL import Image from starlette.authentication import has_required_scope +from starlette.requests import URL -from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters +from khoj.database.adapters import ( + AgentAdapters, + ConversationAdapters, + EntryAdapters, + create_khoj_token, + get_khoj_tokens, +) from khoj.database.models import ( ChatModelOptions, ClientApplication, @@ -779,3 +788,27 @@ class CommonQueryParamsClass: CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()] + + +def scheduled_chat(query, user: KhojUser, calling_url: URL): + # Construct the URL, header for the chat API + scheme = "http" if calling_url.scheme == "http" or calling_url.scheme == "ws" else "https" + # Replace the original scheduling query with the scheduled query + query_dict = parse_qs(calling_url.query) + query_dict["q"] = [query] + # Convert the dictionary back into a query string + scheduled_query = urlencode(query_dict, doseq=True) + url = f"{scheme}://{calling_url.netloc}/api/chat?{scheduled_query}" + + headers = {"User-Agent": "Khoj"} + if not state.anonymous_mode: + # Add authorization request header in non-anonymous mode + token = get_khoj_tokens(user) + if is_none_or_empty(token): + token = create_khoj_token(user) + else: + token = token[0] + headers["Authorization"] = f"Bearer {token}" + + # Call the chat API endpoint with authenticated user token and query + return requests.get(url, headers=headers)