Add cancellation support to research mode via asyncio.Event

This commit is contained in:
Debanjum
2025-04-07 21:03:05 +05:30
parent 1572781946
commit e94bf00e1e
3 changed files with 70 additions and 22 deletions

View File

@@ -241,8 +241,7 @@ export default function Chat() {
handleAbortedMessage(); handleAbortedMessage();
setTriggeredAbort(false); setTriggeredAbort(false);
} }
}), }, [triggeredAbort]);
[triggeredAbort];
useEffect(() => { useEffect(() => {
if (queryToProcess) { if (queryToProcess) {

View File

@@ -683,14 +683,13 @@ async def chat(
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
connection_alive = True
user: KhojUser = request.user.object user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"]) is_subscribed = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
train_of_thought = [] train_of_thought = []
nonlocal conversation_id nonlocal conversation_id
nonlocal raw_query_files nonlocal raw_query_files
cancellation_event = asyncio.Event()
tracer: dict = { tracer: dict = {
"mid": turn_id, "mid": turn_id,
@@ -717,11 +716,33 @@ async def chat(
for file in raw_query_files: for file in raw_query_files:
query_files[file.name] = file.content query_files[file.name] = file.content
# Create a task to monitor for disconnections
disconnect_monitor_task = None
async def monitor_disconnection():
try:
msg = await request.receive()
if msg["type"] == "http.disconnect":
logger.debug(f"User {user} disconnected from {common.client} client.")
cancellation_event.set()
except Exception as e:
logger.error(f"Error in disconnect monitor: {e}")
# Cancel the disconnect monitor task if it is still running
async def cancel_disconnect_monitor():
if disconnect_monitor_task and not disconnect_monitor_task.done():
logger.debug(f"Cancelling disconnect monitor task for user {user}")
disconnect_monitor_task.cancel()
try:
await disconnect_monitor_task
except asyncio.CancelledError:
pass
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft, train_of_thought nonlocal ttft, train_of_thought
if not connection_alive or await request.is_disconnected(): event_delimiter = "␃🔚␗"
connection_alive = False if cancellation_event.is_set():
logger.warning(f"User {user} disconnected from {common.client} client") logger.debug(f"User {user} disconnected from {common.client} client. Setting cancellation event.")
return return
try: try:
if event_type == ChatEvent.END_LLM_RESPONSE: if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -746,17 +767,25 @@ async def chat(
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e: except asyncio.CancelledError as e:
connection_alive = False if cancellation_event.is_set():
logger.warn(f"User {user} disconnected from {common.client} client: {e}") logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client: {e}.")
return
except Exception as e: except Exception as e:
connection_alive = False if not cancellation_event.is_set():
logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) logger.error(
return f"Failed to stream chat API response to {user} on {common.client}: {e}.",
exc_info=True,
)
finally: finally:
yield event_delimiter if not cancellation_event.is_set():
yield event_delimiter
# Cancel the disconnect monitor task if it is still running
if cancellation_event.is_set() or event_type == ChatEvent.END_RESPONSE:
await cancel_disconnect_monitor()
async def send_llm_response(response: str, usage: dict = None): async def send_llm_response(response: str, usage: dict = None):
# Check if the client is still connected
if cancellation_event.is_set():
return
# Send Chat Response # Send Chat Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result yield result
@@ -797,6 +826,9 @@ async def chat(
metadata=chat_metadata, metadata=chat_metadata,
) )
# Start the disconnect monitor in the background
disconnect_monitor_task = asyncio.create_task(monitor_disconnection())
if is_query_empty(q): if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
yield result yield result
@@ -914,6 +946,7 @@ async def chat(
file_filters=conversation.file_filters if conversation else [], file_filters=conversation.file_filters if conversation else [],
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
cancellation_event=cancellation_event,
): ):
if isinstance(research_result, InformationCollectionIteration): if isinstance(research_result, InformationCollectionIteration):
if research_result.summarizedResult: if research_result.summarizedResult:
@@ -1288,6 +1321,13 @@ async def chat(
async for result in send_event(ChatEvent.STATUS, error_message): async for result in send_event(ChatEvent.STATUS, error_message):
yield result yield result
# Check if the user has disconnected
if cancellation_event.is_set():
logger.debug(f"User {user} disconnected from {common.client} client. Stopping LLM response.")
# Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor()
return
## Generate Text Output ## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
yield result yield result
@@ -1320,14 +1360,12 @@ async def chat(
tracer, tracer,
) )
continue_stream = True
async for item in llm_response: async for item in llm_response:
# Should not happen with async generator, end is signaled by loop exit. Skip. # Should not happen with async generator, end is signaled by loop exit. Skip.
if item is None: if item is None:
continue continue
if not connection_alive or not continue_stream: if cancellation_event.is_set():
# Drain the generator if disconnected but keep processing internally break
continue
message = item.response if isinstance(item, ResponseWithThought) else item message = item.response if isinstance(item, ResponseWithThought) else item
if isinstance(item, ResponseWithThought) and item.thought: if isinstance(item, ResponseWithThought) and item.thought:
async for result in send_event(ChatEvent.THOUGHT, item.thought): async for result in send_event(ChatEvent.THOUGHT, item.thought):
@@ -1342,11 +1380,12 @@ async def chat(
async for result in send_event(ChatEvent.MESSAGE, message): async for result in send_event(ChatEvent.MESSAGE, message):
yield result yield result
except Exception as e: except Exception as e:
continue_stream = False if not cancellation_event.is_set():
logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}") logger.warning(f"Error during streaming. Stopping send: {e}")
break
# Signal end of LLM response after the loop finishes # Signal end of LLM response after the loop finishes
if connection_alive: if not cancellation_event.is_set():
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result yield result
# Send Usage Metadata once llm interactions are complete # Send Usage Metadata once llm interactions are complete
@@ -1357,6 +1396,9 @@ async def chat(
yield result yield result
logger.debug("Finished streaming response") logger.debug("Finished streaming response")
# Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor()
## Stream Text Response ## Stream Text Response
if stream: if stream:
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")

View File

@@ -1,3 +1,4 @@
import asyncio
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@@ -205,11 +206,17 @@ async def execute_information_collection(
file_filters: List[str] = [], file_filters: List[str] = [],
tracer: dict = {}, tracer: dict = {},
query_files: str = None, query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None,
): ):
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
previous_iterations: List[InformationCollectionIteration] = [] previous_iterations: List[InformationCollectionIteration] = []
while current_iteration < MAX_ITERATIONS: while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.debug(f"User {user} disconnected client. Research cancelled.")
break
online_results: Dict = dict() online_results: Dict = dict()
code_results: Dict = dict() code_results: Dict = dict()
document_results: List[Dict[str, str]] = [] document_results: List[Dict[str, str]] = []