mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Make post oauth next url redirect more robust, handle q params better
This commit is contained in:
@@ -3,6 +3,7 @@ import datetime
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
@@ -204,23 +205,33 @@ async def auth_post(request: Request):
|
|||||||
|
|
||||||
@auth_router.get("/redirect")
|
@auth_router.get("/redirect")
|
||||||
async def auth(request: Request):
|
async def auth(request: Request):
|
||||||
next_url = get_next_url(request)
|
next_url_path = get_next_url(request)
|
||||||
for q in request.query_params:
|
|
||||||
if q in ["code", "state", "scope", "authuser", "prompt", "session_state", "access_type"]:
|
|
||||||
continue
|
|
||||||
if q != "next":
|
|
||||||
next_url += f"&{q}={request.query_params[q]}"
|
|
||||||
|
|
||||||
code = request.query_params.get("code")
|
# Add query params from request, excluding OAuth params to next URL
|
||||||
|
oauth_params = {"code", "state", "scope", "authuser", "prompt", "session_state", "access_type", "next"}
|
||||||
|
query_params = {param: value for param, value in request.query_params.items() if param not in oauth_params}
|
||||||
|
|
||||||
# 1. Construct the full redirect URI including domain
|
# Rebuild next URL with updated query params
|
||||||
|
parsed_next_url_path = urlparse(next_url_path)
|
||||||
|
next_url = urlunparse(
|
||||||
|
(
|
||||||
|
parsed_next_url_path.scheme,
|
||||||
|
parsed_next_url_path.netloc,
|
||||||
|
parsed_next_url_path.path,
|
||||||
|
parsed_next_url_path.params,
|
||||||
|
urlencode(query_params, doseq=True),
|
||||||
|
parsed_next_url_path.fragment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct the full redirect URI including domain
|
||||||
base_url = str(request.base_url).rstrip("/")
|
base_url = str(request.base_url).rstrip("/")
|
||||||
|
|
||||||
if not DISABLE_HTTPS:
|
if not DISABLE_HTTPS:
|
||||||
base_url = base_url.replace("http://", "https://")
|
base_url = base_url.replace("http://", "https://")
|
||||||
|
|
||||||
redirect_uri = f"{base_url}{request.app.url_path_for('auth')}"
|
redirect_uri = f"{base_url}{request.app.url_path_for('auth')}"
|
||||||
|
|
||||||
|
# Build the payload for the token request
|
||||||
|
code = request.query_params.get("code")
|
||||||
payload = {
|
payload = {
|
||||||
"code": code,
|
"code": code,
|
||||||
"client_id": os.environ["GOOGLE_CLIENT_ID"],
|
"client_id": os.environ["GOOGLE_CLIENT_ID"],
|
||||||
@@ -229,12 +240,14 @@ async def auth(request: Request):
|
|||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Request the token from Google
|
||||||
verified_data = requests.post(
|
verified_data = requests.post(
|
||||||
"https://oauth2.googleapis.com/token",
|
"https://oauth2.googleapis.com/token",
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
data=payload,
|
data=payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the OAuth response
|
||||||
if verified_data.status_code != 200:
|
if verified_data.status_code != 200:
|
||||||
logger.error(f"Token request failed: {verified_data.text}")
|
logger.error(f"Token request failed: {verified_data.text}")
|
||||||
try:
|
try:
|
||||||
@@ -245,20 +258,24 @@ async def auth(request: Request):
|
|||||||
verified_data.raise_for_status()
|
verified_data.raise_for_status()
|
||||||
|
|
||||||
credential = verified_data.json().get("id_token")
|
credential = verified_data.json().get("id_token")
|
||||||
|
|
||||||
if not credential:
|
if not credential:
|
||||||
logger.error("Missing id_token in OAuth response")
|
logger.error("Missing id_token in OAuth response")
|
||||||
return RedirectResponse(url="/login?error=invalid_token", status_code=HTTP_302_FOUND)
|
return RedirectResponse(url="/login?error=invalid_token", status_code=HTTP_302_FOUND)
|
||||||
|
|
||||||
|
# Validate the OAuth token
|
||||||
try:
|
try:
|
||||||
idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
|
idinfo = id_token.verify_oauth2_token(credential, google_requests.Request(), os.environ["GOOGLE_CLIENT_ID"])
|
||||||
except OAuthError as error:
|
except OAuthError as error:
|
||||||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||||
|
|
||||||
|
# Get or create the authenticated user in the database
|
||||||
khoj_user = await get_or_create_user(idinfo)
|
khoj_user = await get_or_create_user(idinfo)
|
||||||
|
|
||||||
|
# Set the user session if the user is authenticated
|
||||||
if khoj_user:
|
if khoj_user:
|
||||||
request.session["user"] = dict(idinfo)
|
request.session["user"] = dict(idinfo)
|
||||||
|
|
||||||
|
# Send a welcome email to new users
|
||||||
if datetime.timedelta(minutes=3) > (datetime.datetime.now(datetime.timezone.utc) - khoj_user.date_joined):
|
if datetime.timedelta(minutes=3) > (datetime.datetime.now(datetime.timezone.utc) - khoj_user.date_joined):
|
||||||
asyncio.create_task(send_welcome_email(idinfo["name"], idinfo["email"]))
|
asyncio.create_task(send_welcome_email(idinfo["name"], idinfo["email"]))
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
@@ -269,6 +286,7 @@ async def auth(request: Request):
|
|||||||
)
|
)
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||||
|
|
||||||
|
# Redirect the user to the next URL
|
||||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user