Fix interrupt UX and research when using websocket via web app

This commit is contained in:
Debanjum
2025-07-17 15:36:26 -07:00
parent 0ecd5f497d
commit b90e2367d5
4 changed files with 56 additions and 23 deletions

View File

@@ -271,11 +271,13 @@ export default function Chat() {
const controlMessage = JSON.parse(lastMessage.data);
if (controlMessage.type === "interrupt_acknowledged") {
console.log("Interrupt acknowledged by server");
setSocketUrl(null);
setProcessQuerySignal(false);
return;
}
if (controlMessage.error) {
} else if (controlMessage.type === "interrupt_message_acknowledged") {
console.log("Interrupt message acknowledged by server");
setProcessQuerySignal(false);
return;
} else if (controlMessage.error) {
console.error("WebSocket error:", controlMessage.error);
return;
}
@@ -360,24 +362,20 @@ export default function Chat() {
);
console.log("Sent interrupt message via WebSocket:", interruptMessage);
// Update the current message with the new query but keep it in processing state
const messageToProcess = interruptMessage || queryToProcess;
// Mark the last message as completed
setMessages((prevMessages) => {
const newMessages = [...prevMessages];
const currentMessage = newMessages[newMessages.length - 1];
if (currentMessage && !currentMessage.completed) {
currentMessage.rawQuery = messageToProcess;
currentMessage.completed = !!interruptMessage;
}
if (currentMessage) currentMessage.completed = true;
return newMessages;
});
// Update the query being processed
setQueryToProcess(messageToProcess);
setTriggeredAbort(!!interruptMessage);
// Set the interrupt message as the new query being processed
setQueryToProcess(interruptMessage);
setTriggeredAbort(false); // Always set to false after processing
setInterruptMessage("");
}
}, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]);
}, [triggeredAbort, sendMessage]);
useEffect(() => {
if (queryToProcess) {

View File

@@ -44,6 +44,7 @@ async def operate_environment(
query_files: str = None, # TODO: Handle query files
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: Optional[str] = "␃🔚␗",
tracer: dict = {},
):
response, user_input_message = None, None
@@ -144,6 +145,10 @@ async def operate_environment(
# Add interrupt query to current operator run
if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Operator run cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))

View File

@@ -69,6 +69,7 @@ from khoj.routers.helpers import (
generate_mermaidjs_diagram,
generate_summary_from_files,
get_conversation_command,
get_message_from_queue,
is_query_empty,
is_ready_to_chat,
read_chat_stream,
@@ -672,7 +673,7 @@ async def event_generator(
common: CommonQueryParams,
headers: Headers,
request_obj: Request | WebSocket,
interrupt_queue: asyncio.Queue = None,
parent_interrupt_queue: asyncio.Queue = None,
):
# Access the parameters from the body
q = body.q
@@ -697,8 +698,11 @@ async def event_generator(
user: KhojUser = user_scope.object
is_subscribed = has_required_scope(request_obj, ["premium"])
q = unquote(q)
defiltered_query = defilter_query(q)
train_of_thought = []
cancellation_event = asyncio.Event()
child_interrupt_queue: asyncio.Queue = asyncio.Queue()
event_delimiter = "␃🔚␗"
tracer: dict = {
"mid": turn_id,
@@ -744,6 +748,7 @@ async def event_generator(
disconnect_monitor_task = None
async def monitor_disconnection():
nonlocal q, defiltered_query
if isinstance(request_obj, Request):
try:
msg = await request_obj.receive()
@@ -779,12 +784,23 @@ async def event_generator(
except Exception as e:
logger.error(f"Error in disconnect monitor: {e}")
elif isinstance(request_obj, WebSocket):
while request_obj.client_state == WebSocketState.CONNECTED:
while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
await asyncio.sleep(1)
logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.")
cancellation_event.set()
if conversation:
# Check if any interrupt query is received
if interrupt_query := get_message_from_queue(parent_interrupt_queue):
if interrupt_query == event_delimiter:
cancellation_event.set()
logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
else:
# Pass the interrupt query to child tasks
logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
await child_interrupt_queue.put(interrupt_query)
q += f"\n\n{interrupt_query}"
defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
if conversation and cancellation_event.is_set():
await asyncio.shield(
save_to_conversation_log(
q,
@@ -821,7 +837,6 @@ async def event_generator(
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗"
if cancellation_event.is_set():
return
try:
@@ -1025,7 +1040,8 @@ async def event_generator(
query_files=attached_file_context,
tracer=tracer,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
interrupt_queue=child_interrupt_queue,
abort_message=event_delimiter,
):
if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult:
@@ -1218,6 +1234,7 @@ async def event_generator(
send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent,
cancellation_event=cancellation_event,
interrupt_queue=child_interrupt_queue,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1471,9 +1488,17 @@ async def chat_ws(
if data.get("type") == "interrupt":
if current_task and not current_task.done():
# Send interrupt signal to the ongoing task
await interrupt_queue.put(data.get("query", ""))
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}")
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"}))
abort_message = "␃🔚␗"
await interrupt_queue.put(data.get("query") or abort_message)
logger.info(
f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
)
if data.get("query"):
ack_type = "interrupt_message_acknowledged"
await websocket.send_text(json.dumps({"type": ack_type}))
else:
ack_type = "interrupt_acknowledged"
await websocket.send_text(json.dumps({"type": ack_type}))
else:
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
continue

View File

@@ -224,6 +224,7 @@ async def research(
query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: str = "␃🔚␗",
):
max_document_searches = 7
max_online_searches = 3
@@ -246,6 +247,10 @@ async def research(
# Update the query for the current research iteration
if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Research cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history
logger.info(
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"