Use websocket chat api endpoint to communicate from web app

- Use websocket library to handle setup, reconnection from web app
Use react-use-websocket library to handle websocket connection and
reconnection logic. Previously connection wasn't re-established on
disconnects.

- Send interrupt messages with ws to update research, operator trajectory

Previously we were using the abort and send new POST /api/chat
mechanism.

But now we can use the websocket's bi-directional messaging capability
to send users messages in the middle of a research, operator run.

This change should
1. Allow for a faster, more interactive interruption to shift the
research direction without breaking the conversation flow. As
previously we were using the DB to communicate interrupts across
workers, this would take time and feel sluggish on the UX.

2. Be a more robust interrupt mechanism that'll work in multi worker
setups. As same worker is interacted with to send interrupt messages
instead of potentially new worker receiving the POST /api/chat with
the interrupt user message.

On the server we're using an asyncio Queue to pass messages down from
websocket api to researcher via event generator. This can be extended
to pass to other iterative agents like operator.
This commit is contained in:
Debanjum
2025-06-18 16:46:10 -07:00
parent 9f0eff6541
commit eaed0c839e
4 changed files with 172 additions and 130 deletions

View File

@@ -1,7 +1,8 @@
"use client";
import styles from "./chat.module.css";
import React, { Suspense, useEffect, useRef, useState } from "react";
import React, { Suspense, useCallback, useEffect, useRef, useState } from "react";
import useWebSocket from "react-use-websocket";
import ChatHistory from "../components/chatHistory/chatHistory";
import { useSearchParams } from "next/navigation";
@@ -45,7 +46,7 @@ interface ChatBodyDataProps {
isMobileWidth?: boolean;
isLoggedIn: boolean;
setImages: (images: string[]) => void;
setTriggeredAbort: (triggeredAbort: boolean) => void;
setTriggeredAbort: (triggeredAbort: boolean, newMessage?: string) => void;
isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean;
@@ -205,10 +206,10 @@ export default function Chat() {
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
const [images, setImages] = useState<string[]>([]);
const [abortMessageStreamController, setAbortMessageStreamController] =
useState<AbortController | null>(null);
const [triggeredAbort, setTriggeredAbort] = useState(false);
const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false);
const [interruptMessage, setInterruptMessage] = useState<string>("");
const bufferRef = useRef("");
const idleTimerRef = useRef<NodeJS.Timeout | null>(null);
const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || {
locationData: {
@@ -222,6 +223,107 @@ export default function Chat() {
} = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth();
const [isChatSideBarOpen, setIsChatSideBarOpen] = useState(false);
const [socketUrl, setSocketUrl] = useState<string | null>(null);
const disconnectFromServer = useCallback(() => {
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
setSocketUrl(null);
console.log("WebSocket disconnected due to inactivity.");
}, []);
const resetIdleTimer = useCallback(() => {
const idleTimeout = 10 * 60 * 1000; // 10 minutes
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
idleTimerRef.current = setTimeout(disconnectFromServer, idleTimeout);
}, [disconnectFromServer]);
const { sendMessage, lastMessage } = useWebSocket(socketUrl, {
share: true,
shouldReconnect: (closeEvent) => true,
reconnectAttempts: 10,
// reconnect using exponential backoff with jitter
reconnectInterval: (attemptNumber) => {
const baseDelay = 1000 * Math.pow(2, attemptNumber);
const jitter = Math.random() * 1000; // Add jitter up to 1s
return Math.min(baseDelay + jitter, 20000); // Cap backoff at 20s
},
onOpen: () => {
console.log("WebSocket connection established.");
resetIdleTimer();
},
onClose: () => {
console.log("WebSocket connection closed.");
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
},
});
useEffect(() => {
if (lastMessage !== null) {
resetIdleTimer();
// Check if this is a control message (JSON) rather than a streaming event
try {
const controlMessage = JSON.parse(lastMessage.data);
if (controlMessage.type === "interrupt_acknowledged") {
console.log("Interrupt acknowledged by server");
setSocketUrl(null);
setProcessQuerySignal(false);
return;
}
if (controlMessage.error) {
console.error("WebSocket error:", controlMessage.error);
return;
}
} catch {
// Not a JSON control message, process as streaming event
}
const eventDelimiter = "␃🔚␗";
bufferRef.current += lastMessage.data;
let newEventIndex;
while ((newEventIndex = bufferRef.current.indexOf(eventDelimiter)) !== -1) {
const eventChunk = bufferRef.current.slice(0, newEventIndex);
bufferRef.current = bufferRef.current.slice(newEventIndex + eventDelimiter.length);
if (eventChunk) {
setMessages((prevMessages) => {
const newMessages = [...prevMessages];
const currentMessage = newMessages[newMessages.length - 1];
if (!currentMessage || currentMessage.completed) {
return prevMessages;
}
const { context, onlineContext, codeContext } = processMessageChunk(
eventChunk,
currentMessage,
currentMessage.context || [],
currentMessage.onlineContext || {},
currentMessage.codeContext || {},
);
// Update the current message with the new reference data
currentMessage.context = context;
currentMessage.onlineContext = onlineContext;
currentMessage.codeContext = codeContext;
if (currentMessage.completed) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImages([]);
if (conversationId) generateNewTitle(conversationId, setTitle);
}
return newMessages;
});
}
}
}
}, [lastMessage, setMessages]);
useEffect(() => {
fetch("/api/chat/options")
@@ -241,14 +343,41 @@ export default function Chat() {
welcomeConsole();
}, []);
const handleTriggeredAbort = (value: boolean, newMessage?: string) => {
if (value) {
setInterruptMessage(newMessage || "");
}
setTriggeredAbort(value);
};
useEffect(() => {
if (triggeredAbort) {
abortMessageStreamController?.abort();
handleAbortedMessage();
setShouldSendWithInterrupt(true);
setTriggeredAbort(false);
sendMessage(
JSON.stringify({
type: "interrupt",
query: interruptMessage,
}),
);
console.log("Sent interrupt message via WebSocket:", interruptMessage);
// Update the current message with the new query but keep it in processing state
const messageToProcess = interruptMessage || queryToProcess;
setMessages((prevMessages) => {
const newMessages = [...prevMessages];
const currentMessage = newMessages[newMessages.length - 1];
if (currentMessage && !currentMessage.completed) {
currentMessage.rawQuery = messageToProcess;
currentMessage.completed = !!interruptMessage;
}
return newMessages;
});
// Update the query being processed
setQueryToProcess(messageToProcess);
setTriggeredAbort(!!interruptMessage);
setInterruptMessage("");
}
}, [triggeredAbort]);
}, [triggeredAbort, interruptMessage, queryToProcess, sendMessage]);
useEffect(() => {
if (queryToProcess) {
@@ -266,7 +395,6 @@ export default function Chat() {
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);
setAbortMessageStreamController(new AbortController());
}
}, [queryToProcess]);
@@ -280,70 +408,19 @@ export default function Chat() {
}
}, [processQuerySignal, locationDataLoading]);
async function readChatStream(response: Response) {
if (!response.ok) throw new Error(response.statusText);
if (!response.body) throw new Error("Response body is null");
useEffect(() => {
if (!conversationId) return;
const reader = response.body.getReader();
const decoder = new TextDecoder();
const eventDelimiter = "␃🔚␗";
let buffer = "";
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`;
setSocketUrl(wsUrl);
// Track context used for chat response
let context: Context[] = [];
let onlineContext: OnlineContext = {};
let codeContext: CodeContext = {};
while (true) {
const { done, value } = await reader.read();
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImages([]);
if (conversationId) generateNewTitle(conversationId, setTitle);
break;
return () => {
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
let newEventIndex;
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
const event = buffer.slice(0, newEventIndex);
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
if (event) {
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) {
console.error("No current message found");
return;
}
// Track context used for chat response. References are rendered at the end of the chat
({ context, onlineContext, codeContext } = processMessageChunk(
event,
currentMessage,
context,
onlineContext,
codeContext,
));
setMessages([...messages]);
}
}
}
}
function handleAbortedMessage() {
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
currentMessage.completed = true;
setMessages([...messages]);
setProcessQuerySignal(false);
}
};
}, [conversationId]);
async function chat() {
localStorage.removeItem("message");
@@ -351,12 +428,19 @@ export default function Chat() {
setProcessQuerySignal(false);
return;
}
const chatAPI = "/api/chat?client=web";
// Re-establish WebSocket connection if disconnected
resetIdleTimer();
if (!socketUrl) {
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`;
setSocketUrl(wsUrl);
}
const chatAPIBody = {
q: queryToProcess,
conversation_id: conversationId,
stream: true,
interrupt: shouldSendWithInterrupt,
...(locationData && {
city: locationData.city,
region: locationData.region,
@@ -368,58 +452,7 @@ export default function Chat() {
...(uploadedFiles && { files: uploadedFiles }),
};
// Reset the flag after using it
setShouldSendWithInterrupt(false);
const response = await fetch(chatAPI, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(chatAPIBody),
signal: abortMessageStreamController?.signal,
});
try {
await readChatStream(response);
} catch (err) {
let apiError;
try {
apiError = await response.json();
} catch (err) {
// Error reading API error response
apiError = {
streamError: "Error reading API error response stream. Expected JSON response.",
};
}
console.error(apiError);
// Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
// Render error message as current message
const errorMessage = (err as Error).message;
const errorName = (err as Error).name;
if (errorMessage.includes("Error in input stream"))
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
else if (apiError.streamError) {
currentMessage.rawResponse = `Umm, not sure what just happened but I lost my train of thought. Could you try again or ask my developers to look into this if the issue persists? They can be contacted at the Khoj Github, Discord or team@khoj.dev.`;
} else if (response.status === 429) {
"detail" in apiError
? (currentMessage.rawResponse = `${apiError.detail}`)
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
} else if (errorName === "AbortError") {
currentMessage.rawResponse = `I've stopped processing this message. If you'd like to continue, please send the message again.`;
} else {
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
}
// Complete message streaming teardown properly
currentMessage.completed = true;
setMessages([...messages]);
setQueryToProcess("");
setProcessQuerySignal(false);
}
sendMessage(JSON.stringify(chatAPIBody));
}
const handleConversationIdChange = (newConversationId: string) => {
@@ -522,7 +555,7 @@ export default function Chat() {
isMobileWidth={isMobileWidth}
onConversationIdChange={handleConversationIdChange}
setImages={setImages}
setTriggeredAbort={setTriggeredAbort}
setTriggeredAbort={handleTriggeredAbort}
isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active}

View File

@@ -82,7 +82,7 @@ interface ChatInputProps {
isLoggedIn: boolean;
agentColor?: string;
isResearchModeEnabled?: boolean;
setTriggeredAbort: (value: boolean) => void;
setTriggeredAbort: (value: boolean, newMessage?: string) => void;
prefillMessage?: string;
focus?: ChatInputFocus;
}
@@ -189,9 +189,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
return;
}
// If currently processing, trigger abort first
// If currently processing, handle interrupt first
if (props.sendDisabled) {
props.setTriggeredAbort(true);
props.setTriggeredAbort(true, message.trim());
setMessage(""); // Clear the input
return; // Don't continue with regular message sending
}
if (imageUploaded) {

View File

@@ -71,6 +71,7 @@
"react": "^18",
"react-dom": "^18",
"react-hook-form": "^7.52.1",
"react-use-websocket": "^4.13.0",
"shadcn-ui": "^0.9.0",
"swr": "^2.2.5",
"tailwind-merge": "^2.3.0",

View File

@@ -4542,6 +4542,11 @@ react-style-singleton@^2.2.2, react-style-singleton@^2.2.3:
get-nonce "^1.0.0"
tslib "^2.0.0"
react-use-websocket@^4.13.0:
version "4.13.0"
resolved "https://registry.yarnpkg.com/react-use-websocket/-/react-use-websocket-4.13.0.tgz#9db1dbac6dc8ba2fdc02a5bba06205fbf6406736"
integrity sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw==
react@^18:
version "18.3.1"
resolved "https://registry.yarnpkg.com/react/-/react-18.3.1.tgz#49ab892009c53933625bd16b2533fc754cab2891"
@@ -4894,6 +4899,7 @@ string-argv@^0.3.2:
integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==
"string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0:
name string-width-cjs
version "4.2.3"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==