mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Add websocket chat api to ease bi-directional communication (#1207)
- Add a websocket api endpoint for chat. Reuse most of the existing chat logic. - Communicate from web app using the websocket chat api endpoint. - Pass interrupt messages using websocket to guide research, operator trajectory Previously we were using the abort and send new POST /api/chat mechanism. This didn't scale well to multi-worker setups as a different worker could pick up the new interrupt message request. Using websocket to send messages in the middle of long running tasks should work more naturally.
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import styles from "./chat.module.css";
|
||||
import React, { Suspense, useEffect, useRef, useState } from "react";
|
||||
import React, { Suspense, useCallback, useEffect, useRef, useState } from "react";
|
||||
import useWebSocket from "react-use-websocket";
|
||||
|
||||
import ChatHistory from "../components/chatHistory/chatHistory";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
@@ -45,7 +46,7 @@ interface ChatBodyDataProps {
|
||||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
setImages: (images: string[]) => void;
|
||||
setTriggeredAbort: (triggeredAbort: boolean) => void;
|
||||
setTriggeredAbort: (triggeredAbort: boolean, newMessage?: string) => void;
|
||||
isChatSideBarOpen: boolean;
|
||||
setIsChatSideBarOpen: (open: boolean) => void;
|
||||
isActive?: boolean;
|
||||
@@ -205,10 +206,10 @@ export default function Chat() {
|
||||
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const [abortMessageStreamController, setAbortMessageStreamController] =
|
||||
useState<AbortController | null>(null);
|
||||
const [triggeredAbort, setTriggeredAbort] = useState(false);
|
||||
const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false);
|
||||
const [interruptMessage, setInterruptMessage] = useState<string>("");
|
||||
const bufferRef = useRef("");
|
||||
const idleTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || {
|
||||
locationData: {
|
||||
@@ -222,6 +223,109 @@ export default function Chat() {
|
||||
} = useAuthenticatedData();
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
const [isChatSideBarOpen, setIsChatSideBarOpen] = useState(false);
|
||||
const [socketUrl, setSocketUrl] = useState<string | null>(null);
|
||||
|
||||
const disconnectFromServer = useCallback(() => {
|
||||
if (idleTimerRef.current) {
|
||||
clearTimeout(idleTimerRef.current);
|
||||
}
|
||||
setSocketUrl(null);
|
||||
console.log("WebSocket disconnected due to inactivity.");
|
||||
}, []);
|
||||
|
||||
const resetIdleTimer = useCallback(() => {
|
||||
const idleTimeout = 10 * 60 * 1000; // 10 minutes
|
||||
if (idleTimerRef.current) {
|
||||
clearTimeout(idleTimerRef.current);
|
||||
}
|
||||
idleTimerRef.current = setTimeout(disconnectFromServer, idleTimeout);
|
||||
}, [disconnectFromServer]);
|
||||
|
||||
const { sendMessage, lastMessage } = useWebSocket(socketUrl, {
|
||||
share: true,
|
||||
shouldReconnect: (closeEvent) => true,
|
||||
reconnectAttempts: 10,
|
||||
// reconnect using exponential backoff with jitter
|
||||
reconnectInterval: (attemptNumber) => {
|
||||
const baseDelay = 1000 * Math.pow(2, attemptNumber);
|
||||
const jitter = Math.random() * 1000; // Add jitter up to 1s
|
||||
return Math.min(baseDelay + jitter, 20000); // Cap backoff at 20s
|
||||
},
|
||||
onOpen: () => {
|
||||
console.log("WebSocket connection established.");
|
||||
resetIdleTimer();
|
||||
},
|
||||
onClose: () => {
|
||||
console.log("WebSocket connection closed.");
|
||||
if (idleTimerRef.current) {
|
||||
clearTimeout(idleTimerRef.current);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (lastMessage !== null) {
|
||||
resetIdleTimer();
|
||||
// Check if this is a control message (JSON) rather than a streaming event
|
||||
try {
|
||||
const controlMessage = JSON.parse(lastMessage.data);
|
||||
if (controlMessage.type === "interrupt_acknowledged") {
|
||||
console.log("Interrupt acknowledged by server");
|
||||
setProcessQuerySignal(false);
|
||||
return;
|
||||
} else if (controlMessage.type === "interrupt_message_acknowledged") {
|
||||
console.log("Interrupt message acknowledged by server");
|
||||
setProcessQuerySignal(false);
|
||||
return;
|
||||
} else if (controlMessage.error) {
|
||||
console.error("WebSocket error:", controlMessage.error);
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// Not a JSON control message, process as streaming event
|
||||
}
|
||||
|
||||
const eventDelimiter = "␃🔚␗";
|
||||
bufferRef.current += lastMessage.data;
|
||||
|
||||
let newEventIndex;
|
||||
while ((newEventIndex = bufferRef.current.indexOf(eventDelimiter)) !== -1) {
|
||||
const eventChunk = bufferRef.current.slice(0, newEventIndex);
|
||||
bufferRef.current = bufferRef.current.slice(newEventIndex + eventDelimiter.length);
|
||||
if (eventChunk) {
|
||||
setMessages((prevMessages) => {
|
||||
const newMessages = [...prevMessages];
|
||||
const currentMessage = newMessages[newMessages.length - 1];
|
||||
if (!currentMessage || currentMessage.completed) {
|
||||
return prevMessages;
|
||||
}
|
||||
|
||||
const { context, onlineContext, codeContext } = processMessageChunk(
|
||||
eventChunk,
|
||||
currentMessage,
|
||||
currentMessage.context || [],
|
||||
currentMessage.onlineContext || {},
|
||||
currentMessage.codeContext || {},
|
||||
);
|
||||
|
||||
// Update the current message with the new reference data
|
||||
currentMessage.context = context;
|
||||
currentMessage.onlineContext = onlineContext;
|
||||
currentMessage.codeContext = codeContext;
|
||||
|
||||
if (currentMessage.completed) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImages([]);
|
||||
if (conversationId) generateNewTitle(conversationId, setTitle);
|
||||
}
|
||||
|
||||
return newMessages;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [lastMessage, setMessages]);
|
||||
|
||||
useEffect(() => {
|
||||
fetch("/api/chat/options")
|
||||
@@ -241,14 +345,37 @@ export default function Chat() {
|
||||
welcomeConsole();
|
||||
}, []);
|
||||
|
||||
const handleTriggeredAbort = (value: boolean, newMessage?: string) => {
|
||||
if (value) {
|
||||
setInterruptMessage(newMessage || "");
|
||||
}
|
||||
setTriggeredAbort(value);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (triggeredAbort) {
|
||||
abortMessageStreamController?.abort();
|
||||
handleAbortedMessage();
|
||||
setShouldSendWithInterrupt(true);
|
||||
setTriggeredAbort(false);
|
||||
sendMessage(
|
||||
JSON.stringify({
|
||||
type: "interrupt",
|
||||
query: interruptMessage,
|
||||
}),
|
||||
);
|
||||
console.log("Sent interrupt message via WebSocket:", interruptMessage);
|
||||
|
||||
// Mark the last message as completed
|
||||
setMessages((prevMessages) => {
|
||||
const newMessages = [...prevMessages];
|
||||
const currentMessage = newMessages[newMessages.length - 1];
|
||||
if (currentMessage) currentMessage.completed = true;
|
||||
return newMessages;
|
||||
});
|
||||
|
||||
// Set the interrupt message as the new query being processed
|
||||
setQueryToProcess(interruptMessage);
|
||||
setTriggeredAbort(false); // Always set to false after processing
|
||||
setInterruptMessage("");
|
||||
}
|
||||
}, [triggeredAbort]);
|
||||
}, [triggeredAbort, sendMessage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (queryToProcess) {
|
||||
@@ -266,7 +393,6 @@ export default function Chat() {
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
setAbortMessageStreamController(new AbortController());
|
||||
}
|
||||
}, [queryToProcess]);
|
||||
|
||||
@@ -280,70 +406,19 @@ export default function Chat() {
|
||||
}
|
||||
}, [processQuerySignal, locationDataLoading]);
|
||||
|
||||
async function readChatStream(response: Response) {
|
||||
if (!response.ok) throw new Error(response.statusText);
|
||||
if (!response.body) throw new Error("Response body is null");
|
||||
useEffect(() => {
|
||||
if (!conversationId) return;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
const eventDelimiter = "␃🔚␗";
|
||||
let buffer = "";
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`;
|
||||
setSocketUrl(wsUrl);
|
||||
|
||||
// Track context used for chat response
|
||||
let context: Context[] = [];
|
||||
let onlineContext: OnlineContext = {};
|
||||
let codeContext: CodeContext = {};
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImages([]);
|
||||
|
||||
if (conversationId) generateNewTitle(conversationId, setTitle);
|
||||
|
||||
break;
|
||||
return () => {
|
||||
if (idleTimerRef.current) {
|
||||
clearTimeout(idleTimerRef.current);
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
buffer += chunk;
|
||||
|
||||
let newEventIndex;
|
||||
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
|
||||
const event = buffer.slice(0, newEventIndex);
|
||||
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
|
||||
if (event) {
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
|
||||
if (!currentMessage) {
|
||||
console.error("No current message found");
|
||||
return;
|
||||
}
|
||||
|
||||
// Track context used for chat response. References are rendered at the end of the chat
|
||||
({ context, onlineContext, codeContext } = processMessageChunk(
|
||||
event,
|
||||
currentMessage,
|
||||
context,
|
||||
onlineContext,
|
||||
codeContext,
|
||||
));
|
||||
|
||||
setMessages([...messages]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleAbortedMessage() {
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
|
||||
currentMessage.completed = true;
|
||||
setMessages([...messages]);
|
||||
setProcessQuerySignal(false);
|
||||
}
|
||||
};
|
||||
}, [conversationId]);
|
||||
|
||||
async function chat() {
|
||||
localStorage.removeItem("message");
|
||||
@@ -351,12 +426,19 @@ export default function Chat() {
|
||||
setProcessQuerySignal(false);
|
||||
return;
|
||||
}
|
||||
const chatAPI = "/api/chat?client=web";
|
||||
|
||||
// Re-establish WebSocket connection if disconnected
|
||||
resetIdleTimer();
|
||||
if (!socketUrl) {
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`;
|
||||
setSocketUrl(wsUrl);
|
||||
}
|
||||
|
||||
const chatAPIBody = {
|
||||
q: queryToProcess,
|
||||
conversation_id: conversationId,
|
||||
stream: true,
|
||||
interrupt: shouldSendWithInterrupt,
|
||||
...(locationData && {
|
||||
city: locationData.city,
|
||||
region: locationData.region,
|
||||
@@ -368,58 +450,7 @@ export default function Chat() {
|
||||
...(uploadedFiles && { files: uploadedFiles }),
|
||||
};
|
||||
|
||||
// Reset the flag after using it
|
||||
setShouldSendWithInterrupt(false);
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(chatAPIBody),
|
||||
signal: abortMessageStreamController?.signal,
|
||||
});
|
||||
|
||||
try {
|
||||
await readChatStream(response);
|
||||
} catch (err) {
|
||||
let apiError;
|
||||
try {
|
||||
apiError = await response.json();
|
||||
} catch (err) {
|
||||
// Error reading API error response
|
||||
apiError = {
|
||||
streamError: "Error reading API error response stream. Expected JSON response.",
|
||||
};
|
||||
}
|
||||
console.error(apiError);
|
||||
// Retrieve latest message being processed
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
|
||||
// Render error message as current message
|
||||
const errorMessage = (err as Error).message;
|
||||
const errorName = (err as Error).name;
|
||||
if (errorMessage.includes("Error in input stream"))
|
||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||
else if (apiError.streamError) {
|
||||
currentMessage.rawResponse = `Umm, not sure what just happened but I lost my train of thought. Could you try again or ask my developers to look into this if the issue persists? They can be contacted at the Khoj Github, Discord or team@khoj.dev.`;
|
||||
} else if (response.status === 429) {
|
||||
"detail" in apiError
|
||||
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||
} else if (errorName === "AbortError") {
|
||||
currentMessage.rawResponse = `I've stopped processing this message. If you'd like to continue, please send the message again.`;
|
||||
} else {
|
||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||
}
|
||||
|
||||
// Complete message streaming teardown properly
|
||||
currentMessage.completed = true;
|
||||
setMessages([...messages]);
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
}
|
||||
sendMessage(JSON.stringify(chatAPIBody));
|
||||
}
|
||||
|
||||
const handleConversationIdChange = (newConversationId: string) => {
|
||||
@@ -522,7 +553,7 @@ export default function Chat() {
|
||||
isMobileWidth={isMobileWidth}
|
||||
onConversationIdChange={handleConversationIdChange}
|
||||
setImages={setImages}
|
||||
setTriggeredAbort={setTriggeredAbort}
|
||||
setTriggeredAbort={handleTriggeredAbort}
|
||||
isChatSideBarOpen={isChatSideBarOpen}
|
||||
setIsChatSideBarOpen={setIsChatSideBarOpen}
|
||||
isActive={authenticatedData?.is_active}
|
||||
|
||||
@@ -82,7 +82,7 @@ interface ChatInputProps {
|
||||
isLoggedIn: boolean;
|
||||
agentColor?: string;
|
||||
isResearchModeEnabled?: boolean;
|
||||
setTriggeredAbort: (value: boolean) => void;
|
||||
setTriggeredAbort: (value: boolean, newMessage?: string) => void;
|
||||
prefillMessage?: string;
|
||||
focus?: ChatInputFocus;
|
||||
}
|
||||
@@ -189,9 +189,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||
return;
|
||||
}
|
||||
|
||||
// If currently processing, trigger abort first
|
||||
// If currently processing, handle interrupt first
|
||||
if (props.sendDisabled) {
|
||||
props.setTriggeredAbort(true);
|
||||
props.setTriggeredAbort(true, message.trim());
|
||||
setMessage(""); // Clear the input
|
||||
return; // Don't continue with regular message sending
|
||||
}
|
||||
|
||||
if (imageUploaded) {
|
||||
|
||||
@@ -71,6 +71,7 @@
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
"react-hook-form": "^7.52.1",
|
||||
"react-use-websocket": "^4.13.0",
|
||||
"shadcn-ui": "^0.9.0",
|
||||
"swr": "^2.2.5",
|
||||
"tailwind-merge": "^2.3.0",
|
||||
|
||||
@@ -4542,6 +4542,11 @@ react-style-singleton@^2.2.2, react-style-singleton@^2.2.3:
|
||||
get-nonce "^1.0.0"
|
||||
tslib "^2.0.0"
|
||||
|
||||
react-use-websocket@^4.13.0:
|
||||
version "4.13.0"
|
||||
resolved "https://registry.yarnpkg.com/react-use-websocket/-/react-use-websocket-4.13.0.tgz#9db1dbac6dc8ba2fdc02a5bba06205fbf6406736"
|
||||
integrity sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw==
|
||||
|
||||
react@^18:
|
||||
version "18.3.1"
|
||||
resolved "https://registry.yarnpkg.com/react/-/react-18.3.1.tgz#49ab892009c53933625bd16b2533fc754cab2891"
|
||||
@@ -4894,6 +4899,7 @@ string-argv@^0.3.2:
|
||||
integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==
|
||||
|
||||
"string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0:
|
||||
name string-width-cjs
|
||||
version "4.2.3"
|
||||
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
|
||||
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
|
||||
|
||||
@@ -1465,7 +1465,7 @@ class ConversationAdapters:
|
||||
@require_valid_user
|
||||
async def save_conversation(
|
||||
user: KhojUser,
|
||||
chat_history: List[ChatMessageModel],
|
||||
new_messages: List[ChatMessageModel],
|
||||
client_application: ClientApplication = None,
|
||||
conversation_id: str = None,
|
||||
user_message: str = None,
|
||||
@@ -1480,7 +1480,8 @@ class ConversationAdapters:
|
||||
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||
)
|
||||
|
||||
conversation_log = {"chat": [msg.model_dump() for msg in chat_history]}
|
||||
existing_messages = conversation.messages if conversation else []
|
||||
conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]}
|
||||
cleaned_conversation_log = clean_object_for_db(conversation_log)
|
||||
if conversation:
|
||||
conversation.conversation_log = cleaned_conversation_log
|
||||
|
||||
@@ -677,6 +677,34 @@ class Conversation(DbBaseModel):
|
||||
continue
|
||||
return validated_messages
|
||||
|
||||
async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]:
|
||||
"""
|
||||
Remove and return the last message from the conversation log, persisting the change to the database.
|
||||
When interrupted is True, we only drop the last message if it was an interrupted message by khoj.
|
||||
"""
|
||||
chat_log = self.conversation_log.get("chat", [])
|
||||
|
||||
if not chat_log:
|
||||
return None
|
||||
|
||||
last_message = chat_log[-1]
|
||||
is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message")
|
||||
# When handling an interruption, only pop if the last message is an empty one by khoj.
|
||||
if interrupted and not is_interrupted_msg:
|
||||
return None
|
||||
|
||||
# Pop the last message, save the conversation, and then return the message.
|
||||
popped_message_dict = chat_log.pop()
|
||||
await self.asave()
|
||||
|
||||
# Try to validate and return the popped message as a Pydantic model
|
||||
try:
|
||||
return ChatMessageModel.model_validate(popped_message_dict)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}")
|
||||
# The invalid message was removed and saved, but we can't return a valid model.
|
||||
return None
|
||||
|
||||
|
||||
class PublicConversation(DbBaseModel):
|
||||
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
@@ -220,7 +220,16 @@ def set_state(args):
|
||||
def start_server(app, host=None, port=None, socket=None):
|
||||
logger.info("🌖 Khoj is ready to engage")
|
||||
if socket:
|
||||
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None)
|
||||
uvicorn.run(
|
||||
app,
|
||||
proxy_headers=True,
|
||||
uds=socket,
|
||||
log_level="debug" if state.verbose > 1 else "info",
|
||||
use_colors=True,
|
||||
log_config=None,
|
||||
ws_ping_timeout=300,
|
||||
timeout_keep_alive=60,
|
||||
)
|
||||
else:
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None):
|
||||
log_level="debug" if state.verbose > 1 else "info",
|
||||
use_colors=True,
|
||||
log_config=None,
|
||||
ws_ping_timeout=300,
|
||||
timeout_keep_alive=60,
|
||||
**state.ssl_config if state.ssl_config else {},
|
||||
)
|
||||
|
||||
@@ -384,6 +384,7 @@ class ChatEvent(Enum):
|
||||
METADATA = "metadata"
|
||||
USAGE = "usage"
|
||||
END_RESPONSE = "end_response"
|
||||
INTERRUPT = "interrupt"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
@@ -434,7 +435,6 @@ async def save_to_conversation_log(
|
||||
q: str,
|
||||
chat_response: str,
|
||||
user: KhojUser,
|
||||
chat_history: List[ChatMessageModel],
|
||||
user_message_time: str = None,
|
||||
compiled_references: List[Dict[str, Any]] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
@@ -480,22 +480,22 @@ async def save_to_conversation_log(
|
||||
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
|
||||
|
||||
try:
|
||||
updated_conversation = message_to_log(
|
||||
new_messages = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata=user_message_metadata,
|
||||
khoj_message_metadata=khoj_message_metadata,
|
||||
chat_history=chat_history,
|
||||
chat_history=[],
|
||||
)
|
||||
except ValidationError as e:
|
||||
updated_conversation = None
|
||||
new_messages = None
|
||||
logger.error(f"Error constructing chat history: {e}")
|
||||
|
||||
db_conversation = None
|
||||
if updated_conversation:
|
||||
if new_messages:
|
||||
db_conversation = await ConversationAdapters.save_conversation(
|
||||
user,
|
||||
updated_conversation,
|
||||
new_messages,
|
||||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
user_message=q,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Callable, List, Optional
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
construct_chat_history_for_operator,
|
||||
)
|
||||
@@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import (
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
|
||||
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
|
||||
from khoj.routers.helpers import ChatEvent
|
||||
from khoj.routers.helpers import ChatEvent, get_message_from_queue
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
@@ -42,6 +43,8 @@ async def operate_environment(
|
||||
agent: Agent = None,
|
||||
query_files: str = None, # TODO: Handle query files
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: Optional[str] = "␃🔚␗",
|
||||
tracer: dict = {},
|
||||
):
|
||||
response, user_input_message = None, None
|
||||
@@ -140,6 +143,18 @@ async def operate_environment(
|
||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||
break
|
||||
|
||||
# Add interrupt query to current operator run
|
||||
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||
if interrupt_query == abort_message:
|
||||
cancellation_event.set()
|
||||
logger.debug(f"Operator run cancelled by user {user} via interrupt queue.")
|
||||
break
|
||||
# Add the interrupt query as a new user message to the research conversation history
|
||||
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
|
||||
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
|
||||
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||
yield result
|
||||
|
||||
iterations += 1
|
||||
|
||||
# 1. Get current environment state
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -33,7 +33,7 @@ from apscheduler.job import Job
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.utils import timezone as django_timezone
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
@@ -1936,6 +1936,53 @@ class ApiUserRateLimiter:
|
||||
# Add the current request to the cache
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
def check_websocket(self, websocket: WebSocket):
|
||||
"""WebSocket-specific rate limiting method"""
|
||||
# Rate limiting disabled if billing is disabled
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
# Rate limiting is disabled if user unauthenticated.
|
||||
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
|
||||
return
|
||||
|
||||
user: KhojUser = websocket.scope["user"].object
|
||||
subscribed = has_required_scope(websocket, ["premium"])
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if subscribed and count_requests >= self.subscribed_requests:
|
||||
logger.info(
|
||||
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
||||
)
|
||||
if not subscribed and count_requests >= self.requests:
|
||||
if self.requests >= self.subscribed_requests:
|
||||
logger.info(
|
||||
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
|
||||
)
|
||||
|
||||
# Add the current request to the cache
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ApiImageRateLimiter:
|
||||
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||
@@ -1983,6 +2030,47 @@ class ApiImageRateLimiter:
|
||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||
)
|
||||
|
||||
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
|
||||
"""WebSocket-specific image rate limiting method"""
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
# Rate limiting is disabled if user unauthenticated.
|
||||
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
|
||||
return
|
||||
|
||||
if not body.images:
|
||||
return
|
||||
|
||||
# Check number of images
|
||||
if len(body.images) > self.max_images:
|
||||
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||
)
|
||||
|
||||
# Check total size of images
|
||||
total_size_mb = 0.0
|
||||
for image in body.images:
|
||||
# Unquote the image in case it's URL encoded
|
||||
image = unquote(image)
|
||||
# Assuming the image is a base64 encoded string
|
||||
# Remove the data:image/jpeg;base64, part if present
|
||||
if "," in image:
|
||||
image = image.split(",", 1)[1]
|
||||
|
||||
# Decode base64 to get the actual size
|
||||
image_bytes = base64.b64decode(image)
|
||||
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||
|
||||
if total_size_mb > self.max_combined_size_mb:
|
||||
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||
)
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||
@@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter:
|
||||
self.subscribed_rate_limit = subscribed_rate_limit
|
||||
self.restricted_commands = [ConversationCommand.Research]
|
||||
|
||||
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
||||
async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
@@ -2512,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||
}
|
||||
|
||||
|
||||
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
|
||||
"""Get any message in queue if available."""
|
||||
if not queue:
|
||||
return None
|
||||
try:
|
||||
# Non-blocking check for message in the queue
|
||||
return queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
is_active = has_required_scope(request, ["premium"])
|
||||
|
||||
@@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
|
||||
ResearchIteration,
|
||||
ToolCall,
|
||||
construct_iteration_history,
|
||||
construct_structured_message,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
@@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
generate_summary_from_files,
|
||||
get_message_from_queue,
|
||||
grep_files,
|
||||
list_files,
|
||||
search_documents,
|
||||
@@ -74,7 +76,7 @@ async def apick_next_tool(
|
||||
):
|
||||
previous_iteration = previous_iterations[-1]
|
||||
yield ResearchIteration(
|
||||
query=query,
|
||||
query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
|
||||
context=previous_iteration.context,
|
||||
onlineContext=previous_iteration.onlineContext,
|
||||
codeContext=previous_iteration.codeContext,
|
||||
@@ -221,6 +223,8 @@ async def research(
|
||||
tracer: dict = {},
|
||||
query_files: str = None,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
interrupt_queue: Optional[asyncio.Queue] = None,
|
||||
abort_message: str = "␃🔚␗",
|
||||
):
|
||||
max_document_searches = 7
|
||||
max_online_searches = 3
|
||||
@@ -241,6 +245,26 @@ async def research(
|
||||
logger.debug(f"Research cancelled. User {user} disconnected client.")
|
||||
break
|
||||
|
||||
# Update the query for the current research iteration
|
||||
if interrupt_query := get_message_from_queue(interrupt_queue):
|
||||
if interrupt_query == abort_message:
|
||||
cancellation_event.set()
|
||||
logger.debug(f"Research cancelled by user {user} via interrupt queue.")
|
||||
break
|
||||
# Add the interrupt query as a new user message to the research conversation history
|
||||
logger.info(
|
||||
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
|
||||
)
|
||||
previous_iterations_history = construct_iteration_history(
|
||||
previous_iterations, query, query_images, query_files
|
||||
)
|
||||
research_conversation_history += previous_iterations_history
|
||||
query = interrupt_query
|
||||
previous_iterations = []
|
||||
|
||||
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
@@ -428,6 +452,7 @@ async def research(
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
||||
@@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel):
|
||||
images: Optional[list[str]] = None
|
||||
files: Optional[list[FileAttachment]] = []
|
||||
create_new: Optional[bool] = False
|
||||
interrupt: Optional[bool] = False
|
||||
|
||||
|
||||
class Entry:
|
||||
|
||||
Reference in New Issue
Block a user