mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Limit number of new websocket connections allowed per user
This commit is contained in:
@@ -60,6 +60,7 @@ from khoj.routers.helpers import (
|
||||
ConversationCommandRateLimiter,
|
||||
DeleteMessageRequestBody,
|
||||
FeedbackData,
|
||||
WebSocketConnectionManager,
|
||||
acreate_title_from_history,
|
||||
agenerate_chat_response,
|
||||
aget_data_sources_and_output_format,
|
||||
@@ -1467,8 +1468,21 @@ async def chat_ws(
|
||||
websocket: WebSocket,
|
||||
common: CommonQueryParams,
|
||||
):
|
||||
# Limit open websocket connections per user
|
||||
user = websocket.scope["user"].object
|
||||
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
|
||||
connection_id = str(uuid.uuid4())
|
||||
|
||||
if not await connection_manager.can_connect(websocket):
|
||||
await websocket.close(code=1008, reason="Connection limit exceeded")
|
||||
logger.info(f"WebSocket connection rejected for user {user.id}: connection limit exceeded")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# Note new websocket connection for the user
|
||||
await connection_manager.register_connection(user, connection_id)
|
||||
|
||||
# Initialize rate limiters
|
||||
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
||||
rate_limiter_per_day = ApiUserRateLimiter(
|
||||
@@ -1539,6 +1553,9 @@ async def chat_ws(
|
||||
if current_task and not current_task.done():
|
||||
current_task.cancel()
|
||||
await websocket.close(code=1011, reason="Internal Server Error")
|
||||
finally:
|
||||
# Always unregister the connection on disconnect
|
||||
await connection_manager.unregister_connection(user, connection_id)
|
||||
|
||||
|
||||
async def process_chat_request(
|
||||
|
||||
@@ -2075,6 +2075,48 @@ class ApiImageRateLimiter:
|
||||
)
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""Limit max open websockets per user."""
|
||||
|
||||
def __init__(self, trial_user_max_connections: int = 10, subscribed_user_max_connections: int = 10):
|
||||
self.trial_user_max_connections = trial_user_max_connections
|
||||
self.subscribed_user_max_connections = subscribed_user_max_connections
|
||||
self.connection_slug_prefix = "ws_connection_"
|
||||
# Set cleanup window to 24 hours for truly stale connections (e.g., server crashes)
|
||||
self.cleanup_window = 86400 # 24 hours
|
||||
|
||||
async def can_connect(self, websocket: WebSocket) -> bool:
|
||||
"""Check if user can establish a new WebSocket connection."""
|
||||
# Cleanup very old connections (likely from server crashes)
|
||||
user: KhojUser = websocket.scope["user"].object
|
||||
subscribed = has_required_scope(websocket, ["premium"])
|
||||
max_connections = self.subscribed_user_max_connections if subscribed else self.trial_user_max_connections
|
||||
|
||||
await self._cleanup_stale_connections(user)
|
||||
|
||||
# Count ALL connections for this user (not filtered by time)
|
||||
active_connections = await UserRequests.objects.filter(
|
||||
user=user, slug__startswith=self.connection_slug_prefix
|
||||
).acount()
|
||||
|
||||
return active_connections < max_connections
|
||||
|
||||
async def register_connection(self, user: KhojUser, connection_id: str) -> None:
|
||||
"""Register a new WebSocket connection."""
|
||||
await UserRequests.objects.acreate(user=user, slug=f"{self.connection_slug_prefix}{connection_id}")
|
||||
|
||||
async def unregister_connection(self, user: KhojUser, connection_id: str) -> None:
|
||||
"""Remove a WebSocket connection record."""
|
||||
await UserRequests.objects.filter(user=user, slug=f"{self.connection_slug_prefix}{connection_id}").adelete()
|
||||
|
||||
async def _cleanup_stale_connections(self, user: KhojUser) -> None:
|
||||
"""Remove connection records older than cleanup window."""
|
||||
cutoff = django_timezone.now() - timedelta(seconds=self.cleanup_window)
|
||||
await UserRequests.objects.filter(
|
||||
user=user, slug__startswith=self.connection_slug_prefix, created_at__lt=cutoff
|
||||
).adelete()
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||
self.slug = slug
|
||||
|
||||
Reference in New Issue
Block a user