diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19c1fd81..fc64af0f 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -77,7 +77,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user: if state.billing_enabled: subscription_state = await aget_user_subscription_state(user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user @@ -97,7 +100,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user_with_token: if state.billing_enabled: subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 146de11c..7f76b796 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,7 +1,7 @@ import math import random import secrets -from datetime import date, datetime, timezone +from datetime import date, datetime, timezone, timedelta from typing import List, Optional, Type from enum import Enum @@ -140,6 +140,10 @@ def subscription_to_state(subscription: Subscription) -> str: if not subscription: return SubscriptionState.INVALID.value elif subscription.type == Subscription.Type.TRIAL: + # Trial subscription is valid for 7 days + if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7): + return SubscriptionState.EXPIRED.value + return SubscriptionState.TRIAL.value elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): return SubscriptionState.SUBSCRIBED.value