diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 12a127e9..2848dc2f 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -41,6 +41,7 @@ from khoj.search_filter.word_filter import WordFilter from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import generate_random_name +from khoj.database.adapters import get_or_create_user_by_email class SubscriptionState(Enum): @@ -85,6 +86,15 @@ async def get_or_create_user(token: dict) -> KhojUser: return user +async def get_or_create_user_by_email(email: str) -> KhojUser: + user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email}) + await user.asave() + + await Subscription.objects.acreate(user=user, type="trial") + + return user + + async def create_user_by_google_token(token: dict) -> KhojUser: user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create( defaults={"username": token.get("email"), "email": token.get("email")} @@ -115,7 +125,10 @@ def get_user_subscription(email: str) -> Optional[Subscription]: async def set_user_subscription( email: str, is_recurring=None, renewal_date=None, type="standard" ) -> Optional[Subscription]: - user_subscription = await Subscription.objects.filter(user__email=email).afirst() + # Get or create the user object + user = await get_or_create_user_by_email(email) + + user_subscription = await Subscription.objects.filter(user=user).afirst() if not user_subscription: user = await get_user_by_email(email) if not user: