diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 6ecf0c37..374f66f3 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -21,7 +21,7 @@ from fastapi import ( from fastapi.responses import RedirectResponse, Response, StreamingResponse from fastapi.websockets import WebSocketState from starlette.authentication import has_required_scope, requires -from starlette.requests import Headers +from starlette.requests import URL, Headers from khoj.app.settings import ALLOWED_HOSTS from khoj.database.adapters import ( @@ -1468,6 +1468,12 @@ async def chat_ws( websocket: WebSocket, common: CommonQueryParams, ): + # Validate WebSocket Origin + origin = websocket.headers.get("origin") + if not origin or URL(origin).hostname not in ALLOWED_HOSTS: + await websocket.close(code=1008, reason="Origin not allowed") + return + # Limit open websocket connections per user user = websocket.scope["user"].object connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)