diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 3358aa79..a8080166 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -103,33 +103,39 @@ async def create_google_user(token: dict) -> KhojUser: return user -async def set_user_unsubscribed(email: str, type="standard") -> KhojUser: - user = await KhojUser.objects.filter(email=email, subscription_type=type).afirst() - if user: - user.is_subscribed = False - await user.asave() - return user - else: - return None - - -async def set_user_subscribed(email: str, type="standard") -> KhojUser: +async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser: user = await KhojUser.objects.filter(email=email).afirst() if user: user.subscription_type = type - user.is_subscribed = True - start_date = user.subscription_renewal_date or datetime.now() - user.subscription_renewal_date = start_date + timedelta(days=30) + if is_subscribed is not None: + user.is_subscribed = is_subscribed + if renewal_date is False: + user.subscription_renewal_date = None + elif renewal_date is not None: + user.subscription_renewal_date = renewal_date await user.asave() return user else: return None -def is_user_subscribed(email: str, type="standard") -> bool: - return KhojUser.objects.filter( - email=email, subscription_type=type, subscription_renewal_date__gte=datetime.now(tz=timezone.utc) - ).exists() +def get_user_subscription_state(email: str) -> str: + """Get subscription state of user + Valid state transitions: trial -> subscribed <-> unsubscribed OR expired + """ + user = KhojUser.objects.filter(email=email).first() + if user.subscription_type == KhojUser.SubscriptionType.TRIAL: + return "trial" + elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc): + return "subscribed" + elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc): + return "unsubscribed" + elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc): + return "expired" + + +async def get_user_by_email(email: str) -> KhojUser: + return await KhojUser.objects.filter(email=email).afirst() async def get_user_by_token(token: dict) -> KhojUser: diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index e86204bb..58c68349 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -167,29 +167,39 @@

Subscription Configured

-

Manage your subscription to Khoj Cloud

-
- {% if not is_subscribed %} -
- - Subscribe - - -
+ {% if subscription_state == "subscribed" %} +

You are subscribed to Khoj Cloud. Subscription will renew on {{ subscription_renewal_date }}

+ {% elif subscription_state == "unsubscribed" %} +

You are subscribed to Khoj Cloud. Subscription will expire on {{ subscription_renewal_date }}

+ {% elif subscription_state == "expired" %} +

Subscribe to Khoj Cloud. Subscription expired on {{ subscription_renewal_date }}

{% else %} +

Subscribe to Khoj Cloud

