mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Fix interrupt UX and research when using websocket via web app
This commit is contained in:
@@ -271,11 +271,13 @@ export default function Chat() {
|
||||
const controlMessage = JSON.parse(lastMessage.data);
|
||||
if (controlMessage.type === "interrupt_acknowledged") {
|
||||
console.log("Interrupt acknowledged by server");
|
||||
setSocketUrl(null);
|
||||
setProcessQuerySignal(false);
|
||||
return;
|
||||
}
|
||||
if (controlMessage.error) {
|
||||
} else if (controlMessage.type === "interrupt_message_acknowledged") {
|
||||
console.log("Interrupt message acknowledged by server");
|
||||
setProcessQuerySignal(false);
|
||||
return;
|
||||
} else if (controlMessage.error) {
|
||||
console.error("WebSocket error:", controlMessage.error);
|
||||
return;
|
||||
}
|
||||
@@ -360,24 +362,20 @@ export default function Chat() {
|
||||
);
|
||||
console.log("Sent interrupt message via WebSocket:", interruptMessage);
|
||||
|
||||
// Update the current message with the new query but keep it in processing state
|
||||
const messageToProcess = interruptMessage || queryToProcess;
|
||||
// Mark the last message as completed
|
||||
setMessages((prevMessages) => {
|
||||
const newMessages = [...prevMessages];
|
||||
const currentMessage = newMessages[newMessages.length - 1];
|
||||
if (currentMessage && !currentMessage.completed) {
|
||||
currentMessage.rawQuery = messageToProcess;
|
||||
currentMessage.completed = !!interruptMessage;
|
||||
}
|
||||
if (currentMessage) currentMessage.completed = true;
|
||||
return newMessages;
|
||||
});
|
||||
|
||||
// Update the query being processed
|
||||
setQueryToProcess(messageToProcess);
|
||||
setTriggeredAbort(!!interruptMessage);
|
||||
// Set the interrupt message as the new query being processed
|
||||
setQueryToProcess(interruptMessage);
|
||||
setTriggeredAbort(false); // Always set to false after processing
|
||||
setInterruptMessage("");
|
||||
}
|
||||
}, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]);
|
||||
}, [triggeredAbort, sendMessage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (queryToProcess) {
|
||||
|
||||
@@ -44,6 +44,7 @@ async def operate_environment(
|
||||
query_files: str = None, # TODO: Handle query files
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: Optional[str] = "␃🔚␗",
|
||||
tracer: dict = {},
|
||||
):
|
||||
response, user_input_message = None, None
|
||||
@@ -144,6 +145,10 @@ async def operate_environment(
|
||||
|
||||
# Add interrupt query to current operator run
|
||||
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||
if interrupt_query == abort_message:
|
||||
cancellation_event.set()
|
||||
logger.debug(f"Operator run cancelled by user {user} via interrupt queue.")
|
||||
break
|
||||
# Add the interrupt query as a new user message to the research conversation history
|
||||
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
|
||||
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
|
||||
|
||||
@@ -69,6 +69,7 @@ from khoj.routers.helpers import (
|
||||
generate_mermaidjs_diagram,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
get_message_from_queue,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
read_chat_stream,
|
||||
@@ -672,7 +673,7 @@ async def event_generator(
|
||||
common: CommonQueryParams,
|
||||
headers: Headers,
|
||||
request_obj: Request | WebSocket,
|
||||
interrupt_queue: asyncio.Queue = None,
|
||||
parent_interrupt_queue: asyncio.Queue = None,
|
||||
):
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
@@ -697,8 +698,11 @@ async def event_generator(
|
||||
user: KhojUser = user_scope.object
|
||||
is_subscribed = has_required_scope(request_obj, ["premium"])
|
||||
q = unquote(q)
|
||||
defiltered_query = defilter_query(q)
|
||||
train_of_thought = []
|
||||
cancellation_event = asyncio.Event()
|
||||
child_interrupt_queue: asyncio.Queue = asyncio.Queue()
|
||||
event_delimiter = "␃🔚␗"
|
||||
|
||||
tracer: dict = {
|
||||
"mid": turn_id,
|
||||
@@ -744,6 +748,7 @@ async def event_generator(
|
||||
disconnect_monitor_task = None
|
||||
|
||||
async def monitor_disconnection():
|
||||
nonlocal q, defiltered_query
|
||||
if isinstance(request_obj, Request):
|
||||
try:
|
||||
msg = await request_obj.receive()
|
||||
@@ -779,12 +784,23 @@ async def event_generator(
|
||||
except Exception as e:
|
||||
logger.error(f"Error in disconnect monitor: {e}")
|
||||
elif isinstance(request_obj, WebSocket):
|
||||
while request_obj.client_state == WebSocketState.CONNECTED:
|
||||
while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.debug(f"WebSocket disconnected. User {user} from {common.client} client.")
|
||||
cancellation_event.set()
|
||||
if conversation:
|
||||
# Check if any interrupt query is received
|
||||
if interrupt_query := get_message_from_queue(parent_interrupt_queue):
|
||||
if interrupt_query == event_delimiter:
|
||||
cancellation_event.set()
|
||||
logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
|
||||
else:
|
||||
# Pass the interrupt query to child tasks
|
||||
logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
|
||||
await child_interrupt_queue.put(interrupt_query)
|
||||
q += f"\n\n{interrupt_query}"
|
||||
defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
|
||||
|
||||
logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
|
||||
if conversation and cancellation_event.is_set():
|
||||
await asyncio.shield(
|
||||
save_to_conversation_log(
|
||||
q,
|
||||
@@ -821,7 +837,6 @@ async def event_generator(
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal ttft, train_of_thought
|
||||
event_delimiter = "␃🔚␗"
|
||||
if cancellation_event.is_set():
|
||||
return
|
||||
try:
|
||||
@@ -1025,7 +1040,8 @@ async def event_generator(
|
||||
query_files=attached_file_context,
|
||||
tracer=tracer,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
interrupt_queue=child_interrupt_queue,
|
||||
abort_message=event_delimiter,
|
||||
):
|
||||
if isinstance(research_result, ResearchIteration):
|
||||
if research_result.summarizedResult:
|
||||
@@ -1218,6 +1234,7 @@ async def event_generator(
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
agent=agent,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=child_interrupt_queue,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
@@ -1471,9 +1488,17 @@ async def chat_ws(
|
||||
if data.get("type") == "interrupt":
|
||||
if current_task and not current_task.done():
|
||||
# Send interrupt signal to the ongoing task
|
||||
await interrupt_queue.put(data.get("query", ""))
|
||||
logger.info(f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id}")
|
||||
await websocket.send_text(json.dumps({"type": "interrupt_acknowledged"}))
|
||||
abort_message = "␃🔚␗"
|
||||
await interrupt_queue.put(data.get("query") or abort_message)
|
||||
logger.info(
|
||||
f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
|
||||
)
|
||||
if data.get("query"):
|
||||
ack_type = "interrupt_message_acknowledged"
|
||||
await websocket.send_text(json.dumps({"type": ack_type}))
|
||||
else:
|
||||
ack_type = "interrupt_acknowledged"
|
||||
await websocket.send_text(json.dumps({"type": ack_type}))
|
||||
else:
|
||||
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
|
||||
continue
|
||||
|
||||
@@ -224,6 +224,7 @@ async def research(
|
||||
query_files: str = None,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: str = "␃🔚␗",
|
||||
):
|
||||
max_document_searches = 7
|
||||
max_online_searches = 3
|
||||
@@ -246,6 +247,10 @@ async def research(
|
||||
|
||||
# Update the query for the current research iteration
|
||||
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||
if interrupt_query == abort_message:
|
||||
cancellation_event.set()
|
||||
logger.debug(f"Research cancelled by user {user} via interrupt queue.")
|
||||
break
|
||||
# Add the interrupt query as a new user message to the research conversation history
|
||||
logger.info(
|
||||
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
|
||||
|
||||
Reference in New Issue
Block a user