Add websocket chat api to ease bi-directional communication (#1207)

- Add a websocket api endpoint for chat. Reuse most of the existing chat
logic.
- Communicate from web app using the websocket chat api endpoint.
- Pass interrupt messages using websocket to guide research, operator
trajectory
Previously we were using the abort and send new POST /api/chat
mechanism.
This didn't scale well to multi-worker setups as a different worker
could pick up the new interrupt message request.
  Using websocket to send messages in the middle of long running tasks
  should work more naturally.
This commit is contained in:
Debanjum
2025-07-17 18:06:43 -07:00
committed by GitHub
13 changed files with 1206 additions and 849 deletions

View File

@@ -1,7 +1,8 @@
"use client"; "use client";
import styles from "./chat.module.css"; import styles from "./chat.module.css";
import React, { Suspense, useEffect, useRef, useState } from "react"; import React, { Suspense, useCallback, useEffect, useRef, useState } from "react";
import useWebSocket from "react-use-websocket";
import ChatHistory from "../components/chatHistory/chatHistory"; import ChatHistory from "../components/chatHistory/chatHistory";
import { useSearchParams } from "next/navigation"; import { useSearchParams } from "next/navigation";
@@ -45,7 +46,7 @@ interface ChatBodyDataProps {
isMobileWidth?: boolean; isMobileWidth?: boolean;
isLoggedIn: boolean; isLoggedIn: boolean;
setImages: (images: string[]) => void; setImages: (images: string[]) => void;
setTriggeredAbort: (triggeredAbort: boolean) => void; setTriggeredAbort: (triggeredAbort: boolean, newMessage?: string) => void;
isChatSideBarOpen: boolean; isChatSideBarOpen: boolean;
setIsChatSideBarOpen: (open: boolean) => void; setIsChatSideBarOpen: (open: boolean) => void;
isActive?: boolean; isActive?: boolean;
@@ -205,10 +206,10 @@ export default function Chat() {
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined); const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
const [images, setImages] = useState<string[]>([]); const [images, setImages] = useState<string[]>([]);
const [abortMessageStreamController, setAbortMessageStreamController] =
useState<AbortController | null>(null);
const [triggeredAbort, setTriggeredAbort] = useState(false); const [triggeredAbort, setTriggeredAbort] = useState(false);
const [shouldSendWithInterrupt, setShouldSendWithInterrupt] = useState(false); const [interruptMessage, setInterruptMessage] = useState<string>("");
const bufferRef = useRef("");
const idleTimerRef = useRef<NodeJS.Timeout | null>(null);
const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || { const { locationData, locationDataError, locationDataLoading } = useIPLocationData() || {
locationData: { locationData: {
@@ -222,6 +223,109 @@ export default function Chat() {
} = useAuthenticatedData(); } = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth(); const isMobileWidth = useIsMobileWidth();
const [isChatSideBarOpen, setIsChatSideBarOpen] = useState(false); const [isChatSideBarOpen, setIsChatSideBarOpen] = useState(false);
const [socketUrl, setSocketUrl] = useState<string | null>(null);
const disconnectFromServer = useCallback(() => {
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
setSocketUrl(null);
console.log("WebSocket disconnected due to inactivity.");
}, []);
const resetIdleTimer = useCallback(() => {
const idleTimeout = 10 * 60 * 1000; // 10 minutes
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
idleTimerRef.current = setTimeout(disconnectFromServer, idleTimeout);
}, [disconnectFromServer]);
const { sendMessage, lastMessage } = useWebSocket(socketUrl, {
share: true,
shouldReconnect: (closeEvent) => true,
reconnectAttempts: 10,
// reconnect using exponential backoff with jitter
reconnectInterval: (attemptNumber) => {
const baseDelay = 1000 * Math.pow(2, attemptNumber);
const jitter = Math.random() * 1000; // Add jitter up to 1s
return Math.min(baseDelay + jitter, 20000); // Cap backoff at 20s
},
onOpen: () => {
console.log("WebSocket connection established.");
resetIdleTimer();
},
onClose: () => {
console.log("WebSocket connection closed.");
if (idleTimerRef.current) {
clearTimeout(idleTimerRef.current);
}
},
});
useEffect(() => {
if (lastMessage !== null) {
resetIdleTimer();
// Check if this is a control message (JSON) rather than a streaming event
try {
const controlMessage = JSON.parse(lastMessage.data);
if (controlMessage.type === "interrupt_acknowledged") {
console.log("Interrupt acknowledged by server");
setProcessQuerySignal(false);
return;
} else if (controlMessage.type === "interrupt_message_acknowledged") {
console.log("Interrupt message acknowledged by server");
setProcessQuerySignal(false);
return;
} else if (controlMessage.error) {
console.error("WebSocket error:", controlMessage.error);
return;
}
} catch {
// Not a JSON control message, process as streaming event
}
const eventDelimiter = "␃🔚␗";
bufferRef.current += lastMessage.data;
let newEventIndex;
while ((newEventIndex = bufferRef.current.indexOf(eventDelimiter)) !== -1) {
const eventChunk = bufferRef.current.slice(0, newEventIndex);
bufferRef.current = bufferRef.current.slice(newEventIndex + eventDelimiter.length);
if (eventChunk) {
setMessages((prevMessages) => {
const newMessages = [...prevMessages];
const currentMessage = newMessages[newMessages.length - 1];
if (!currentMessage || currentMessage.completed) {
return prevMessages;
}
const { context, onlineContext, codeContext } = processMessageChunk(
eventChunk,
currentMessage,
currentMessage.context || [],
currentMessage.onlineContext || {},
currentMessage.codeContext || {},
);
// Update the current message with the new reference data
currentMessage.context = context;
currentMessage.onlineContext = onlineContext;
currentMessage.codeContext = codeContext;
if (currentMessage.completed) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImages([]);
if (conversationId) generateNewTitle(conversationId, setTitle);
}
return newMessages;
});
}
}
}
}, [lastMessage, setMessages]);
useEffect(() => { useEffect(() => {
fetch("/api/chat/options") fetch("/api/chat/options")
@@ -241,14 +345,37 @@ export default function Chat() {
welcomeConsole(); welcomeConsole();
}, []); }, []);
const handleTriggeredAbort = (value: boolean, newMessage?: string) => {
if (value) {
setInterruptMessage(newMessage || "");
}
setTriggeredAbort(value);
};
useEffect(() => { useEffect(() => {
if (triggeredAbort) { if (triggeredAbort) {
abortMessageStreamController?.abort(); sendMessage(
handleAbortedMessage(); JSON.stringify({
setShouldSendWithInterrupt(true); type: "interrupt",
setTriggeredAbort(false); query: interruptMessage,
}),
);
console.log("Sent interrupt message via WebSocket:", interruptMessage);
// Mark the last message as completed
setMessages((prevMessages) => {
const newMessages = [...prevMessages];
const currentMessage = newMessages[newMessages.length - 1];
if (currentMessage) currentMessage.completed = true;
return newMessages;
});
// Set the interrupt message as the new query being processed
setQueryToProcess(interruptMessage);
setTriggeredAbort(false); // Always set to false after processing
setInterruptMessage("");
} }
}, [triggeredAbort]); }, [triggeredAbort, sendMessage]);
useEffect(() => { useEffect(() => {
if (queryToProcess) { if (queryToProcess) {
@@ -266,7 +393,6 @@ export default function Chat() {
}; };
setMessages((prevMessages) => [...prevMessages, newStreamMessage]); setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true); setProcessQuerySignal(true);
setAbortMessageStreamController(new AbortController());
} }
}, [queryToProcess]); }, [queryToProcess]);
@@ -280,70 +406,19 @@ export default function Chat() {
} }
}, [processQuerySignal, locationDataLoading]); }, [processQuerySignal, locationDataLoading]);
async function readChatStream(response: Response) { useEffect(() => {
if (!response.ok) throw new Error(response.statusText); if (!conversationId) return;
if (!response.body) throw new Error("Response body is null");
const reader = response.body.getReader(); const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const decoder = new TextDecoder(); const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`;
const eventDelimiter = "␃🔚␗"; setSocketUrl(wsUrl);
let buffer = "";
// Track context used for chat response return () => {
let context: Context[] = []; if (idleTimerRef.current) {
let onlineContext: OnlineContext = {}; clearTimeout(idleTimerRef.current);
let codeContext: CodeContext = {};
while (true) {
const { done, value } = await reader.read();
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImages([]);
if (conversationId) generateNewTitle(conversationId, setTitle);
break;
}
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
let newEventIndex;
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
const event = buffer.slice(0, newEventIndex);
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
if (event) {
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) {
console.error("No current message found");
return;
}
// Track context used for chat response. References are rendered at the end of the chat
({ context, onlineContext, codeContext } = processMessageChunk(
event,
currentMessage,
context,
onlineContext,
codeContext,
));
setMessages([...messages]);
}
}
}
}
function handleAbortedMessage() {
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
currentMessage.completed = true;
setMessages([...messages]);
setProcessQuerySignal(false);
} }
};
}, [conversationId]);
async function chat() { async function chat() {
localStorage.removeItem("message"); localStorage.removeItem("message");
@@ -351,12 +426,19 @@ export default function Chat() {
setProcessQuerySignal(false); setProcessQuerySignal(false);
return; return;
} }
const chatAPI = "/api/chat?client=web";
// Re-establish WebSocket connection if disconnected
resetIdleTimer();
if (!socketUrl) {
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const wsUrl = `${protocol}//${window.location.host}/api/chat/ws?client=web`;
setSocketUrl(wsUrl);
}
const chatAPIBody = { const chatAPIBody = {
q: queryToProcess, q: queryToProcess,
conversation_id: conversationId, conversation_id: conversationId,
stream: true, stream: true,
interrupt: shouldSendWithInterrupt,
...(locationData && { ...(locationData && {
city: locationData.city, city: locationData.city,
region: locationData.region, region: locationData.region,
@@ -368,58 +450,7 @@ export default function Chat() {
...(uploadedFiles && { files: uploadedFiles }), ...(uploadedFiles && { files: uploadedFiles }),
}; };
// Reset the flag after using it sendMessage(JSON.stringify(chatAPIBody));
setShouldSendWithInterrupt(false);
const response = await fetch(chatAPI, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(chatAPIBody),
signal: abortMessageStreamController?.signal,
});
try {
await readChatStream(response);
} catch (err) {
let apiError;
try {
apiError = await response.json();
} catch (err) {
// Error reading API error response
apiError = {
streamError: "Error reading API error response stream. Expected JSON response.",
};
}
console.error(apiError);
// Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
// Render error message as current message
const errorMessage = (err as Error).message;
const errorName = (err as Error).name;
if (errorMessage.includes("Error in input stream"))
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
else if (apiError.streamError) {
currentMessage.rawResponse = `Umm, not sure what just happened but I lost my train of thought. Could you try again or ask my developers to look into this if the issue persists? They can be contacted at the Khoj Github, Discord or team@khoj.dev.`;
} else if (response.status === 429) {
"detail" in apiError
? (currentMessage.rawResponse = `${apiError.detail}`)
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
} else if (errorName === "AbortError") {
currentMessage.rawResponse = `I've stopped processing this message. If you'd like to continue, please send the message again.`;
} else {
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
}
// Complete message streaming teardown properly
currentMessage.completed = true;
setMessages([...messages]);
setQueryToProcess("");
setProcessQuerySignal(false);
}
} }
const handleConversationIdChange = (newConversationId: string) => { const handleConversationIdChange = (newConversationId: string) => {
@@ -522,7 +553,7 @@ export default function Chat() {
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
onConversationIdChange={handleConversationIdChange} onConversationIdChange={handleConversationIdChange}
setImages={setImages} setImages={setImages}
setTriggeredAbort={setTriggeredAbort} setTriggeredAbort={handleTriggeredAbort}
isChatSideBarOpen={isChatSideBarOpen} isChatSideBarOpen={isChatSideBarOpen}
setIsChatSideBarOpen={setIsChatSideBarOpen} setIsChatSideBarOpen={setIsChatSideBarOpen}
isActive={authenticatedData?.is_active} isActive={authenticatedData?.is_active}

View File

@@ -82,7 +82,7 @@ interface ChatInputProps {
isLoggedIn: boolean; isLoggedIn: boolean;
agentColor?: string; agentColor?: string;
isResearchModeEnabled?: boolean; isResearchModeEnabled?: boolean;
setTriggeredAbort: (value: boolean) => void; setTriggeredAbort: (value: boolean, newMessage?: string) => void;
prefillMessage?: string; prefillMessage?: string;
focus?: ChatInputFocus; focus?: ChatInputFocus;
} }
@@ -189,9 +189,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
return; return;
} }
// If currently processing, trigger abort first // If currently processing, handle interrupt first
if (props.sendDisabled) { if (props.sendDisabled) {
props.setTriggeredAbort(true); props.setTriggeredAbort(true, message.trim());
setMessage(""); // Clear the input
return; // Don't continue with regular message sending
} }
if (imageUploaded) { if (imageUploaded) {

View File

@@ -71,6 +71,7 @@
"react": "^18", "react": "^18",
"react-dom": "^18", "react-dom": "^18",
"react-hook-form": "^7.52.1", "react-hook-form": "^7.52.1",
"react-use-websocket": "^4.13.0",
"shadcn-ui": "^0.9.0", "shadcn-ui": "^0.9.0",
"swr": "^2.2.5", "swr": "^2.2.5",
"tailwind-merge": "^2.3.0", "tailwind-merge": "^2.3.0",

View File

@@ -4542,6 +4542,11 @@ react-style-singleton@^2.2.2, react-style-singleton@^2.2.3:
get-nonce "^1.0.0" get-nonce "^1.0.0"
tslib "^2.0.0" tslib "^2.0.0"
react-use-websocket@^4.13.0:
version "4.13.0"
resolved "https://registry.yarnpkg.com/react-use-websocket/-/react-use-websocket-4.13.0.tgz#9db1dbac6dc8ba2fdc02a5bba06205fbf6406736"
integrity sha512-anMuVoV//g2N76Wxqvqjjo1X48r9Np3y1/gMl7arX84tAPXdy5R7sB5lO5hvCzQRYjqXwV8XMAiEBOUbyrZFrw==
react@^18: react@^18:
version "18.3.1" version "18.3.1"
resolved "https://registry.yarnpkg.com/react/-/react-18.3.1.tgz#49ab892009c53933625bd16b2533fc754cab2891" resolved "https://registry.yarnpkg.com/react/-/react-18.3.1.tgz#49ab892009c53933625bd16b2533fc754cab2891"
@@ -4894,6 +4899,7 @@ string-argv@^0.3.2:
integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q== integrity sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==
"string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0: "string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0:
name string-width-cjs
version "4.2.3" version "4.2.3"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==

View File

@@ -1465,7 +1465,7 @@ class ConversationAdapters:
@require_valid_user @require_valid_user
async def save_conversation( async def save_conversation(
user: KhojUser, user: KhojUser,
chat_history: List[ChatMessageModel], new_messages: List[ChatMessageModel],
client_application: ClientApplication = None, client_application: ClientApplication = None,
conversation_id: str = None, conversation_id: str = None,
user_message: str = None, user_message: str = None,
@@ -1480,7 +1480,8 @@ class ConversationAdapters:
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
) )
conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} existing_messages = conversation.messages if conversation else []
conversation_log = {"chat": [msg.model_dump() for msg in existing_messages + new_messages]}
cleaned_conversation_log = clean_object_for_db(conversation_log) cleaned_conversation_log = clean_object_for_db(conversation_log)
if conversation: if conversation:
conversation.conversation_log = cleaned_conversation_log conversation.conversation_log = cleaned_conversation_log

View File

@@ -677,6 +677,34 @@ class Conversation(DbBaseModel):
continue continue
return validated_messages return validated_messages
async def pop_message(self, interrupted: bool = False) -> Optional[ChatMessageModel]:
"""
Remove and return the last message from the conversation log, persisting the change to the database.
When interrupted is True, we only drop the last message if it was an interrupted message by khoj.
"""
chat_log = self.conversation_log.get("chat", [])
if not chat_log:
return None
last_message = chat_log[-1]
is_interrupted_msg = last_message.get("by") == "khoj" and not last_message.get("message")
# When handling an interruption, only pop if the last message is an empty one by khoj.
if interrupted and not is_interrupted_msg:
return None
# Pop the last message, save the conversation, and then return the message.
popped_message_dict = chat_log.pop()
await self.asave()
# Try to validate and return the popped message as a Pydantic model
try:
return ChatMessageModel.model_validate(popped_message_dict)
except ValidationError as e:
logger.warning(f"Popped an invalid message from conversation. The removal has been saved. Error: {e}")
# The invalid message was removed and saved, but we can't return a valid model.
return None
class PublicConversation(DbBaseModel): class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE) source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

View File

@@ -220,7 +220,16 @@ def set_state(args):
def start_server(app, host=None, port=None, socket=None): def start_server(app, host=None, port=None, socket=None):
logger.info("🌖 Khoj is ready to engage") logger.info("🌖 Khoj is ready to engage")
if socket: if socket:
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None) uvicorn.run(
app,
proxy_headers=True,
uds=socket,
log_level="debug" if state.verbose > 1 else "info",
use_colors=True,
log_config=None,
ws_ping_timeout=300,
timeout_keep_alive=60,
)
else: else:
uvicorn.run( uvicorn.run(
app, app,
@@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None):
log_level="debug" if state.verbose > 1 else "info", log_level="debug" if state.verbose > 1 else "info",
use_colors=True, use_colors=True,
log_config=None, log_config=None,
ws_ping_timeout=300,
timeout_keep_alive=60, timeout_keep_alive=60,
**state.ssl_config if state.ssl_config else {}, **state.ssl_config if state.ssl_config else {},
) )

View File

@@ -384,6 +384,7 @@ class ChatEvent(Enum):
METADATA = "metadata" METADATA = "metadata"
USAGE = "usage" USAGE = "usage"
END_RESPONSE = "end_response" END_RESPONSE = "end_response"
INTERRUPT = "interrupt"
def message_to_log( def message_to_log(
@@ -434,7 +435,6 @@ async def save_to_conversation_log(
q: str, q: str,
chat_response: str, chat_response: str,
user: KhojUser, user: KhojUser,
chat_history: List[ChatMessageModel],
user_message_time: str = None, user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [], compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {}, online_results: Dict[str, Any] = {},
@@ -480,22 +480,22 @@ async def save_to_conversation_log(
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
try: try:
updated_conversation = message_to_log( new_messages = message_to_log(
user_message=q, user_message=q,
chat_response=chat_response, chat_response=chat_response,
user_message_metadata=user_message_metadata, user_message_metadata=user_message_metadata,
khoj_message_metadata=khoj_message_metadata, khoj_message_metadata=khoj_message_metadata,
chat_history=chat_history, chat_history=[],
) )
except ValidationError as e: except ValidationError as e:
updated_conversation = None new_messages = None
logger.error(f"Error constructing chat history: {e}") logger.error(f"Error constructing chat history: {e}")
db_conversation = None db_conversation = None
if updated_conversation: if new_messages:
db_conversation = await ConversationAdapters.save_conversation( db_conversation = await ConversationAdapters.save_conversation(
user, user,
updated_conversation, new_messages,
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=q, user_message=q,

View File

@@ -7,6 +7,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun, OperatorRun,
construct_chat_history_for_operator, construct_chat_history_for_operator,
) )
@@ -22,7 +23,7 @@ from khoj.processor.operator.operator_environment_base import (
) )
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
from khoj.processor.operator.operator_environment_computer import ComputerEnvironment from khoj.processor.operator.operator_environment_computer import ComputerEnvironment
from khoj.routers.helpers import ChatEvent from khoj.routers.helpers import ChatEvent, get_message_from_queue
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@@ -42,6 +43,8 @@ async def operate_environment(
agent: Agent = None, agent: Agent = None,
query_files: str = None, # TODO: Handle query files query_files: str = None, # TODO: Handle query files
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: Optional[str] = "␃🔚␗",
tracer: dict = {}, tracer: dict = {},
): ):
response, user_input_message = None, None response, user_input_message = None, None
@@ -140,6 +143,18 @@ async def operate_environment(
logger.debug(f"{environment_type.value} operator cancelled by client disconnect") logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break break
# Add interrupt query to current operator run
if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Operator run cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history
logger.info(f"Continuing operator run with the new instruction: {interrupt_query}")
operator_agent.messages.append(AgentMessage(role="user", content=interrupt_query))
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
iterations += 1 iterations += 1
# 1. Get current environment state # 1. Get current environment state

View File

@@ -10,9 +10,18 @@ from typing import Any, Dict, List, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import RedirectResponse, Response, StreamingResponse from fastapi.responses import RedirectResponse, Response, StreamingResponse
from fastapi.websockets import WebSocketState
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
from starlette.requests import Headers
from khoj.app.settings import ALLOWED_HOSTS from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
@@ -60,6 +69,7 @@ from khoj.routers.helpers import (
generate_mermaidjs_diagram, generate_mermaidjs_diagram,
generate_summary_from_files, generate_summary_from_files,
get_conversation_command, get_conversation_command,
get_message_from_queue,
is_query_empty, is_query_empty,
is_ready_to_chat, is_ready_to_chat,
read_chat_stream, read_chat_stream,
@@ -657,19 +667,13 @@ def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404) return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
@api_chat.post("") async def event_generator(
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
body: ChatRequestBody, body: ChatRequestBody,
rate_limiter_per_minute=Depends( user_scope: Any,
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") common: CommonQueryParams,
), headers: Headers,
rate_limiter_per_day=Depends( request_obj: Request | WebSocket,
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") parent_interrupt_queue: asyncio.Queue = None,
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
): ):
# Access the parameters from the body # Access the parameters from the body
q = body.q q = body.q
@@ -686,20 +690,19 @@ async def chat(
timezone = body.timezone timezone = body.timezone
raw_images = body.images raw_images = body.images
raw_query_files = body.files raw_query_files = body.files
interrupt_flag = body.interrupt
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
conversation = None conversation = None
user: KhojUser = request.user.object user: KhojUser = user_scope.object
is_subscribed = has_required_scope(request, ["premium"]) is_subscribed = has_required_scope(request_obj, ["premium"])
q = unquote(q) q = unquote(q)
defiltered_query = defilter_query(q)
train_of_thought = [] train_of_thought = []
nonlocal conversation_id
nonlocal raw_query_files
cancellation_event = asyncio.Event() cancellation_event = asyncio.Event()
child_interrupt_queue: asyncio.Queue = asyncio.Queue()
event_delimiter = "␃🔚␗"
tracer: dict = { tracer: dict = {
"mid": turn_id, "mid": turn_id,
@@ -709,13 +712,13 @@ async def chat(
} }
uploaded_images: list[str] = [] uploaded_images: list[str] = []
if images: if raw_images:
for image in images: for image in raw_images:
decoded_string = unquote(image) decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1] base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data) image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes) webp_image_bytes = convert_image_to_webp(image_bytes)
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id) uploaded_image = upload_user_image_to_bucket(webp_image_bytes, user.id)
if not uploaded_image: if not uploaded_image:
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
uploaded_image = f"data:image/webp;base64,{base64_webp_image}" uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
@@ -739,14 +742,16 @@ async def chat(
generated_mermaidjs_diagram: str = None generated_mermaidjs_diagram: str = None
generated_asset_results: Dict = dict() generated_asset_results: Dict = dict()
program_execution_context: List[str] = [] program_execution_context: List[str] = []
chat_history: List[ChatMessageModel] = [] user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create a task to monitor for disconnections # Create a task to monitor for disconnections
disconnect_monitor_task = None disconnect_monitor_task = None
async def monitor_disconnection(): async def monitor_disconnection():
nonlocal q, defiltered_query
if isinstance(request_obj, Request):
try: try:
msg = await request.receive() msg = await request_obj.receive()
if msg["type"] == "http.disconnect": if msg["type"] == "http.disconnect":
logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.") logger.debug(f"Request cancelled. User {user} disconnected from {common.client} client.")
cancellation_event.set() cancellation_event.set()
@@ -758,14 +763,13 @@ async def chat(
q, q,
chat_response="", chat_response="",
user=user, user=user,
chat_history=chat_history,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
operator_results=operator_results, operator_results=operator_results,
research_results=research_results, research_results=research_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=uploaded_images, query_images=uploaded_images,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
@@ -773,11 +777,53 @@ async def chat(
generated_images=generated_images, generated_images=generated_images,
raw_generated_files=generated_asset_results, raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram, generated_mermaidjs_diagram=generated_mermaidjs_diagram,
user_message_time=user_message_time,
tracer=tracer, tracer=tracer,
) )
) )
except Exception as e: except Exception as e:
logger.error(f"Error in disconnect monitor: {e}") logger.error(f"Error in disconnect monitor: {e}")
elif isinstance(request_obj, WebSocket):
while request_obj.client_state == WebSocketState.CONNECTED and not cancellation_event.is_set():
await asyncio.sleep(1)
# Check if any interrupt query is received
if interrupt_query := get_message_from_queue(parent_interrupt_queue):
if interrupt_query == event_delimiter:
cancellation_event.set()
logger.debug(f"Chat cancelled by user {user} via interrupt queue.")
else:
# Pass the interrupt query to child tasks
logger.info(f"Continuing chat with the new instruction: {interrupt_query}")
await child_interrupt_queue.put(interrupt_query)
q += f"\n\n{interrupt_query}"
defiltered_query += f"\n\n{defilter_query(interrupt_query)}"
logger.debug(f"WebSocket disconnected or chat cancelled by user {user} from {common.client} client.")
if conversation and cancellation_event.is_set():
await asyncio.shield(
save_to_conversation_log(
q,
chat_response="",
user=user,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
research_results=research_results,
inferred_queries=inferred_queries,
client_application=user_scope.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=generated_asset_results,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
user_message_time=user_message_time,
tracer=tracer,
)
)
# Cancel the disconnect monitor task if it is still running # Cancel the disconnect monitor task if it is still running
async def cancel_disconnect_monitor(): async def cancel_disconnect_monitor():
@@ -791,7 +837,6 @@ async def chat(
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal ttft, train_of_thought nonlocal ttft, train_of_thought
event_delimiter = "␃🔚␗"
if cancellation_event.is_set(): if cancellation_event.is_set():
return return
try: try:
@@ -864,12 +909,12 @@ async def chat(
logger.info(f"Chat response total time: {latency:.3f} seconds") logger.info(f"Chat response total time: {latency:.3f} seconds")
logger.info(f"Chat response cost: ${cost:.5f}") logger.info(f"Chat response cost: ${cost:.5f}")
update_telemetry_state( update_telemetry_state(
request=request, request=request_obj,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
client=common.client, client=common.client,
user_agent=request.headers.get("user-agent"), user_agent=headers.get("user-agent"),
host=request.headers.get("host"), host=headers.get("host"),
metadata=chat_metadata, metadata=chat_metadata,
) )
@@ -894,7 +939,7 @@ async def chat(
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, user,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
title=title, title=title,
create_new=body.create_new, create_new=body.create_new,
@@ -923,46 +968,11 @@ async def chat(
location = None location = None
if city or region or country or country_code: if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code) location = LocationData(city=city, region=region, country=country, country_code=country_code)
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_history = conversation.messages chat_history = conversation.messages
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
if interrupt_flag:
max_wait_time = 20.0 # seconds
wait_interval = 0.3 # seconds
wait_start = wait_current = time.time()
while wait_current - wait_start < max_wait_time:
# Refresh conversation to check if interrupted message saved to DB
conversation = await ConversationAdapters.aget_conversation_by_user(
user,
client_application=request.user.client_app,
conversation_id=conversation_id,
)
if (
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
break
await asyncio.sleep(wait_interval)
wait_current = time.time()
if wait_current - wait_start >= max_wait_time:
logger.warning(
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
)
# If interrupted message in DB # If interrupted message in DB
if ( if last_message := await conversation.pop_message(interrupted=True):
conversation
and conversation.messages
and conversation.messages[-1].by == "khoj"
and not conversation.messages[-1].message
):
# Populate context from interrupted message # Populate context from interrupted message
last_message = conversation.messages[-1]
online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []}
code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []}
compiled_references = [ref.model_dump() for ref in last_message.context or []] compiled_references = [ref.model_dump() for ref in last_message.context or []]
@@ -973,8 +983,6 @@ async def chat(
] ]
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
chat_history.pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]: if conversation_commands == [ConversationCommand.Default]:
@@ -1006,7 +1014,7 @@ async def chat(
cmds_to_rate_limit += conversation_commands cmds_to_rate_limit += conversation_commands
for cmd in cmds_to_rate_limit: for cmd in cmds_to_rate_limit:
try: try:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request_obj, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e: except HTTPException as e:
async for result in send_llm_response(str(e.detail), tracer.get("usage")): async for result in send_llm_response(str(e.detail), tracer.get("usage")):
@@ -1032,6 +1040,8 @@ async def chat(
query_files=attached_file_context, query_files=attached_file_context,
tracer=tracer, tracer=tracer,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=child_interrupt_queue,
abort_message=event_delimiter,
): ):
if isinstance(research_result, ResearchIteration): if isinstance(research_result, ResearchIteration):
if research_result.summarizedResult: if research_result.summarizedResult:
@@ -1091,9 +1101,7 @@ async def chat(
inferred_queries.extend(result[1]) inferred_queries.extend(result[1])
defiltered_query = result[2] defiltered_query = result[2]
except Exception as e: except Exception as e:
error_message = ( error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
f"Error searching knowledge base: {e}. Attempting to respond without document references."
)
logger.error(error_message, exc_info=True) logger.error(error_message, exc_info=True)
async for result in send_event( async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references" ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
@@ -1101,10 +1109,14 @@ async def chat(
yield result yield result
if not is_none_or_empty(compiled_references): if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) distinct_headings = set([d.get("compiled").split("\n")[0] for d in compiled_references if "compiled" in d])
distinct_files = set([d["file"] for d in compiled_references])
# Strip only leading # from headings # Strip only leading # from headings
headings = headings.replace("#", "") headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): async for result in send_event(
ChatEvent.STATUS,
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}",
):
yield result yield result
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
@@ -1222,6 +1234,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent, agent=agent,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=child_interrupt_queue,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1343,9 +1356,7 @@ async def chat(
else: else:
error_message = "Failed to generate diagram. Please try again later." error_message = "Failed to generate diagram. Please try again later."
program_execution_context.append( program_execution_context.append(
prompts.failed_diagram_generation.format( prompts.failed_diagram_generation.format(attempted_diagram=better_diagram_description_prompt)
attempted_diagram=better_diagram_description_prompt
)
) )
async for result in send_event(ChatEvent.STATUS, error_message): async for result in send_event(ChatEvent.STATUS, error_message):
@@ -1416,14 +1427,13 @@ async def chat(
q, q,
chat_response=full_response, chat_response=full_response,
user=user, user=user,
chat_history=chat_history,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
operator_results=operator_results, operator_results=operator_results,
research_results=research_results, research_results=research_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=request.user.client_app, client_application=user_scope.client_app,
conversation_id=str(conversation.id), conversation_id=str(conversation.id),
query_images=uploaded_images, query_images=uploaded_images,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
@@ -1450,11 +1460,141 @@ async def chat(
# Cancel the disconnect monitor task if it is still running # Cancel the disconnect monitor task if it is still running
await cancel_disconnect_monitor() await cancel_disconnect_monitor()
## Stream Text Response
if stream: @api_chat.websocket("/ws")
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") @requires(["authenticated"])
## Non-Streaming Text Response async def chat_ws(
websocket: WebSocket,
common: CommonQueryParams,
):
await websocket.accept()
# Initialize rate limiters
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
rate_limiter_per_day = ApiUserRateLimiter(
requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day"
)
image_rate_limiter = ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)
# Shared interrupt queue for communicating interrupts to ongoing research
interrupt_queue: asyncio.Queue = asyncio.Queue()
current_task = None
try:
while True:
data = await websocket.receive_json()
# Check if this is an interrupt message
if data.get("type") == "interrupt":
if current_task and not current_task.done():
# Send interrupt signal to the ongoing task
abort_message = "␃🔚␗"
await interrupt_queue.put(data.get("query") or abort_message)
logger.info(
f"Interrupt signal sent to ongoing task for user {websocket.scope['user'].object.id} with query: {data.get('query')}"
)
if data.get("query"):
ack_type = "interrupt_message_acknowledged"
await websocket.send_text(json.dumps({"type": ack_type}))
else:
ack_type = "interrupt_acknowledged"
await websocket.send_text(json.dumps({"type": ack_type}))
else:
logger.info(f"No ongoing task to interrupt for user {websocket.scope['user'].object.id}")
continue
# Handle regular chat messages - ensure data has required fields
if "q" not in data:
await websocket.send_text(json.dumps({"error": "Missing required field 'q' in chat message"}))
continue
body = ChatRequestBody(**data)
# Apply rate limiting manually
try:
rate_limiter_per_minute.check_websocket(websocket)
rate_limiter_per_day.check_websocket(websocket)
image_rate_limiter.check_websocket(websocket, body)
except HTTPException as e:
await websocket.send_text(json.dumps({"error": e.detail}))
continue
# Cancel any ongoing task before starting a new one
if current_task and not current_task.done():
current_task.cancel()
try:
await current_task
except asyncio.CancelledError:
pass
# Create a new task for processing the chat request
current_task = asyncio.create_task(process_chat_request(websocket, body, common, interrupt_queue))
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for user {websocket.scope['user'].object.id}")
if current_task and not current_task.done():
current_task.cancel()
except Exception as e:
logger.error(f"Error in websocket chat: {e}", exc_info=True)
if current_task and not current_task.done():
current_task.cancel()
await websocket.close(code=1011, reason="Internal Server Error")
async def process_chat_request(
websocket: WebSocket,
body: ChatRequestBody,
common: CommonQueryParams,
interrupt_queue: asyncio.Queue,
):
"""Process a single chat request with interrupt support"""
try:
# Since we are using websockets, we can ignore the stream parameter and always stream
response_iterator = event_generator(
body,
websocket.scope["user"],
common,
websocket.headers,
websocket,
interrupt_queue,
)
async for event in response_iterator:
await websocket.send_text(event)
except asyncio.CancelledError:
logger.debug(f"Chat request cancelled for user {websocket.scope['user'].object.id}")
raise
except Exception as e:
logger.error(f"Error processing chat request: {e}", exc_info=True)
await websocket.send_text(json.dumps({"error": "Internal server error"}))
raise
@api_chat.post("")
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
body: ChatRequestBody,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
):
response_iterator = event_generator(
body,
request.user,
common,
request.headers,
request,
)
# Stream Text Response
if body.stream:
return StreamingResponse(response_iterator, media_type="text/plain")
# Non-Streaming Text Response
else: else:
response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator) response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View File

@@ -33,7 +33,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.utils import timezone as django_timezone from django.utils import timezone as django_timezone
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from starlette.requests import URL from starlette.requests import URL
@@ -1936,6 +1936,53 @@ class ApiUserRateLimiter:
# Add the current request to the cache # Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug) UserRequests.objects.create(user=user, slug=self.slug)
def check_websocket(self, websocket: WebSocket):
"""WebSocket-specific rate limiting method"""
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
# Remove requests outside of the time window
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
logger.info(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
)
# Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug)
class ApiImageRateLimiter: class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10): def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
@@ -1983,6 +2030,47 @@ class ApiImageRateLimiter:
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
) )
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
"""WebSocket-specific image rate limiting method"""
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class ConversationCommandRateLimiter: class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
@@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter:
self.subscribed_rate_limit = subscribed_rate_limit self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Research] self.restricted_commands = [ConversationCommand.Research]
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
if state.billing_enabled is False: if state.billing_enabled is False:
return return
@@ -2512,6 +2600,17 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
} }
def get_message_from_queue(queue: asyncio.Queue) -> Optional[str]:
"""Get any message in queue if available."""
if not queue:
return None
try:
# Non-blocking check for message in the queue
return queue.get_nowait()
except asyncio.QueueEmpty:
return None
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
is_active = has_required_scope(request, ["premium"]) is_active = has_required_scope(request, ["premium"])

