From 0b91383debff16243657a9bc53bd3c053d9a6831 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 23 Dec 2024 17:48:57 -0800 Subject: [PATCH] Make post oauth next url redirect more robust, handle q params better --- src/khoj/routers/auth.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 94ac0a26..ff42e070 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -3,6 +3,7 @@ import datetime import logging import os from typing import Optional +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import requests from fastapi import APIRouter, Depends @@ -204,23 +205,33 @@ async def auth_post(request: Request): @auth_router.get("/redirect") async def auth(request: Request): - next_url = get_next_url(request) - for q in request.query_params: - if q in ["code", "state", "scope", "authuser", "prompt", "session_state", "access_type"]: - continue - if q != "next": - next_url += f"&{q}={request.query_params[q]}" + next_url_path = get_next_url(request) - code = request.query_params.get("code") + # Add query params from request, excluding OAuth params to next URL + oauth_params = {"code", "state", "scope", "authuser", "prompt", "session_state", "access_type", "next"} + query_params = {param: value for param, value in request.query_params.items() if param not in oauth_params} - # 1. Construct the full redirect URI including domain + # Rebuild next URL with updated query params + parsed_next_url_path = urlparse(next_url_path) + next_url = urlunparse( + ( + parsed_next_url_path.scheme, + parsed_next_url_path.netloc, + parsed_next_url_path.path, + parsed_next_url_path.params, + urlencode(query_params, doseq=True), + parsed_next_url_path.fragment, + ) + ) + + # Construct the full redirect URI including domain base_url = str(request.base_url).rstrip("/") - if not DISABLE_HTTPS: base_url = base_url.replace("http://", "https://") - redirect_uri = f"{base_url}{request.app.url_path_for('auth')}" + # Build the payload for the token request + code = request.query_params.get("code") payload = { "code": code, "client_id": os.environ["GOOGLE_CLIENT_ID"], @@ -229,12 +240,14 @@ async def auth(request: Request): "grant_type": "authorization_code", } + # Request the token from Google verified_data = requests.post( "https://oauth2.googleapis.com/token", headers={"Content-Type": "application/x-www-form-urlencoded"}, data=payload, ) + # Validate the OAuth response if verified_data.status_code != 200: logger.error(f"Token request failed: {verified_data.text}") try: @@ -245,20 +258,24 @@ async def auth(request: Request): verified_data.raise_for_status() credential = verified_data.json().get("id_token") - if not credential: logger.error("Missing id_token in OAuth response") return RedirectResponse(url="/login?error=invalid_token", status_code=HTTP_302_FOUND) + # Validate the OAuth token try: idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"]) except OAuthError as error: return HTMLResponse(f"

{error.error}

") + + # Get or create the authenticated user in the database khoj_user = await get_or_create_user(idinfo) + # Set the user session if the user is authenticated if khoj_user: request.session["user"] = dict(idinfo) + # Send a welcome email to new users if datetime.timedelta(minutes=3) > (datetime.datetime.now(datetime.timezone.utc) - khoj_user.date_joined): asyncio.create_task(send_welcome_email(idinfo["name"], idinfo["email"])) update_telemetry_state( @@ -269,6 +286,7 @@ async def auth(request: Request): ) logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}") + # Redirect the user to the next URL return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)