diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 014db373..f09fedc6 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -89,7 +89,9 @@ async def get_or_create_user_by_email(email: str) -> KhojUser: user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email}) await user.asave() - await Subscription.objects.acreate(user=user, type="trial") + user_subscription = await Subscription.objects.filter(user=user).afirst() + if not user_subscription: + await Subscription.objects.acreate(user=user, type="trial") return user @@ -124,30 +126,20 @@ def get_user_subscription(email: str) -> Optional[Subscription]: async def set_user_subscription( email: str, is_recurring=None, renewal_date=None, type="standard" ) -> Optional[Subscription]: - # Get or create the user object + # Get or create the user object and their subscription user = await get_or_create_user_by_email(email) - user_subscription = await Subscription.objects.filter(user=user).afirst() - if not user_subscription: - user = await get_user_by_email(email) - if not user: - return None - user_subscription = await Subscription.objects.acreate( - user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date - ) - return user_subscription - elif user_subscription: - user_subscription.type = type - if is_recurring is not None: - user_subscription.is_recurring = is_recurring - if renewal_date is False: - user_subscription.renewal_date = None - elif renewal_date is not None: - user_subscription.renewal_date = renewal_date - await user_subscription.asave() - return user_subscription - else: - return None + + # Update the user subscription state + user_subscription.type = type + if is_recurring is not None: + user_subscription.is_recurring = is_recurring + if renewal_date is False: + user_subscription.renewal_date = None + elif renewal_date is not None: + user_subscription.renewal_date = renewal_date + await user_subscription.asave() + return user_subscription def subscription_to_state(subscription: Subscription) -> str: diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py index bed53c8c..09d2a7d4 100644 --- a/src/khoj/routers/subscription.py +++ b/src/khoj/routers/subscription.py @@ -51,7 +51,7 @@ async def subscribe(request: Request): if event_type in {"invoice.paid"}: # Mark the user as subscribed and update the next renewal date on payment subscription = stripe.Subscription.list(customer=customer_id).data[0] - renewal_date = datetime.fromtimestamp(subscription["lines"]["data"][0]["period"]["end"], tz=timezone.utc) + renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date) success = user is not None elif event_type in {"customer.subscription.updated"}: