diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 8ef97b8d..796290cf 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -271,11 +271,13 @@ export default function Chat() { const controlMessage = JSON.parse(lastMessage.data); if (controlMessage.type === "interrupt_acknowledged") { console.log("Interrupt acknowledged by server"); - setSocketUrl(null); setProcessQuerySignal(false); return; - } - if (controlMessage.error) { + } else if (controlMessage.type === "interrupt_message_acknowledged") { + console.log("Interrupt message acknowledged by server"); + setProcessQuerySignal(false); + return; + } else if (controlMessage.error) { console.error("WebSocket error:", controlMessage.error); return; } @@ -360,24 +362,20 @@ export default function Chat() { ); console.log("Sent interrupt message via WebSocket:", interruptMessage); - // Update the current message with the new query but keep it in processing state - const messageToProcess = interruptMessage || queryToProcess; + // Mark the last message as completed setMessages((prevMessages) => { const newMessages = [...prevMessages]; const currentMessage = newMessages[newMessages.length - 1]; - if (currentMessage && !currentMessage.completed) { - currentMessage.rawQuery = messageToProcess; - currentMessage.completed = !!interruptMessage; - } + if (currentMessage) currentMessage.completed = true; return newMessages; }); - // Update the query being processed - setQueryToProcess(messageToProcess); - setTriggeredAbort(!!interruptMessage); + // Set the interrupt message as the new query being processed + setQueryToProcess(interruptMessage); + setTriggeredAbort(false); // Always set to false after processing setInterruptMessage(""); } - }, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]); + }, [triggeredAbort, sendMessage]); useEffect(() => { if (queryToProcess) { diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index c07dec90..d7068bff 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -44,6 +44,7 @@ async def operate_environment( query_files: str = None, # TODO: Handle query files cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, + abort_message: Optional[str] = "␃🔚␗", tracer: dict = {}, ): response, user_input_message = None, None @@ -144,6 +145,10 @@ async def operate_environment( # Add interrupt query to current operator run if interrupt_query := get_message_from_queue(interrupt_queue): + if interrupt_query == abort_message: + cancellation_event.set() + logger.debug(f"Operator run cancelled by user {user} via interrupt queue.") + break # Add the interrupt query as a new user message to the research conversation history logger.info(f"Continuing operator run with the new instruction: {interrupt_query}") operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query)) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index c8c23c8f..09c7651a 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -69,6 +69,7 @@ from khoj.routers.helpers import ( generate_mermaidjs_diagram, generate_summary_from_files, get_conversation_command, + get_message_from_queue, is_query_empty, is_ready_to_chat, read_chat_stream, @@ -672,7 +673,7 @@ async def event_generator( common: CommonQueryParams, headers: Headers, request_obj: Request | WebSocket, - interrupt_queue: asyncio.Queue = None, + parent_interrupt_queue: asyncio.Queue = None, ): # Access the parameters from the body q = body.q @@ -697,8 +698,11 @@ async def event_generator( user: KhojUser = user_scope.object is_subscribed = has_required_scope(request_obj, ["premium"]) q = unquote(q) + defiltered_query = defilter_query(q) train_of_thought = [] cancellation_event = asyncio.Event() + child_interrupt_queue: asyncio.Queue = asyncio.Queue() + event_delimiter = "␃🔚␗" tracer: dict = { "mid": turn_id, @@ -744,6 +748,7 @@ async def event_generator( disconnect_monitor_task = None async def monitor_disconnection(): + nonlocal q, defiltered_query if isinstance(request_obj, Request): try: msg = await request_obj.receive() @@ -779,12 +784,23 @@ async def event_generator( except Exception as e: logger.error(f"Error in disconnect monitor: {e}") elif isinstance(request_obj, WebSocket): - while request_obj.client_state == WebSocketState.CONNECTED: + while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set(): await asyncio.sleep(1) - logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.") - cancellation_event.set() - if conversation: + # Check if any interrupt query is received + if interrupt_query := get_message_from_queue(parent_interrupt_queue): + if interrupt_query == event_delimiter: + cancellation_event.set() + logger.debug(f"Chat cancelled by user {user} via interrupt queue.") + else: + # Pass the interrupt query to child tasks + logger.info(f"Continuing chat with the new instruction: {interrupt_query}") + await child_interrupt_queue.put(interrupt_query) + q += f"\n\n{interrupt_query}" + defiltered_query += f"\n\n{defilter_query(interrupt_query)}" + + logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.") + if conversation and cancellation_event.is_set(): await asyncio.shield( save_to_conversation_log( q, @@ -821,7 +837,6 @@ async def event_generator( async def send_event(event_type: ChatEvent, data: str | dict): nonlocal ttft, train_of_thought - event_delimiter = "␃🔚␗" if cancellation_event.is_set(): return try: @@ -1025,7 +1040,8 @@ async def event_generator( query_files=attached_file_context, tracer=tracer, cancellation_event=cancellation_event, - interrupt_queue=interrupt_queue, + interrupt_queue=child_interrupt_queue, + abort_message=event_delimiter, ): if isinstance(research_result, ResearchIteration): if research_result.summarizedResult: @@ -1218,6 +1234,7 @@ async def event_generator( send_status_func=partial(send_event, ChatEvent.STATUS), agent=agent, cancellation_event=cancellation_event, + interrupt_queue=child_interrupt_queue, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1471,9 +1488,17 @@ async def chat_ws( if data.get("type") == "interrupt": if current_task and not current_task.done(): # Send interrupt signal to the ongoing task - await interrupt_queue.put(data.get("query", "")) - logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}") - await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"})) + abort_message = "␃🔚␗" + await interrupt_queue.put(data.get("query") or abort_message) + logger.info( + f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}" + ) + if data.get("query"): + ack_type = "interrupt_message_acknowledged" + await websocket.send_text(json.dumps({"type": ack_type})) + else: + ack_type = "interrupt_acknowledged" + await websocket.send_text(json.dumps({"type": ack_type})) else: logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}") continue diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 050fd3f8..616acae2 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -224,6 +224,7 @@ async def research( query_files: str = None, cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, + abort_message: str = "␃🔚␗", ): max_document_searches = 7 max_online_searches = 3 @@ -246,6 +247,10 @@ async def research( # Update the query for the current research iteration if interrupt_query := get_message_from_queue(interrupt_queue): + if interrupt_query == abort_message: + cancellation_event.set() + logger.debug(f"Research cancelled by user {user} via interrupt queue.") + break # Add the interrupt query as a new user message to the research conversation history logger.info( f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"