View File

@@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
ResearchIteration, ResearchIteration,
ToolCall, ToolCall,
construct_iteration_history, construct_iteration_history,
construct_structured_message,
construct_tool_chat_history, construct_tool_chat_history,
load_complex_json, load_complex_json,
) )
@@ -24,6 +25,7 @@ from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
generate_summary_from_files, generate_summary_from_files,
get_message_from_queue,
grep_files, grep_files,
list_files, list_files,
search_documents, search_documents,
@@ -74,7 +76,7 @@ async def apick_next_tool(
): ):
previous_iteration = previous_iterations[-1] previous_iteration = previous_iterations[-1]
yield ResearchIteration( yield ResearchIteration(
query=query, query=ToolCall(name=previous_iteration.query.name, args={"query": query}, id=previous_iteration.query.id), # type: ignore
context=previous_iteration.context, context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext, onlineContext=previous_iteration.onlineContext,
codeContext=previous_iteration.codeContext, codeContext=previous_iteration.codeContext,
@@ -221,6 +223,8 @@ async def research(
tracer: dict = {}, tracer: dict = {},
query_files: str = None, query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: str = "␃🔚␗",
): ):
max_document_searches = 7 max_document_searches = 7
max_online_searches = 3 max_online_searches = 3
@@ -241,6 +245,26 @@ async def research(
logger.debug(f"Research cancelled. User {user} disconnected client.") logger.debug(f"Research cancelled. User {user} disconnected client.")
break break
# Update the query for the current research iteration
if interrupt_query := get_message_from_queue(interrupt_queue):
if interrupt_query == abort_message:
cancellation_event.set()
logger.debug(f"Research cancelled by user {user} via interrupt queue.")
break
# Add the interrupt query as a new user message to the research conversation history
logger.info(
f"Continuing research with the previous {len(previous_iterations)} iterations and new instruction: {interrupt_query}"
)
previous_iterations_history = construct_iteration_history(
previous_iterations, query, query_images, query_files
)
research_conversation_history += previous_iterations_history
query = interrupt_query
previous_iterations = []
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result
online_results: Dict = dict() online_results: Dict = dict()
code_results: Dict = dict() code_results: Dict = dict()
document_results: List[Dict[str, str]] = [] document_results: List[Dict[str, str]] = []
@@ -428,6 +452,7 @@ async def research(
agent=agent, agent=agent,
query_files=query_files, query_files=query_files,
cancellation_event=cancellation_event, cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:

View File

@@ -168,7 +168,6 @@ class ChatRequestBody(BaseModel):
images: Optional[list[str]] = None images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = [] files: Optional[list[FileAttachment]] = []
create_new: Optional[bool] = False create_new: Optional[bool] = False
interrupt: Optional[bool] = False
class Entry: class Entry: