From 771f9bcfa1844827bb23b99be9f9bf6bf3e98d35 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 22:08:32 -0800 Subject: [PATCH] If the user subscription was created over 7 days ago, then their trial is expired --- src/khoj/configure.py | 10 ++++++++-- src/khoj/database/adapters/__init__.py | 6 +++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19c1fd81..fc64af0f 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -77,7 +77,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user: if state.billing_enabled: subscription_state = await aget_user_subscription_state(user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user @@ -97,7 +100,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user_with_token: if state.billing_enabled: subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 146de11c..7f76b796 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,7 +1,7 @@ import math import random import secrets -from datetime import date, datetime, timezone +from datetime import date, datetime, timezone, timedelta from typing import List, Optional, Type from enum import Enum @@ -140,6 +140,10 @@ def subscription_to_state(subscription: Subscription) -> str: if not subscription: return SubscriptionState.INVALID.value elif subscription.type == Subscription.Type.TRIAL: + # Trial subscription is valid for 7 days + if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7): + return SubscriptionState.EXPIRED.value + return SubscriptionState.TRIAL.value elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): return SubscriptionState.SUBSCRIBED.value