Files
khoj/src/khoj/routers/subscription.py
2024-01-18 14:46:10 +05:30

102 lines
4.1 KiB
Python

import logging
import os
from datetime import datetime, timezone
import stripe
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from starlette.authentication import requires
from khoj.database import adapters
# Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
logger = logging.getLogger(__name__)
subscription_router = APIRouter()
@subscription_router.post("")
async def subscribe(request: Request):
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
event = None
try:
payload = await request.body()
sig_header = request.headers["stripe-signature"]
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
event_type = event["type"]
if event_type not in {
"invoice.paid",
"customer.subscription.updated",
"customer.subscription.deleted",
}:
logger.warning(f"Unhandled Stripe event type: {event['type']}")
return {"success": False}
# Retrieve the customer's details
subscription = event["data"]["object"]
customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Handle valid stripe webhook events
success = True
if event_type in {"invoice.paid"}:
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["lines"]["data"][0]["period"]["end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
# Allow updating subscription status if paid user
if user_subscription and user_subscription.renewal_date:
# Mark user as unsubscribed or resubscribed
is_recurring = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_recurring=False, renewal_date=False, type="trial"
)
success = user is not None
logger.info(f'Stripe subscription {event["type"]} for {customer_email}')
return {"success": success}
@subscription_router.patch("")
@requires(["authenticated"])
async def update_subscription(request: Request, email: str, operation: str):
# Retrieve the customer's details
customers = stripe.Customer.list(email=email).auto_paging_iter()
customer = next(customers, None)
if customer is None:
return {"success": False, "message": "Customer not found"}
if operation == "cancel":
customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
return {"success": True}
elif operation == "resubscribe":
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
# Find the subscription that is set to cancel at the end of the period
for subscription in subscriptions:
if subscription.cancel_at_period_end:
# Update the subscription to not cancel at the end of the period
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
return {"success": True}
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}