Make post oauth next url redirect more robust, handle q params better

This commit is contained in:
Debanjum
2024-12-23 17:48:57 -08:00
parent 17f8ba732d
commit 0b91383deb

View File

@@ -3,6 +3,7 @@ import datetime
import logging
import os
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
import requests
from fastapi import APIRouter, Depends
@@ -204,23 +205,33 @@ async def auth_post(request: Request):
@auth_router.get("/redirect")
async def auth(request: Request):
next_url = get_next_url(request)
for q in request.query_params:
if q in ["code", "state", "scope", "authuser", "prompt", "session_state", "access_type"]:
continue
if q != "next":
next_url += f"&{q}={request.query_params[q]}"
next_url_path = get_next_url(request)
code = request.query_params.get("code")
# Add query params from request, excluding OAuth params to next URL
oauth_params = {"code", "state", "scope", "authuser", "prompt", "session_state", "access_type", "next"}
query_params = {param: value for param, value in request.query_params.items() if param not in oauth_params}
# 1. Construct the full redirect URI including domain
# Rebuild next URL with updated query params
parsed_next_url_path = urlparse(next_url_path)
next_url = urlunparse(
(
parsed_next_url_path.scheme,
parsed_next_url_path.netloc,
parsed_next_url_path.path,
parsed_next_url_path.params,
urlencode(query_params, doseq=True),
parsed_next_url_path.fragment,
)
)
# Construct the full redirect URI including domain
base_url = str(request.base_url).rstrip("/")
if not DISABLE_HTTPS:
base_url = base_url.replace("http://", "https://")
redirect_uri = f"{base_url}{request.app.url_path_for('auth')}"
# Build the payload for the token request
code = request.query_params.get("code")
payload = {
"code": code,
"client_id": os.environ["GOOGLE_CLIENT_ID"],
@@ -229,12 +240,14 @@ async def auth(request: Request):
"grant_type": "authorization_code",
}
# Request the token from Google
verified_data = requests.post(
"https://oauth2.googleapis.com/token",
headers={"Content-Type": "application/x-www-form-urlencoded"},
data=payload,
)
# Validate the OAuth response
if verified_data.status_code != 200:
logger.error(f"Token request failed: {verified_data.text}")
try:
@@ -245,20 +258,24 @@ async def auth(request: Request):
verified_data.raise_for_status()
credential = verified_data.json().get("id_token")
if not credential:
logger.error("Missing id_token in OAuth response")
return RedirectResponse(url="/login?error=invalid_token", status_code=HTTP_302_FOUND)
# Validate the OAuth token
try:
idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
except OAuthError as error:
return HTMLResponse(f"<h1>{error.error}</h1>")
# Get or create the authenticated user in the database
khoj_user = await get_or_create_user(idinfo)
# Set the user session if the user is authenticated
if khoj_user:
request.session["user"] = dict(idinfo)
# Send a welcome email to new users
if datetime.timedelta(minutes=3) > (datetime.datetime.now(datetime.timezone.utc) - khoj_user.date_joined):
asyncio.create_task(send_welcome_email(idinfo["name"], idinfo["email"]))
update_telemetry_state(
@@ -269,6 +286,7 @@ async def auth(request: Request):
)
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
# Redirect the user to the next URL
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)