Validate websocket origin before establishing connection

This commit is contained in:
Debanjum
2025-07-19 20:07:21 -05:00
parent 69a7d332fc
commit 749160e38d

View File

@@ -21,7 +21,7 @@ from fastapi import (
from fastapi.responses import RedirectResponse, Response, StreamingResponse
from fastapi.websockets import WebSocketState
from starlette.authentication import has_required_scope, requires
from starlette.requests import Headers
from starlette.requests import URL, Headers
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
@@ -1468,6 +1468,12 @@ async def chat_ws(
websocket: WebSocket,
common: CommonQueryParams,
):
# Validate WebSocket Origin
origin = websocket.headers.get("origin")
if not origin or URL(origin).hostname not in ALLOWED_HOSTS:
await websocket.close(code=1008, reason="Origin not allowed")
return
# Limit open websocket connections per user
user = websocket.scope["user"].object
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)