+ {% endif %} +
+ {% if subscription_state == "subscribed" %} + {% elif subscription_state == "unsubscribed" %} + + {% else %} + + Subscribe + + + {% endif %}
- {% endif %} @@ -258,8 +268,17 @@ }; function unsubscribe() { - fetch('/api/subscription?email=' + '{{username}}', { - method: 'DELETE', + fetch('/api/subscription?operation=cancel&email={{username}}', { + method: 'PATCH', + headers: { + 'Content-Type': 'application/json', + }, + }) + } + + function resubscribe() { + fetch('/api/subscription?operation=resubscribe&email={{username}}', { + method: 'PATCH', headers: { 'Content-Type': 'application/json', }, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 77e377e7..e2187f0b 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,5 +1,6 @@ # Standard Packages import concurrent.futures +from datetime import datetime, timezone import math import os import time @@ -746,48 +747,72 @@ async def subscribe(request: Request): # Invalid signature raise e - # Handle the event - success = True - if ( - event["type"] == "payment_intent.succeeded" - or event["type"] == "invoice.payment_succeeded" - or event["type"] == "customer.subscription.created" - ): - # Retrieve the customer's details - customer_id = event["data"]["object"]["customer"] - customer = stripe.Customer.retrieve(customer_id) - customer_email = customer["email"] - # Mark the customer as subscribed - user = await adapters.set_user_subscribed(customer_email) - if not user: - success = False - elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted": - # Retrieve the customer's details - customer_id = event["data"]["object"]["customer"] - customer = stripe.Customer.retrieve(customer_id) - customer_email = customer["email"] - # Mark the customer as unsubscribed - user = await adapters.set_user_unsubscribed(customer_email) - if not user: - success = False - else: + event_type = event["type"] + if event_type not in { + "invoice.paid", + "customer.subscription.updated", + "customer.subscription.deleted", + "subscription_schedule.canceled", + }: logger.warn(f"Unhandled Stripe event type: {event['type']}") return {"success": False} + # Retrieve the customer's details + subscription = event["data"]["object"] + customer_id = subscription["customer"] + customer = stripe.Customer.retrieve(customer_id) + customer_email = customer["email"] + + # Handle valid stripe webhook events + success = True + if event_type in {"invoice.paid"}: + # Mark the user as subscribed and update the next renewal date on payment + subscription = stripe.Subscription.list(customer=customer_id).data[0] + renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) + user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date) + success = user is not None + elif event_type in {"customer.subscription.updated"}: + user = await adapters.get_user_by_email(customer_email) + # Allow updating subscription status if paid user + if user.subscription_renewal_date: + # Mark user as unsubscribed or resubscribed + is_subscribed = not subscription["cancel_at_period_end"] + updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed) + success = updated_user is not None + elif event_type in {"customer.subscription.deleted"}: + # Reset the user to trial state + user = await adapters.set_user_subscription( + customer_email, is_subscribed=False, renewal_date=False, type="trial" + ) + success = user is not None + logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}') return {"success": success} -@api.delete("/subscription") +@api.patch("/subscription") @requires(["authenticated"]) -async def unsubscribe(request: Request, email: str): - customer = stripe.Customer.list(email=email).data - if not is_none_or_empty(customer): - customer_id = customer[0].id +async def unsubscribe(request: Request, email: str, operation: str): + # Retrieve the customer's details + customers = stripe.Customer.list(email=email).auto_paging_iter() + customer = next(customers, None) + if customer is None: + return {"success": False, "message": "Customer not found"} + + if operation == "cancel": + customer_id = customer.id for subscription in stripe.Subscription.list(customer=customer_id): stripe.Subscription.modify(subscription.id, cancel_at_period_end=True) - success = True - else: - success = False + return {"success": True} - return {"success": success} + elif operation == "resubscribe": + subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter() + # Find the subscription that is set to cancel at the end of the period + for subscription in subscriptions: + if subscription.cancel_at_period_end: + # Update the subscription to not cancel at the end of the period + stripe.Subscription.modify(subscription.id, cancel_at_period_end=False) + return {"success": True} + return {"success": False, "message": "No subscription found that is set to cancel"} + + return {"success": False, "message": "Invalid operation"} diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index b62f6b94..fd96dc8f 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -1,5 +1,4 @@ # System Packages -from datetime import datetime, timezone import json import os @@ -23,7 +22,7 @@ from database.adapters import ( get_user_github_config, get_user_notion_config, ConversationAdapters, - is_user_subscribed, + get_user_subscription_state, ) # Initialize Router @@ -118,9 +117,9 @@ def login_page(request: Request): def config_page(request: Request): user: KhojUser = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_is_subscribed = is_user_subscribed(user.email) - days_to_renewal = ( - (user.subscription_renewal_date - datetime.now(tz=timezone.utc)).days if user.subscription_renewal_date else 0 + user_subscription_state = get_user_subscription_state(user.email) + subscription_renewal_date = ( + user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None ) enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all()) @@ -147,8 +146,8 @@ def config_page(request: Request): "conversation_options": all_conversation_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "user_photo": user_picture, - "is_subscribed": user_is_subscribed, - "days_to_renewal": days_to_renewal, + "subscription_state": user_subscription_state, + "subscription_renewal_date": subscription_renewal_date, "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), }, )