Use normalized email address for new users

Not check email deliverability for now to allow air-gapped usage or
authenticated/multi-user setups with admin managed otp

Closes #1069
This commit is contained in:
Debanjum
2025-01-11 10:00:55 +07:00
parent 85c34a5f0f
commit 6e955e158b
4 changed files with 36 additions and 12 deletions

View File

@@ -91,6 +91,7 @@ dependencies = [
"google-generativeai == 0.8.3",
"pyjson5 == 1.6.7",
"resend == 1.0.1",
"email-validator == 2.2.0",
]
dynamic = ["version"]

View File

@@ -72,6 +72,7 @@ from khoj.utils.helpers import (
generate_random_name,
in_debug_mode,
is_none_or_empty,
normalize_email,
timer,
)
@@ -231,17 +232,22 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
return user
async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
async def aget_or_create_user_by_email(input_email: str) -> tuple[KhojUser, bool]:
email, is_valid_email = normalize_email(input_email)
is_existing_user = await KhojUser.objects.filter(email=email).aexists()
# Validate email address of new users
if not is_existing_user and not is_valid_email:
logger.error(f"Account creation failed. Invalid email address: {email}")
return None, False
user, is_new = await KhojUser.objects.filter(email=email).aupdate_or_create(
defaults={"username": email, "email": email}
)
await user.asave()
if user:
# Generate a secure 6-digit numeric code
user.email_verification_code = f"{secrets.randbelow(1000000):06}"
user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5)
await user.asave()
# Generate a secure 6-digit numeric code
user.email_verification_code = f"{secrets.randbelow(1000000):06}"
user.email_verification_code_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=5)
await user.asave()
user_subscription = await Subscription.objects.filter(user=user).afirst()
if not user_subscription:
@@ -270,10 +276,15 @@ async def astart_trial_subscription(user: KhojUser) -> Subscription:
async def aget_user_validated_by_email_verification_code(code: str, email: str) -> tuple[Optional[KhojUser], bool]:
user = await KhojUser.objects.filter(email_verification_code=code, email=email).afirst()
# Normalize the email address
normalized_email, _ = normalize_email(email)
# Check if verification code exists for the user
user = await KhojUser.objects.filter(email_verification_code=code, email=normalized_email).afirst()
if not user:
return None, False
# Check if the code has expired
if user.email_verification_code_expiry < datetime.now(tz=timezone.utc):
return user, True
@@ -348,6 +359,8 @@ async def set_user_subscription(
) -> tuple[Optional[Subscription], bool]:
# Get or create the user object and their subscription
user, is_new = await aget_or_create_user_by_email(email)
if not user:
return None, is_new
user_subscription = await Subscription.objects.filter(user=user).afirst()
# Update the user subscription state

View File

@@ -86,12 +86,11 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
# Clear the session if user is already authenticated
request.session.pop("user", None)
email = form.email
user, is_new = await aget_or_create_user_by_email(email)
unique_id = user.email_verification_code
user, is_new = await aget_or_create_user_by_email(form.email)
if user:
await send_magic_link_email(email, unique_id, request.base_url)
unique_id = user.email_verification_code
await send_magic_link_email(user.email, unique_id, request.base_url)
if is_new:
update_telemetry_state(
request=request,

View File

@@ -27,6 +27,7 @@ import psutil
import requests
import torch
from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from magika import Magika
from PIL import Image
from pytz import country_names, country_timezones
@@ -614,3 +615,13 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o
base_url=api_base_url,
)
return client
def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]:
"""Normalize, validate and check deliverability of email address"""
lower_email = email.lower()
try:
valid_email = validate_email(lower_email, check_deliverability=check_deliverability)
return valid_email.normalized, True
except (EmailNotValidError, EmailUndeliverableError):
return lower_email, False