From e94bf00e1e4ed6bd9a67feef3eb297373d036582 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 7 Apr 2025 21:03:05 +0530 Subject: [PATCH] Add cancellation support to research mode via asyncio.Event --- src/interface/web/app/chat/page.tsx | 3 +- src/khoj/routers/api_chat.py | 82 ++++++++++++++++++++++------- src/khoj/routers/research.py | 7 +++ 3 files changed, 70 insertions(+), 22 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index e4926ccc..b672091d 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -241,8 +241,7 @@ export default function Chat() { handleAbortedMessage(); setTriggeredAbort(false); } - }), - [triggeredAbort]; + }, [triggeredAbort]); useEffect(() => { if (queryToProcess) { diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 6201a483..163ffc46 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -683,14 +683,13 @@ async def chat( start_time = time.perf_counter() ttft = None chat_metadata: dict = {} - connection_alive = True user: KhojUser = request.user.object is_subscribed = has_required_scope(request, ["premium"]) - event_delimiter = "␃🔚␗" q = unquote(q) train_of_thought = [] nonlocal conversation_id nonlocal raw_query_files + cancellation_event = asyncio.Event() tracer: dict = { "mid": turn_id, @@ -717,11 +716,33 @@ async def chat( for file in raw_query_files: query_files[file.name] = file.content + # Create a task to monitor for disconnections + disconnect_monitor_task = None + + async def monitor_disconnection(): + try: + msg = await request.receive() + if msg["type"] == "http.disconnect": + logger.debug(f"User {user} disconnected from {common.client} client.") + cancellation_event.set() + except Exception as e: + logger.error(f"Error in disconnect monitor: {e}") + + # Cancel the disconnect monitor task if it is still running + async def cancel_disconnect_monitor(): + if disconnect_monitor_task and not disconnect_monitor_task.done(): + logger.debug(f"Cancelling disconnect monitor task for user {user}") + disconnect_monitor_task.cancel() + try: + await disconnect_monitor_task + except asyncio.CancelledError: + pass + async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal connection_alive, ttft, train_of_thought - if not connection_alive or await request.is_disconnected(): - connection_alive = False - logger.warning(f"User {user} disconnected from {common.client} client") + nonlocal ttft, train_of_thought + event_delimiter = "␃🔚␗" + if cancellation_event.is_set(): + logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.") return try: if event_type == ChatEvent.END_LLM_RESPONSE: @@ -746,17 +767,25 @@ async def chat( elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) except asyncio.CancelledError as e: - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client: {e}") - return + if cancellation_event.is_set(): + logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.") except Exception as e: - connection_alive = False - logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) - return + if not cancellation_event.is_set(): + logger.error( + f"Failed to stream chat API response to {user} on {common.client}: {e}.", + exc_info=True, + ) finally: - yield event_delimiter + if not cancellation_event.is_set(): + yield event_delimiter + # Cancel the disconnect monitor task if it is still running + if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE: + await cancel_disconnect_monitor() async def send_llm_response(response: str, usage: dict = None): + # Check if the client is still connected + if cancellation_event.is_set(): + return # Send Chat Response async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result @@ -797,6 +826,9 @@ async def chat( metadata=chat_metadata, ) + # Start the disconnect monitor in the background + disconnect_monitor_task = asyncio.create_task(monitor_disconnection()) + if is_query_empty(q): async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): yield result @@ -914,6 +946,7 @@ async def chat( file_filters=conversation.file_filters if conversation else [], query_files=attached_file_context, tracer=tracer, + cancellation_event=cancellation_event, ): if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: @@ -1288,6 +1321,13 @@ async def chat( async for result in send_event(ChatEvent.STATUS, error_message): yield result + # Check if the user has disconnected + if cancellation_event.is_set(): + logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.") + # Cancel the disconnect monitor task if it is still running + await cancel_disconnect_monitor() + return + ## Generate Text Output async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): yield result @@ -1320,14 +1360,12 @@ async def chat( tracer, ) - continue_stream = True async for item in llm_response: # Should not happen with async generator, end is signaled by loop exit. Skip. if item is None: continue - if not connection_alive or not continue_stream: - # Drain the generator if disconnected but keep processing internally - continue + if cancellation_event.is_set(): + break message = item.response if isinstance(item, ResponseWithThought) else item if isinstance(item, ResponseWithThought) and item.thought: async for result in send_event(ChatEvent.THOUGHT, item.thought): @@ -1342,11 +1380,12 @@ async def chat( async for result in send_event(ChatEvent.MESSAGE, message): yield result except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}") + if not cancellation_event.is_set(): + logger.warning(f"Error during streaming. Stopping send: {e}") + break # Signal end of LLM response after the loop finishes - if connection_alive: + if not cancellation_event.is_set(): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result # Send Usage Metadata once llm interactions are complete @@ -1357,6 +1396,9 @@ async def chat( yield result logger.debug("Finished streaming response") + # Cancel the disconnect monitor task if it is still running + await cancel_disconnect_monitor() + ## Stream Text Response if stream: return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index fa855b9c..9fb7c229 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,3 +1,4 @@ +import asyncio import logging import os from datetime import datetime @@ -205,11 +206,17 @@ async def execute_information_collection( file_filters: List[str] = [], tracer: dict = {}, query_files: str = None, + cancellation_event: Optional[asyncio.Event] = None, ): current_iteration = 0 MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) previous_iterations: List[InformationCollectionIteration] = [] while current_iteration < MAX_ITERATIONS: + # Check for cancellation at the start of each iteration + if cancellation_event and cancellation_event.is_set(): + logger.debug(f"User {user} disconnected client. Research cancelled.") + break + online_results: Dict = dict() code_results: Dict = dict() document_results: List[Dict[str, str]] = []