diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 09c7651a..6ecf0c37 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -60,6 +60,7 @@ from khoj.routers.helpers import ( ConversationCommandRateLimiter, DeleteMessageRequestBody, FeedbackData, + WebSocketConnectionManager, acreate_title_from_history, agenerate_chat_response, aget_data_sources_and_output_format, @@ -1467,8 +1468,21 @@ async def chat_ws( websocket: WebSocket, common: CommonQueryParams, ): + # Limit open websocket connections per user + user = websocket.scope["user"].object + connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10) + connection_id = str(uuid.uuid4()) + + if not await connection_manager.can_connect(websocket): + await websocket.close(code=1008, reason="Connection limit exceeded") + logger.info(f"WebSocket connection rejected for user {user.id}: connection limit exceeded") + return + await websocket.accept() + # Note new websocket connection for the user + await connection_manager.register_connection(user, connection_id) + # Initialize rate limiters rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") rate_limiter_per_day = ApiUserRateLimiter( @@ -1539,6 +1553,9 @@ async def chat_ws( if current_task and not current_task.done(): current_task.cancel() await websocket.close(code=1011, reason="Internal Server Error") + finally: + # Always unregister the connection on disconnect + await connection_manager.unregister_connection(user, connection_id) async def process_chat_request( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 7b9a0d48..0bb2d7ba 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -2075,6 +2075,48 @@ class ApiImageRateLimiter: ) +class WebSocketConnectionManager: + """Limit max open websockets per user.""" + + def __init__(self, trial_user_max_connections: int = 10, subscribed_user_max_connections: int = 10): + self.trial_user_max_connections = trial_user_max_connections + self.subscribed_user_max_connections = subscribed_user_max_connections + self.connection_slug_prefix = "ws_connection_" + # Set cleanup window to 24 hours for truly stale connections (e.g., server crashes) + self.cleanup_window = 86400 # 24 hours + + async def can_connect(self, websocket: WebSocket) -> bool: + """Check if user can establish a new WebSocket connection.""" + # Cleanup very old connections (likely from server crashes) + user: KhojUser = websocket.scope["user"].object + subscribed = has_required_scope(websocket, ["premium"]) + max_connections = self.subscribed_user_max_connections if subscribed else self.trial_user_max_connections + + await self._cleanup_stale_connections(user) + + # Count ALL connections for this user (not filtered by time) + active_connections = await UserRequests.objects.filter( + user=user, slug__startswith=self.connection_slug_prefix + ).acount() + + return active_connections < max_connections + + async def register_connection(self, user: KhojUser, connection_id: str) -> None: + """Register a new WebSocket connection.""" + await UserRequests.objects.acreate(user=user, slug=f"{self.connection_slug_prefix}{connection_id}") + + async def unregister_connection(self, user: KhojUser, connection_id: str) -> None: + """Remove a WebSocket connection record.""" + await UserRequests.objects.filter(user=user, slug=f"{self.connection_slug_prefix}{connection_id}").adelete() + + async def _cleanup_stale_connections(self, user: KhojUser) -> None: + """Remove connection records older than cleanup window.""" + cutoff = django_timezone.now() - timedelta(seconds=self.cleanup_window) + await UserRequests.objects.filter( + user=user, slug__startswith=self.connection_slug_prefix, created_at__lt=cutoff + ).adelete() + + class ConversationCommandRateLimiter: def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): self.slug = slug