Clean, merge subscription update events, API and functions

- Reduce webhook triggers for subscription updates
- Merge subscription update API endpoint, functions for (re/un-)subscribe
This commit is contained in:
Debanjum Singh Solanky
2023-11-08 15:14:51 -08:00
parent ef5c13f968
commit ec1395d072
4 changed files with 121 additions and 72 deletions

View File

@@ -103,33 +103,39 @@ async def create_google_user(token: dict) -> KhojUser:
return user return user
async def set_user_unsubscribed(email: str, type="standard") -> KhojUser: async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email, subscription_type=type).afirst()
if user:
user.is_subscribed = False
await user.asave()
return user
else:
return None
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst() user = await KhojUser.objects.filter(email=email).afirst()
if user: if user:
user.subscription_type = type user.subscription_type = type
user.is_subscribed = True if is_subscribed is not None:
start_date = user.subscription_renewal_date or datetime.now() user.is_subscribed = is_subscribed
user.subscription_renewal_date = start_date + timedelta(days=30) if renewal_date is False:
user.subscription_renewal_date = None
elif renewal_date is not None:
user.subscription_renewal_date = renewal_date
await user.asave() await user.asave()
return user return user
else: else:
return None return None
def is_user_subscribed(email: str, type="standard") -> bool: def get_user_subscription_state(email: str) -> str:
return KhojUser.objects.filter( """Get subscription state of user
email=email, subscription_type=type, subscription_renewal_date__gte=datetime.now(tz=timezone.utc) Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
).exists() """
user = KhojUser.objects.filter(email=email).first()
if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
return "trial"
elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed"
elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed"
elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc):
return "expired"
async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()
async def get_user_by_token(token: dict) -> KhojUser: async def get_user_by_token(token: dict) -> KhojUser:

View File

