mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Make post oauth next url redirect more robust, handle q params better
This commit is contained in:
@@ -3,6 +3,7 @@ import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends
|
||||
@@ -204,23 +205,33 @@ async def auth_post(request: Request):
|
||||
|
||||
@auth_router.get("/redirect")
|
||||
async def auth(request: Request):
|
||||
next_url = get_next_url(request)
|
||||
for q in request.query_params:
|
||||
if q in ["code", "state", "scope", "authuser", "prompt", "session_state", "access_type"]:
|
||||
continue
|
||||
if q != "next":
|
||||
next_url += f"&{q}={request.query_params[q]}"
|
||||
next_url_path = get_next_url(request)
|
||||
|
||||
code = request.query_params.get("code")
|
||||
# Add query params from request, excluding OAuth params to next URL
|
||||
oauth_params = {"code", "state", "scope", "authuser", "prompt", "session_state", "access_type", "next"}
|
||||
query_params = {param: value for param, value in request.query_params.items() if param not in oauth_params}
|
||||
|
||||
# 1. Construct the full redirect URI including domain
|
||||
# Rebuild next URL with updated query params
|
||||
parsed_next_url_path = urlparse(next_url_path)
|
||||
next_url = urlunparse(
|
||||
(
|
||||
parsed_next_url_path.scheme,
|
||||
parsed_next_url_path.netloc,
|
||||
parsed_next_url_path.path,
|
||||
parsed_next_url_path.params,
|
||||
urlencode(query_params, doseq=True),
|
||||
parsed_next_url_path.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
# Construct the full redirect URI including domain
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
|
||||
if not DISABLE_HTTPS:
|
||||
base_url = base_url.replace("http://", "https://")
|
||||
|
||||
redirect_uri = f"{base_url}{request.app.url_path_for('auth')}"
|
||||
|
||||
# Build the payload for the token request
|
||||
code = request.query_params.get("code")
|
||||
payload = {
|
||||
"code": code,
|
||||
"client_id": os.environ["GOOGLE_CLIENT_ID"],
|
||||
@@ -229,12 +240,14 @@ async def auth(request: Request):
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
# Request the token from Google
|
||||
verified_data = requests.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data=payload,
|
||||
)
|
||||
|
||||
# Validate the OAuth response
|
||||
if verified_data.status_code != 200:
|
||||
logger.error(f"Token request failed: {verified_data.text}")
|
||||
try:
|
||||
@@ -245,20 +258,24 @@ async def auth(request: Request):
|
||||
verified_data.raise_for_status()
|
||||
|
||||
credential = verified_data.json().get("id_token")
|
||||
|
||||
if not credential:
|
||||
logger.error("Missing id_token in OAuth response")
|
||||
return RedirectResponse(url="/login?error=invalid_token", status_code=HTTP_302_FOUND)
|
||||
|
||||
# Validate the OAuth token
|
||||
try:
|
||||
idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
|
||||
except OAuthError as error:
|
||||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||
|
||||
# Get or create the authenticated user in the database
|
||||
khoj_user = await get_or_create_user(idinfo)
|
||||
|
||||
# Set the user session if the user is authenticated
|
||||
if khoj_user:
|
||||
request.session["user"] = dict(idinfo)
|
||||
|
||||
# Send a welcome email to new users
|
||||
if datetime.timedelta(minutes=3) > (datetime.datetime.now(datetime.timezone.utc) - khoj_user.date_joined):
|
||||
asyncio.create_task(send_welcome_email(idinfo["name"], idinfo["email"]))
|
||||
update_telemetry_state(
|
||||
@@ -269,6 +286,7 @@ async def auth(request: Request):
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||
|
||||
# Redirect the user to the next URL
|
||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user