diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 620e83cc..f319ad7b 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -160,7 +160,7 @@ class UserAuthenticationBackend(AuthenticationBackend): if subscribed: return ( AuthCredentials(["authenticated", "premium"]), - AuthenticatedKhojUser(user), + AuthenticatedKhojUser(user, client_application), ) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application) if state.anonymous_mode: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index b4cf9c4c..7b36b40e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -207,6 +207,8 @@ def subscription_to_state(subscription: Subscription) -> str: return SubscriptionState.TRIAL.value elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): return SubscriptionState.SUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date is None: + return SubscriptionState.EXPIRED.value elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): return SubscriptionState.UNSUBSCRIBED.value elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc): @@ -222,11 +224,11 @@ def get_user_subscription_state(email: str) -> str: return subscription_to_state(user_subscription) -async def aget_user_subscription_state(email: str) -> str: +async def aget_user_subscription_state(user: KhojUser) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired """ - user_subscription = await Subscription.objects.filter(user__email=email).afirst() + user_subscription = await Subscription.objects.filter(user=user).afirst() return subscription_to_state(user_subscription)