mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Validate websocket origin before establishing connection
This commit is contained in:
@@ -21,7 +21,7 @@ from fastapi import (
|
|||||||
from fastapi.responses import RedirectResponse, Response, StreamingResponse
|
from fastapi.responses import RedirectResponse, Response, StreamingResponse
|
||||||
from fastapi.websockets import WebSocketState
|
from fastapi.websockets import WebSocketState
|
||||||
from starlette.authentication import has_required_scope, requires
|
from starlette.authentication import has_required_scope, requires
|
||||||
from starlette.requests import Headers
|
from starlette.requests import URL, Headers
|
||||||
|
|
||||||
from khoj.app.settings import ALLOWED_HOSTS
|
from khoj.app.settings import ALLOWED_HOSTS
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
@@ -1468,6 +1468,12 @@ async def chat_ws(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
):
|
):
|
||||||
|
# Validate WebSocket Origin
|
||||||
|
origin = websocket.headers.get("origin")
|
||||||
|
if not origin or URL(origin).hostname not in ALLOWED_HOSTS:
|
||||||
|
await websocket.close(code=1008, reason="Origin not allowed")
|
||||||
|
return
|
||||||
|
|
||||||
# Limit open websocket connections per user
|
# Limit open websocket connections per user
|
||||||
user = websocket.scope["user"].object
|
user = websocket.scope["user"].object
|
||||||
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
|
connection_manager = WebSocketConnectionManager(trial_user_max_connections=5, subscribed_user_max_connections=10)
|
||||||
|
|||||||
Reference in New Issue
Block a user