@@ -167,29 +167,39 @@
<h3 class="card-title"> <h3 class="card-title">
<span>Subscription</span> <span>Subscription</span>
<img id="configured-icon-subscription" <img id="configured-icon-subscription"
style="display: {% if not is_subscribed %}none{% endif %}" style="display: {% if subscription_state == 'trial' or subscription_state == 'expired' %}none{% endif %}"
class="configured-icon" class="configured-icon"
src="/static/assets/icons/confirm-icon.svg" src="/static/assets/icons/confirm-icon.svg"
alt="Configured"> alt="Configured">
</h3> </h3>
</div> </div>
<div class="card-description-row"> <div class="card-description-row">
<p class="card-description">Manage your subscription to Khoj Cloud</p> {% if subscription_state == "subscribed" %}
</div> <p class="card-description">You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>renew</b> on <b>{{ subscription_renewal_date }}</b></p>
{% if not is_subscribed %} {% elif subscription_state == "unsubscribed" %}
<div class="card-action-row"> <p class="card-description">You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>expire</b> on <b>{{ subscription_renewal_date }}</b></p>
<a class="card-button happy" href="{{ khoj_cloud_subscription_url }}?email={{ username }}" target="_blank"> {% elif subscription_state == "expired" %}
Subscribe <p class="card-description">Subscribe to Khoj Cloud. Subscription <b>expired</b> on <b>{{ subscription_renewal_date }}</b></p>
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% else %} {% else %}
<p class="card-description">Subscribe to Khoj Cloud</p>
{% endif %}
</div>
<div class="card-action-row"> <div class="card-action-row">
{% if subscription_state == "subscribed" %}
<button class="card-button" onclick="unsubscribe()"> <button class="card-button" onclick="unsubscribe()">
Unsubscribe Unsubscribe
</button> </button>
{% elif subscription_state == "unsubscribed" %}
<button class="card-button" onclick="resubscribe()">
Resubscribe
</button>
{% else %}
<a class="card-button happy" href="{{ khoj_cloud_subscription_url }}?prefilled_email={{ username }}" target="_blank">
Subscribe
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
{% endif %}
</div> </div>
{% endif %}
</div> </div>
</div> </div>
</div> </div>
@@ -258,8 +268,17 @@
}; };
function unsubscribe() { function unsubscribe() {
fetch('/api/subscription?email=' + '{{username}}', { fetch('/api/subscription?operation=cancel&email={{username}}', {
method: 'DELETE', method: 'PATCH',
headers: {
'Content-Type': 'application/json',
},
})
}
function resubscribe() {
fetch('/api/subscription?operation=resubscribe&email={{username}}', {
method: 'PATCH',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },

View File

@@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import concurrent.futures import concurrent.futures
from datetime import datetime, timezone
import math import math
import os import os
import time import time
@@ -746,48 +747,72 @@ async def subscribe(request: Request):
# Invalid signature # Invalid signature
raise e raise e
# Handle the event event_type = event["type"]
success = True if event_type not in {
if ( "invoice.paid",
event["type"] == "payment_intent.succeeded" "customer.subscription.updated",
or event["type"] == "invoice.payment_succeeded" "customer.subscription.deleted",
or event["type"] == "customer.subscription.created" "subscription_schedule.canceled",
): }:
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Mark the customer as subscribed
user = await adapters.set_user_subscribed(customer_email)
if not user:
success = False
elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Mark the customer as unsubscribed
user = await adapters.set_user_unsubscribed(customer_email)
if not user:
success = False
else:
logger.warn(f"Unhandled Stripe event type: {event['type']}") logger.warn(f"Unhandled Stripe event type: {event['type']}")
return {"success": False} return {"success": False}
# Retrieve the customer's details
subscription = event["data"]["object"]
customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Handle valid stripe webhook events
success = True
if event_type in {"invoice.paid"}:
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user = await adapters.get_user_by_email(customer_email)
# Allow updating subscription status if paid user
if user.subscription_renewal_date:
# Mark user as unsubscribed or resubscribed
is_subscribed = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_subscribed=False, renewal_date=False, type="trial"
)
success = user is not None
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}') logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success} return {"success": success}
@api.delete("/subscription") @api.patch("/subscription")
@requires(["authenticated"]) @requires(["authenticated"])
async def unsubscribe(request: Request, email: str): async def unsubscribe(request: Request, email: str, operation: str):
customer = stripe.Customer.list(email=email).data # Retrieve the customer's details
if not is_none_or_empty(customer): customers = stripe.Customer.list(email=email).auto_paging_iter()
customer_id = customer[0].id customer = next(customers, None)
if customer is None:
return {"success": False, "message": "Customer not found"}
if operation == "cancel":
customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id): for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True) stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
success = True return {"success": True}
else:
success = False
return {"success": success} elif operation == "resubscribe":
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
# Find the subscription that is set to cancel at the end of the period
for subscription in subscriptions:
if subscription.cancel_at_period_end:
# Update the subscription to not cancel at the end of the period
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
return {"success": True}
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}

View File

@@ -1,5 +1,4 @@
# System Packages # System Packages
from datetime import datetime, timezone
import json import json
import os import os
@@ -23,7 +22,7 @@ from database.adapters import (
get_user_github_config, get_user_github_config,
get_user_notion_config, get_user_notion_config,
ConversationAdapters, ConversationAdapters,
is_user_subscribed, get_user_subscription_state,
) )
# Initialize Router # Initialize Router
@@ -118,9 +117,9 @@ def login_page(request: Request):
def config_page(request: Request): def config_page(request: Request):
user: KhojUser = request.user.object user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_is_subscribed = is_user_subscribed(user.email) user_subscription_state = get_user_subscription_state(user.email)
days_to_renewal = ( subscription_renewal_date = (
(user.subscription_renewal_date - datetime.now(tz=timezone.utc)).days if user.subscription_renewal_date else 0 user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None
) )
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all()) enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
@@ -147,8 +146,8 @@ def config_page(request: Request):
"conversation_options": all_conversation_options, "conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture, "user_photo": user_picture,
"is_subscribed": user_is_subscribed, "subscription_state": user_subscription_state,
"days_to_renewal": days_to_renewal, "subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
}, },
) )