Clean, merge subscription update events, API and functions

- Reduce webhook triggers for subscription updates
- Merge subscription update API endpoint, functions for (re/un-)subscribe
This commit is contained in:
Debanjum Singh Solanky
2023-11-08 15:14:51 -08:00
parent ef5c13f968
commit ec1395d072
4 changed files with 121 additions and 72 deletions

View File

@@ -103,33 +103,39 @@ async def create_google_user(token: dict) -> KhojUser:
return user
async def set_user_unsubscribed(email: str, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email, subscription_type=type).afirst()
if user:
user.is_subscribed = False
await user.asave()
return user
else:
return None
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst()
if user:
user.subscription_type = type
user.is_subscribed = True
start_date = user.subscription_renewal_date or datetime.now()
user.subscription_renewal_date = start_date + timedelta(days=30)
if is_subscribed is not None:
user.is_subscribed = is_subscribed
if renewal_date is False:
user.subscription_renewal_date = None
elif renewal_date is not None:
user.subscription_renewal_date = renewal_date
await user.asave()
return user
else:
return None
def is_user_subscribed(email: str, type="standard") -> bool:
return KhojUser.objects.filter(
email=email, subscription_type=type, subscription_renewal_date__gte=datetime.now(tz=timezone.utc)
).exists()
def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user = KhojUser.objects.filter(email=email).first()
if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
return "trial"
elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed"
elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed"
elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc):
return "expired"
async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()
async def get_user_by_token(token: dict) -> KhojUser:

View File

@@ -167,29 +167,39 @@
<h3 class="card-title">
<span>Subscription</span>
<img id="configured-icon-subscription"
style="display: {% if not is_subscribed %}none{% endif %}"
style="display: {% if subscription_state == 'trial' or subscription_state == 'expired' %}none{% endif %}"
class="configured-icon"
src="/static/assets/icons/confirm-icon.svg"
alt="Configured">
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Manage your subscription to Khoj Cloud</p>
</div>
{% if not is_subscribed %}
<div class="card-action-row">
<a class="card-button happy" href="{{ khoj_cloud_subscription_url }}?email={{ username }}" target="_blank">
Subscribe
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if subscription_state == "subscribed" %}
<p class="card-description">You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>renew</b> on <b>{{ subscription_renewal_date }}</b></p>
{% elif subscription_state == "unsubscribed" %}
<p class="card-description">You are <b>subscribed</b> to Khoj Cloud. Subscription will <b>expire</b> on <b>{{ subscription_renewal_date }}</b></p>
{% elif subscription_state == "expired" %}
<p class="card-description">Subscribe to Khoj Cloud. Subscription <b>expired</b> on <b>{{ subscription_renewal_date }}</b></p>
{% else %}
<p class="card-description">Subscribe to Khoj Cloud</p>
{% endif %}
</div>
<div class="card-action-row">
{% if subscription_state == "subscribed" %}
<button class="card-button" onclick="unsubscribe()">
Unsubscribe
</button>
{% elif subscription_state == "unsubscribed" %}
<button class="card-button" onclick="resubscribe()">
Resubscribe
</button>
{% else %}
<a class="card-button happy" href="{{ khoj_cloud_subscription_url }}?prefilled_email={{ username }}" target="_blank">
Subscribe
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
{% endif %}
</div>
{% endif %}
</div>
</div>
</div>
@@ -258,8 +268,17 @@
};
function unsubscribe() {
fetch('/api/subscription?email=' + '{{username}}', {
method: 'DELETE',
fetch('/api/subscription?operation=cancel&email={{username}}', {
method: 'PATCH',
headers: {
'Content-Type': 'application/json',
},
})
}
function resubscribe() {
fetch('/api/subscription?operation=resubscribe&email={{username}}', {
method: 'PATCH',
headers: {
'Content-Type': 'application/json',
},

View File

@@ -1,5 +1,6 @@
# Standard Packages
import concurrent.futures
from datetime import datetime, timezone
import math
import os
import time
@@ -746,48 +747,72 @@ async def subscribe(request: Request):
# Invalid signature
raise e
# Handle the event
success = True
if (
event["type"] == "payment_intent.succeeded"
or event["type"] == "invoice.payment_succeeded"
or event["type"] == "customer.subscription.created"
):
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Mark the customer as subscribed
user = await adapters.set_user_subscribed(customer_email)
if not user:
success = False
elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Mark the customer as unsubscribed
user = await adapters.set_user_unsubscribed(customer_email)
if not user:
success = False
else:
event_type = event["type"]
if event_type not in {
"invoice.paid",
"customer.subscription.updated",
"customer.subscription.deleted",
"subscription_schedule.canceled",
}:
logger.warn(f"Unhandled Stripe event type: {event['type']}")
return {"success": False}
# Retrieve the customer's details
subscription = event["data"]["object"]
customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Handle valid stripe webhook events
success = True
if event_type in {"invoice.paid"}:
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user = await adapters.get_user_by_email(customer_email)
# Allow updating subscription status if paid user
if user.subscription_renewal_date:
# Mark user as unsubscribed or resubscribed
is_subscribed = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_subscribed=False, renewal_date=False, type="trial"
)
success = user is not None
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
@api.delete("/subscription")
@api.patch("/subscription")
@requires(["authenticated"])
async def unsubscribe(request: Request, email: str):
customer = stripe.Customer.list(email=email).data
if not is_none_or_empty(customer):
customer_id = customer[0].id
async def unsubscribe(request: Request, email: str, operation: str):
# Retrieve the customer's details
customers = stripe.Customer.list(email=email).auto_paging_iter()
customer = next(customers, None)
if customer is None:
return {"success": False, "message": "Customer not found"}
if operation == "cancel":
customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
success = True
else:
success = False
return {"success": True}
return {"success": success}
elif operation == "resubscribe":
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
# Find the subscription that is set to cancel at the end of the period
for subscription in subscriptions:
if subscription.cancel_at_period_end:
# Update the subscription to not cancel at the end of the period
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
return {"success": True}
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}

View File

@@ -1,5 +1,4 @@
# System Packages
from datetime import datetime, timezone
import json
import os
@@ -23,7 +22,7 @@ from database.adapters import (
get_user_github_config,
get_user_notion_config,
ConversationAdapters,
is_user_subscribed,
get_user_subscription_state,
)
# Initialize Router
@@ -118,9 +117,9 @@ def login_page(request: Request):
def config_page(request: Request):
user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_is_subscribed = is_user_subscribed(user.email)
days_to_renewal = (
(user.subscription_renewal_date - datetime.now(tz=timezone.utc)).days if user.subscription_renewal_date else 0
user_subscription_state = get_user_subscription_state(user.email)
subscription_renewal_date = (
user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None
)
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
@@ -147,8 +146,8 @@ def config_page(request: Request):
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture,
"is_subscribed": user_is_subscribed,
"days_to_renewal": days_to_renewal,
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
},
)