mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Limit number of new websocket connections allowed per user
This commit is contained in:
@@ -60,6 +60,7 @@ from khoj.routers.helpers import (
|
|||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
DeleteMessageRequestBody,
|
DeleteMessageRequestBody,
|
||||||
FeedbackData,
|
FeedbackData,
|
||||||
|
WebSocketConnectionManager,
|
||||||
acreate_title_from_history,
|
acreate_title_from_history,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
aget_data_sources_and_output_format,
|
aget_data_sources_and_output_format,
|
||||||
@@ -1467,8 +1468,21 @@ async def chat_ws(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
):
|
):
|
||||||
|
# Limit open websocket connections per user
|
||||||
|
user = websocket.scope["user"].object
|
||||||
|
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
|
||||||
|
connection_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if not await connection_manager.can_connect(websocket):
|
||||||
|
await websocket.close(code=1008, reason="Connection limit exceeded")
|
||||||
|
logger.info(f"WebSocket connection rejected for user {user.id}: connection limit exceeded")
|
||||||
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
|
# Note new websocket connection for the user
|
||||||
|
await connection_manager.register_connection(user, connection_id)
|
||||||
|
|
||||||
# Initialize rate limiters
|
# Initialize rate limiters
|
||||||
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
rate_limiter_per_minute = ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
|
||||||
rate_limiter_per_day = ApiUserRateLimiter(
|
rate_limiter_per_day = ApiUserRateLimiter(
|
||||||
@@ -1539,6 +1553,9 @@ async def chat_ws(
|
|||||||
if current_task and not current_task.done():
|
if current_task and not current_task.done():
|
||||||
current_task.cancel()
|
current_task.cancel()
|
||||||
await websocket.close(code=1011, reason="Internal Server Error")
|
await websocket.close(code=1011, reason="Internal Server Error")
|
||||||
|
finally:
|
||||||
|
# Always unregister the connection on disconnect
|
||||||
|
await connection_manager.unregister_connection(user, connection_id)
|
||||||
|
|
||||||
|
|
||||||
async def process_chat_request(
|
async def process_chat_request(
|
||||||
|
|||||||
@@ -2075,6 +2075,48 @@ class ApiImageRateLimiter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketConnectionManager:
|
||||||
|
"""Limit max open websockets per user."""
|
||||||
|
|
||||||
|
def __init__(self, trial_user_max_connections: int = 10, subscribed_user_max_connections: int = 10):
|
||||||
|
self.trial_user_max_connections = trial_user_max_connections
|
||||||
|
self.subscribed_user_max_connections = subscribed_user_max_connections
|
||||||
|
self.connection_slug_prefix = "ws_connection_"
|
||||||
|
# Set cleanup window to 24 hours for truly stale connections (e.g., server crashes)
|
||||||
|
self.cleanup_window = 86400 # 24 hours
|
||||||
|
|
||||||
|
async def can_connect(self, websocket: WebSocket) -> bool:
|
||||||
|
"""Check if user can establish a new WebSocket connection."""
|
||||||
|
# Cleanup very old connections (likely from server crashes)
|
||||||
|
user: KhojUser = websocket.scope["user"].object
|
||||||
|
subscribed = has_required_scope(websocket, ["premium"])
|
||||||
|
max_connections = self.subscribed_user_max_connections if subscribed else self.trial_user_max_connections
|
||||||
|
|
||||||
|
await self._cleanup_stale_connections(user)
|
||||||
|
|
||||||
|
# Count ALL connections for this user (not filtered by time)
|
||||||
|
active_connections = await UserRequests.objects.filter(
|
||||||
|
user=user, slug__startswith=self.connection_slug_prefix
|
||||||
|
).acount()
|
||||||
|
|
||||||
|
return active_connections < max_connections
|
||||||
|
|
||||||
|
async def register_connection(self, user: KhojUser, connection_id: str) -> None:
|
||||||
|
"""Register a new WebSocket connection."""
|
||||||
|
await UserRequests.objects.acreate(user=user, slug=f"{self.connection_slug_prefix}{connection_id}")
|
||||||
|
|
||||||
|
async def unregister_connection(self, user: KhojUser, connection_id: str) -> None:
|
||||||
|
"""Remove a WebSocket connection record."""
|
||||||
|
await UserRequests.objects.filter(user=user, slug=f"{self.connection_slug_prefix}{connection_id}").adelete()
|
||||||
|
|
||||||
|
async def _cleanup_stale_connections(self, user: KhojUser) -> None:
|
||||||
|
"""Remove connection records older than cleanup window."""
|
||||||
|
cutoff = django_timezone.now() - timedelta(seconds=self.cleanup_window)
|
||||||
|
await UserRequests.objects.filter(
|
||||||
|
user=user, slug__startswith=self.connection_slug_prefix, created_at__lt=cutoff
|
||||||
|
).adelete()
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommandRateLimiter:
|
class ConversationCommandRateLimiter:
|
||||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||||
self.slug = slug
|
self.slug = slug
|
||||||
|
|||||||
Reference in New Issue
Block a user