Limit number of new websocket connections allowed per user

This commit is contained in:
Debanjum
2025-07-19 19:22:29 -05:00
parent 76ddf8645c
commit 69a7d332fc
2 changed files with 59 additions and 0 deletions

View File

@@ -60,6 +60,7 @@ from khoj.routers.helpers import (
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
DeleteMessageRequestBody, DeleteMessageRequestBody,
FeedbackData, FeedbackData,
WebSocketConnectionManager,
acreate_title_from_history, acreate_title_from_history,
agenerate_chat_response, agenerate_chat_response,
aget_data_sources_and_output_format, aget_data_sources_and_output_format,
@@ -1467,8 +1468,21 @@ async def chat_ws(
websocket: WebSocket, websocket: WebSocket,
common: CommonQueryParams, common: CommonQueryParams,
): ):
# Limit open websocket connections per user
user = websocket.scope["user"].object
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
connection_id = str(uuid.uuid4())
if not await connection_manager.can_connect(websocket):
await websocket.close(code=1008, reason="Connection limit exceeded")
logger.info(f"WebSocket connection rejected for user {user.id}: connection limit exceeded")
return
await websocket.accept() await websocket.accept()
# Note new websocket connection for the user
await connection_manager.register_connection(user, connection_id)
# Initialize rate limiters # Initialize rate limiters
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
rate_limiter_per_day = ApiUserRateLimiter( rate_limiter_per_day = ApiUserRateLimiter(
@@ -1539,6 +1553,9 @@ async def chat_ws(
if current_task and not current_task.done(): if current_task and not current_task.done():
current_task.cancel() current_task.cancel()
await websocket.close(code=1011, reason="Internal Server Error") await websocket.close(code=1011, reason="Internal Server Error")
finally:
# Always unregister the connection on disconnect
await connection_manager.unregister_connection(user, connection_id)
async def process_chat_request( async def process_chat_request(

View File

@@ -2075,6 +2075,48 @@ class ApiImageRateLimiter:
) )
class WebSocketConnectionManager:
"""Limit max open websockets per user."""
def __init__(self, trial_user_max_connections: int = 10, subscribed_user_max_connections: int = 10):
self.trial_user_max_connections = trial_user_max_connections
self.subscribed_user_max_connections = subscribed_user_max_connections
self.connection_slug_prefix = "ws_connection_"
# Set cleanup window to 24 hours for truly stale connections (e.g., server crashes)
self.cleanup_window = 86400 # 24 hours
async def can_connect(self, websocket: WebSocket) -> bool:
"""Check if user can establish a new WebSocket connection."""
# Cleanup very old connections (likely from server crashes)
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
max_connections = self.subscribed_user_max_connections if subscribed else self.trial_user_max_connections
await self._cleanup_stale_connections(user)
# Count ALL connections for this user (not filtered by time)
active_connections = await UserRequests.objects.filter(
user=user, slug__startswith=self.connection_slug_prefix
).acount()
return active_connections < max_connections
async def register_connection(self, user: KhojUser, connection_id: str) -> None:
"""Register a new WebSocket connection."""
await UserRequests.objects.acreate(user=user, slug=f"{self.connection_slug_prefix}{connection_id}")
async def unregister_connection(self, user: KhojUser, connection_id: str) -> None:
"""Remove a WebSocket connection record."""
await UserRequests.objects.filter(user=user, slug=f"{self.connection_slug_prefix}{connection_id}").adelete()
async def _cleanup_stale_connections(self, user: KhojUser) -> None:
"""Remove connection records older than cleanup window."""
cutoff = django_timezone.now() - timedelta(seconds=self.cleanup_window)
await UserRequests.objects.filter(
user=user, slug__startswith=self.connection_slug_prefix, created_at__lt=cutoff
).adelete()
class ConversationCommandRateLimiter: class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug self.slug = slug