diff --git a/src/khoj/routers/api_automation.py b/src/khoj/routers/api_automation.py index f20f07d5..b965c9f9 100644 --- a/src/khoj/routers/api_automation.py +++ b/src/khoj/routers/api_automation.py @@ -102,7 +102,7 @@ def post_automation( # Schedule automation with query_to_run, timezone, subject directly provided by user try: # Use the query to run as the scheduling request if the scheduling request is unset - calling_url = request.url.replace(query=f"{request.url.query}") + calling_url = str(request.url.replace(query=f"{request.url.query}")) automation = schedule_automation( query_to_run, subject, crontime, timezone, q, user, calling_url, str(conversation.id) ) @@ -224,7 +224,7 @@ def edit_job( "subject": subject, "scheduling_request": q, "user": user, - "calling_url": request.url, + "calling_url": str(request.url), "conversation_id": conversation_id, }, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0643d134..db75423a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -619,13 +619,13 @@ def schedule_query( # Validate that the response is a non-empty, JSON-serializable list try: - raw_response = raw_response.strip() - response: Dict[str, str] = json.loads(clean_json(raw_response)) + raw_response_text = raw_response.text + response: Dict[str, str] = json.loads(clean_json(raw_response_text)) if not response or not isinstance(response, Dict) or len(response) != 3: raise AssertionError(f"Invalid response for scheduling query : {response}") return response.get("crontime"), response.get("query"), response.get("subject") except Exception: - raise AssertionError(f"Invalid response for scheduling query: {raw_response}") + raise AssertionError(f"Invalid response for scheduling query: {raw_response.text}") async def aschedule_query( @@ -2130,7 +2130,8 @@ def format_automation_response(scheduling_request: str, executed_query: str, ai_ ) with timer("Chat actor: Format automation response", logger): - return send_message_to_model_wrapper_sync(automation_format_prompt, user=user) + raw_response = send_message_to_model_wrapper_sync(automation_format_prompt, user=user) + return raw_response.text if raw_response else None def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool: @@ -2150,8 +2151,10 @@ def should_notify(original_query: str, executed_query: str, ai_response: str, us with timer("Chat actor: Decide to notify user of automation response", logger): try: # TODO Replace with async call so we don't have to maintain a sync version - raw_response = send_message_to_model_wrapper_sync(to_notify_or_not, user=user, response_type="json_object") - response = json.loads(clean_json(raw_response)) + raw_response: ResponseWithThought = send_message_to_model_wrapper_sync( + to_notify_or_not, user=user, response_type="json_object" + ) + response = json.loads(clean_json(raw_response.text)) should_notify_result = response["decision"] == "Yes" reason = response.get("reason", "unknown") logger.info( @@ -2171,7 +2174,7 @@ def scheduled_chat( scheduling_request: str, subject: str, user: KhojUser, - calling_url: URL, + calling_url: str, job_id: str = None, conversation_id: str = None, ): @@ -2191,8 +2194,9 @@ def scheduled_chat( return # Extract relevant params from the original URL - scheme = "http" if not calling_url.is_secure else "https" - query_dict = parse_qs(calling_url.query) + parsed_url = urlparse(calling_url) + scheme = parsed_url.scheme + query_dict = parse_qs(parsed_url.query) # Pop the stream value from query_dict if it exists query_dict.pop("stream", None) @@ -2214,7 +2218,7 @@ def scheduled_chat( json_payload = {key: values[0] for key, values in query_dict.items()} # Construct the URL to call the chat API with the scheduled query string - url = f"{scheme}://{calling_url.netloc}/api/chat?client=khoj" + url = f"{scheme}://{parsed_url.netloc}/api/chat?client=khoj" # Construct the Headers for the chat API headers = {"User-Agent": "Khoj", "Content-Type": "application/json"} diff --git a/tests/test_api_automation.py b/tests/test_api_automation.py new file mode 100644 index 00000000..920c86d8 --- /dev/null +++ b/tests/test_api_automation.py @@ -0,0 +1,126 @@ +import pytest +from apscheduler.schedulers.background import BackgroundScheduler +from django_apscheduler.jobstores import DjangoJobStore +from fastapi.testclient import TestClient + +from khoj.utils import state +from tests.helpers import ChatModelFactory + + +@pytest.fixture(autouse=True) +def setup_scheduler(): + state.scheduler = BackgroundScheduler() + state.scheduler.add_jobstore(DjangoJobStore(), "default") + state.scheduler.start() + yield + state.scheduler.shutdown() + + +def create_test_automation(client: TestClient) -> str: + """Helper function to create a test automation and return its ID.""" + state.anonymous_mode = True + ChatModelFactory(name="gpt-4o-mini", model_type="openai") + params = { + "q": "test automation", + "crontime": "0 0 * * *", + } + response = client.post("/api/automation", params=params) + assert response.status_code == 200 + return response.json()["id"] + + +@pytest.mark.django_db(transaction=True) +def test_create_automation(client: TestClient): + """Test that creating an automation works as expected.""" + # Arrange + state.anonymous_mode = True + ChatModelFactory(name="gpt-4o-mini", model_type="openai") + params = { + "q": "test automation", + "crontime": "0 0 * * *", + } + + # Act + response = client.post("/api/automation", params=params) + + # Assert + assert response.status_code == 200 + response_json = response.json() + assert response_json["scheduling_request"] == "test automation" + assert response_json["crontime"] == "0 0 * * *" + + +@pytest.mark.django_db(transaction=True) +def test_get_automations(client: TestClient): + """Test that getting a list of automations works.""" + automation_id = create_test_automation(client) + + # Act + response = client.get("/api/automation") + + # Assert + assert response.status_code == 200 + automations = response.json() + assert isinstance(automations, list) + assert len(automations) > 0 + assert any(a["id"] == automation_id for a in automations) + + +@pytest.mark.django_db(transaction=True) +def test_delete_automation(client: TestClient): + """Test that deleting an automation works.""" + automation_id = create_test_automation(client) + + # Act + response = client.delete(f"/api/automation?automation_id={automation_id}") + + # Assert + assert response.status_code == 200 + + # Verify it's gone + response = client.get("/api/automation") + assert response.status_code == 200 + automations = response.json() + assert not any(a["id"] == automation_id for a in automations) + + +@pytest.mark.django_db(transaction=True) +def test_edit_automation(client: TestClient): + """Test that editing an automation works.""" + automation_id = create_test_automation(client) + + edit_params = { + "automation_id": automation_id, + "q": "edited automation", + "crontime": "0 1 * * *", + "subject": "edited subject", + "timezone": "UTC", + } + + # Act + response = client.put("/api/automation", params=edit_params) + + # Assert + if response.status_code != 200: + print(response.text) + assert response.status_code == 200 + edited_automation = response.json() + assert edited_automation["scheduling_request"] == "edited automation" + assert edited_automation["crontime"] == "0 1 * * *" + assert edited_automation["subject"] == "edited subject" + + +@pytest.mark.django_db(transaction=True) +def test_trigger_automation(client: TestClient): + """Test that triggering an automation works.""" + automation_id = create_test_automation(client) + + # Act + response = client.post(f"/api/automation/trigger?automation_id={automation_id}") + + # Assert + assert response.status_code == 200 + # NOTE: We are not testing the execution of the triggered job itself, + # as that would require a more complex test setup with mocking. + # A 200 response is sufficient to indicate that the trigger was received. + assert response.text == "Automation triggered"