Fix interrupt UX and research when using websocket via web app

This commit is contained in:
Debanjum
2025-07-17 15:36:26 -07:00
parent 0ecd5f497d
commit b90e2367d5
4 changed files with 56 additions and 23 deletions

View File

@@ -271,11 +271,13 @@ export default function Chat() {
const controlMessage = JSON.parse(lastMessage.data); const controlMessage = JSON.parse(lastMessage.data);
if (controlMessage.type === "interrupt_acknowledged") { if (controlMessage.type === "interrupt_acknowledged") {
console.log("Interrupt acknowledged by server"); console.log("Interrupt acknowledged by server");
setSocketUrl(null);
setProcessQuerySignal(false); setProcessQuerySignal(false);
return; return;
} } else if (controlMessage.type === "interrupt_message_acknowledged") {
if (controlMessage.error) { console.log("Interrupt message acknowledged by server");
setProcessQuerySignal(false);
return;
} else if (controlMessage.error) {
console.error("WebSocket error:", controlMessage.error); console.error("WebSocket error:", controlMessage.error);
return; return;
} }
@@ -360,24 +362,20 @@ export default function Chat() {
); );
console.log("Sent interrupt message via WebSocket:", interruptMessage); console.log("Sent interrupt message via WebSocket:", interruptMessage);
// Update the current message with the new query but keep it in processing state // Mark the last message as completed
const messageToProcess = interruptMessage || queryToProcess;
setMessages((prevMessages) => { setMessages((prevMessages) => {
const newMessages = [...prevMessages]; const newMessages = [...prevMessages];
const currentMessage = newMessages[newMessages.length - 1]; const currentMessage = newMessages[newMessages.length - 1];
if (currentMessage && !currentMessage.completed) { if (currentMessage) currentMessage.completed = true;
currentMessage.rawQuery = messageToProcess;
currentMessage.completed = !!interruptMessage;
}
return newMessages; return newMessages;
}); });
// Update the query being processed // Set the interrupt message as the new query being processed
setQueryToProcess(messageToProcess); setQueryToProcess(interruptMessage);
setTriggeredAbort(!!interruptMessage); setTriggeredAbort(false); // Always set to false after processing
setInterruptMessage(""); setInterruptMessage("");
} }
}, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]); }, [triggeredAbort, sendMessage]);
useEffect(() => { useEffect(() => {
if (queryToProcess) { if (queryToProcess) {

View File

@@ -44,6 +44,7 @@ async def operate_environment(
query_files: str = None, # TODO: Handle query files query_files: str = None, # TODO: Handle query files
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None, interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: Optional[str] = "␃🔚␗",
tracer: dict = {}, tracer: dict = {},
): ):
response, user_input_message = None, None response, user_input_message = None, None
@@ -144,6 +145,10 @@ async def operate_environment(
# Add interrupt query to current operator run # Add interrupt query to current operator run
if interrupt_query := get_message_from_queue(interrupt_queue): if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Operator run cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history # Add the interrupt query as a new user message to the research conversation history
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}") logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query)) operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))

View File

@@ -69,6 +69,7 @@ from khoj.routers.helpers import (
generate_mermaidjs_diagram, generate_mermaidjs_diagram,
generate_summary_from_files, generate_summary_from_files,
get_conversation_command, get_conversation_command,
get_message_from_queue,
is_query_empty, is_query_empty,
is_ready_to_chat, is_ready_to_chat,
read_chat_stream, read_chat_stream,
@@ -672,7 +673,7 @@ async def event_generator(
common: CommonQueryParams, common: CommonQueryParams,
headers: Headers, headers: Headers,
request_obj: Request | WebSocket, request_obj: Request | WebSocket,
interrupt_queue: asyncio.Queue = None, parent_interrupt_queue: asyncio.Queue = None,
): ):
# Access the parameters from the body # Access the parameters from the body
q = body.q q = body.q
@@ -697,8 +698,11 @@ async def event_generator(
user: KhojUser = user_scope.object user: KhojUser = user_scope.object
is_subscribed = has_required_scope(request_obj, ["premium"]) is_subscribed = has_required_scope(request_obj, ["premium"])
q = unquote(q) q = unquote(q)
defiltered_query = defilter_query(q)
train_of_thought = [] train_of_thought = []
cancellation_event = asyncio.Event() cancellation_event = asyncio.Event()
child_interrupt_queue: asyncio.Queue = asyncio.Queue()
event_delimiter = "␃🔚␗"
tracer: dict = { tracer: dict = {
"mid": turn_id, "mid": turn_id,
@@ -744,6 +748,7 @@ async def event_generator(
disconnect_monitor_task = None disconnect_monitor_task = None
async def monitor_disconnection(): async def monitor_disconnection():
nonlocal q, defiltered_query
if isinstance(request_obj, Request): if isinstance(request_obj, Request):
try: try:
msg = await request_obj.receive() msg = await request_obj.receive()
@@ -779,12 +784,23 @@ async def event_generator(
except Exception as e: except Exception as e:
logger.error(f"Error in disconnect monitor: {e}") logger.error(f"Error in disconnect monitor: {e}")
elif isinstance(request_obj, WebSocket): elif isinstance(request_obj, WebSocket):
while request_obj.client_state == WebSocketState.CONNECTED: while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
await asyncio.sleep(1) await asyncio.sleep(1)
logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.") # Check if any interrupt query is received
cancellation_event.set() if interrupt_query := get_message_from_queue(parent_interrupt_queue):
if conversation: if interrupt_query == event_delimiter:
cancellation_event.set()
logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
else:
# Pass the interrupt query to child tasks
logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
await child_interrupt_queue.put(interrupt_query)
q += f"\n\n{interrupt_query}"
defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
if conversation and cancellation_event.is_set():
await asyncio.shield( await asyncio.shield(
save_to_conversation_log( save_to_conversation_log(
q, q,
@@ -821,7 +837,6 @@ async def event_generator(
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal ttft, train_of_thought nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗"
if cancellation_event.is_set(): if cancellation_event.is_set():
return return
try: try:
@@ -1025,7 +1040,8 @@ async def event_generator(
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue, interrupt_queue=child_interrupt_queue,
abort_message=event_delimiter,
): ):
if isinstance(research_result, ResearchIteration): if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult: if research_result.summarizedResult:
@@ -1218,6 +1234,7 @@ async def event_generator(
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent, agent=agent,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=child_interrupt_queue,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1471,9 +1488,17 @@ async def chat_ws(
if data.get("type") == "interrupt": if data.get("type") == "interrupt":
if current_task and not current_task.done(): if current_task and not current_task.done():
# Send interrupt signal to the ongoing task # Send interrupt signal to the ongoing task
await interrupt_queue.put(data.get("query", "")) abort_message = "␃🔚␗"
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}") await interrupt_queue.put(data.get("query") or abort_message)
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"})) logger.info(
f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
)
if data.get("query"):
ack_type = "interrupt_message_acknowledged"
await websocket.send_text(json.dumps({"type": ack_type}))
else:
ack_type = "interrupt_acknowledged"
await websocket.send_text(json.dumps({"type": ack_type}))
else: else:
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}") logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
continue continue

View File

@@ -224,6 +224,7 @@ async def research(
query_files: str = None, query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None, interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: str = "␃🔚␗",
): ):
max_document_searches = 7 max_document_searches = 7
max_online_searches = 3 max_online_searches = 3
@@ -246,6 +247,10 @@ async def research(
# Update the query for the current research iteration # Update the query for the current research iteration
if interrupt_query := get_message_from_queue(interrupt_queue): if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Research cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history # Add the interrupt query as a new user message to the research conversation history
logger.info( logger.info(
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}" f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"