Harden the user check of the Notion integration

This commit is contained in:
Debanjum
2025-12-28 21:47:28 -08:00
parent b8eeefa0b1
commit 1b7ccd141d

View File

@@ -7,9 +7,9 @@ from concurrent.futures import ThreadPoolExecutor
import requests import requests
from fastapi import APIRouter, BackgroundTasks, Request, Response from fastapi import APIRouter, BackgroundTasks, Request, Response
from starlette.authentication import requires
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from khoj.database.adapters import aget_user_by_uuid
from khoj.database.models import KhojUser, NotionConfig from khoj.database.models import KhojUser, NotionConfig
from khoj.routers.helpers import configure_content from khoj.routers.helpers import configure_content
from khoj.utils.state import SearchType from khoj.utils.state import SearchType
@@ -31,19 +31,23 @@ async def run_in_executor(func, *args):
@notion_router.get("/auth/callback") @notion_router.get("/auth/callback")
@requires(["authenticated"], redirect="login_page")
async def notion_auth_callback(request: Request, background_tasks: BackgroundTasks): async def notion_auth_callback(request: Request, background_tasks: BackgroundTasks):
code = request.query_params.get("code") code = request.query_params.get("code")
state = request.query_params.get("state") state = request.query_params.get("state")
if not code or not state: if not code or not state:
return Response("Missing code or state", status_code=400) return Response("Missing code or state", status_code=400)
user: KhojUser = await aget_user_by_uuid(state) # Use authenticated user from session instead of trusting the state parameter
user: KhojUser = request.user.object
# Verify state matches user UUID as CSRF protection
if state != str(user.uuid):
logger.warning(f"Notion OAuth state mismatch for user {user.uuid}")
return Response("Invalid state parameter", status_code=400)
await NotionConfig.objects.filter(user=user).adelete() await NotionConfig.objects.filter(user=user).adelete()
if not user:
raise Exception("User not found")
bearer_token = f"{NOTION_OAUTH_CLIENT_ID}:{NOTION_OAUTH_CLIENT_SECRET}" bearer_token = f"{NOTION_OAUTH_CLIENT_ID}:{NOTION_OAUTH_CLIENT_SECRET}"
base64_encoded_token = base64.b64encode(bearer_token.encode()).decode() base64_encoded_token = base64.b64encode(bearer_token.encode()).decode()