diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9342b913..a5be3086 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1741,7 +1741,7 @@ class AutomationAdapters: return { "id": automation.id, "subject": automation_metadata["subject"], - "query_to_run": re.sub(r"^/automated_task\s*", "", automation_metadata["query_to_run"]), + "query_to_run": automation_metadata["query_to_run"], "scheduling_request": automation_metadata["scheduling_request"], "schedule": schedule, "crontime": crontime, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index b936488f..a29e993a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -38,13 +38,12 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio -from khoj.processor.conversation.utils import defilter_query +from khoj.processor.conversation.utils import clean_json, defilter_query from khoj.routers.helpers import ( ApiUserRateLimiter, ChatEvent, CommonQueryParams, ConversationCommandRateLimiter, - acreate_title_from_query, get_user_config, schedule_automation, schedule_query, @@ -567,7 +566,7 @@ def delete_automation(request: Request, automation_id: str) -> Response: @api.post("/automation", response_class=Response) @requires(["authenticated"]) -async def post_automation( +def post_automation( request: Request, q: str, crontime: str, @@ -586,7 +585,7 @@ async def post_automation( return Response(content="Invalid crontime", status_code=400) # Infer subject, query to run - _, query_to_run, generated_subject = await schedule_query(q, conversation_history={}, user=user) + _, query_to_run, generated_subject = schedule_query(q, conversation_history={}, user=user) subject = subject or generated_subject # Normalize query parameters @@ -614,13 +613,13 @@ async def post_automation( # Create new Conversation Session associated with this new task title = f"Automation: {subject}" - conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app, title=title) + conversation = ConversationAdapters.create_conversation_session(user, request.user.client_app, title=title) # Schedule automation with query_to_run, timezone, subject directly provided by user try: # Use the query to run as the scheduling request if the scheduling request is unset calling_url = request.url.replace(query=f"{request.url.query}") - automation = await schedule_automation( + automation = schedule_automation( query_to_run, subject, crontime, timezone, q, user, calling_url, str(conversation.id) ) except Exception as e: @@ -665,7 +664,7 @@ def trigger_manual_job( @api.put("/automation", response_class=Response) @requires(["authenticated"]) -async def edit_job( +def edit_job( request: Request, automation_id: str, q: Optional[str], @@ -686,13 +685,13 @@ async def edit_job( # Check, get automation to edit try: - automation: Job = await AutomationAdapters.aget_automation(user, automation_id) + automation: Job = AutomationAdapters.get_automation(user, automation_id) except ValueError as e: logger.error(f"Error editing automation {automation_id} for {user.email}: {e}", exc_info=True) return Response(content="Invalid automation", status_code=403) # Infer subject, query to run - _, query_to_run, _ = await schedule_query(q, conversation_history={}, user=user) + _, query_to_run, _ = schedule_query(q, conversation_history={}, user=user) subject = subject # Normalize query parameters @@ -717,7 +716,7 @@ async def edit_job( ) # Construct updated automation metadata - automation_metadata = json.loads(automation.name) + automation_metadata: dict[str, str] = json.loads(clean_json(automation.name)) automation_metadata["scheduling_request"] = q automation_metadata["query_to_run"] = query_to_run automation_metadata["subject"] = subject.strip() @@ -728,15 +727,13 @@ async def edit_job( title = f"Automation: {subject}" # Create new Conversation Session associated with this new task - conversation = await ConversationAdapters.acreate_conversation_session( - user, request.user.client_app, title=title - ) + conversation = ConversationAdapters.create_conversation_session(user, request.user.client_app, title=title) conversation_id = str(conversation.id) automation_metadata["conversation_id"] = conversation_id # Modify automation with updated query, subject - await sync_to_async(automation.modify)( + automation.modify( name=json.dumps(automation_metadata), kwargs={ "query_to_run": query_to_run, @@ -752,11 +749,11 @@ async def edit_job( user_timezone = pytz.timezone(timezone) trigger = CronTrigger.from_crontab(crontime, user_timezone) if automation.trigger != trigger: - await sync_to_async(automation.reschedule)(trigger=trigger) + automation.reschedule(trigger=trigger) # Collate info about the updated user automation - automation = await AutomationAdapters.aget_automation(user, automation.id) - automation_info = await sync_to_async(AutomationAdapters.get_automation_metadata)(user, automation) + automation = AutomationAdapters.get_automation(user, automation.id) + automation_info = AutomationAdapters.get_automation_metadata(user, automation) # Return modified automation information as a JSON response return Response(content=json.dumps(automation_info), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6dcb1b4b..d6afddb3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -551,7 +551,35 @@ async def generate_online_subqueries( return {q} -async def schedule_query( +def schedule_query( + q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} +) -> Tuple[str, str, str]: + """ + Schedule the date, time to run the query. Assume the server timezone is UTC. + """ + chat_history = construct_chat_history(conversation_history) + + crontime_prompt = prompts.crontime_prompt.format( + query=q, + chat_history=chat_history, + ) + + raw_response = send_message_to_model_wrapper_sync( + crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer + ) + + # Validate that the response is a non-empty, JSON-serializable list + try: + raw_response = raw_response.strip() + response: Dict[str, str] = json.loads(clean_json(raw_response)) + if not response or not isinstance(response, Dict) or len(response) != 3: + raise AssertionError(f"Invalid response for scheduling query : {response}") + return response.get("crontime"), response.get("query"), response.get("subject") + except Exception: + raise AssertionError(f"Invalid response for scheduling query: {raw_response}") + + +async def aschedule_query( q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, str, str]: """ @@ -571,7 +599,7 @@ async def schedule_query( # Validate that the response is a non-empty, JSON-serializable list try: raw_response = raw_response.strip() - response: Dict[str, str] = json.loads(raw_response) + response: Dict[str, str] = json.loads(clean_json(raw_response)) if not response or not isinstance(response, Dict) or len(response) != 3: raise AssertionError(f"Invalid response for scheduling query : {response}") return response.get("crontime"), response.get("query"), response.get("subject") @@ -1065,6 +1093,7 @@ def send_message_to_model_wrapper_sync( system_message: str = "", response_type: str = "text", user: KhojUser = None, + query_images: List[str] = None, query_files: str = "", tracer: dict = {}, ): @@ -1090,6 +1119,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=chat_model.model_type, + query_images=query_images, query_files=query_files, ) @@ -1112,6 +1142,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=chat_model.model_type, + query_images=query_images, query_files=query_files, ) @@ -1134,6 +1165,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=chat_model.model_type, + query_images=query_images, query_files=query_files, ) @@ -1154,6 +1186,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=chat_model.model_type, + query_images=query_images, query_files=query_files, ) @@ -1794,12 +1827,66 @@ async def create_automation( conversation_id: str = None, tracer: dict = {}, ): - crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer) - job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) + crontime, query_to_run, subject = await aschedule_query(q, meta_log, user, tracer=tracer) + job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject -async def schedule_automation( +def schedule_automation( + query_to_run: str, + subject: str, + crontime: str, + timezone: str, + scheduling_request: str, + user: KhojUser, + calling_url: URL, + conversation_id: str, +): + # Disable minute level automation recurrence + minute_value = crontime.split(" ")[0] + if not minute_value.isdigit(): + # Run automation at some random minute (to distribute request load) instead of running every X minutes + crontime = " ".join([str(math.floor(random() * 60))] + crontime.split(" ")[1:]) + + user_timezone = pytz.timezone(timezone) + trigger = CronTrigger.from_crontab(crontime, user_timezone) + trigger.jitter = 60 + # Generate id and metadata used by task scheduler and process locks for the task runs + job_metadata = json.dumps( + { + "query_to_run": query_to_run, + "scheduling_request": scheduling_request, + "subject": subject, + "crontime": crontime, + "conversation_id": str(conversation_id), + } + ) + query_id = hashlib.md5(f"{query_to_run}_{crontime}".encode("utf-8")).hexdigest() + job_id = f"automation_{user.uuid}_{query_id}" + job = state.scheduler.add_job( + run_with_process_lock, + trigger=trigger, + args=( + scheduled_chat, + f"{ProcessLock.Operation.SCHEDULED_JOB}_{user.uuid}_{query_id}", + ), + kwargs={ + "query_to_run": query_to_run, + "scheduling_request": scheduling_request, + "subject": subject, + "user": user, + "calling_url": calling_url, + "job_id": job_id, + "conversation_id": conversation_id, + }, + id=job_id, + name=job_metadata, + max_instances=2, # Allow second instance to kill any previous instance with stale lock + ) + return job + + +async def aschedule_automation( query_to_run: str, subject: str, crontime: str, diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index f873be45..2b08fc74 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -636,11 +636,11 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, defa ), ], ) -async def test_infer_task_scheduling_request( +def test_infer_task_scheduling_request( chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2 ): # Act - crontime, inferred_query, _ = await schedule_query(user_query, {}, default_user2) + crontime, inferred_query, _ = schedule_query(user_query, {}, default_user2) inferred_query = inferred_query.lower() # Assert