Move subscription API to separate, independent router

This commit is contained in:
Debanjum Singh Solanky
2023-11-08 16:20:27 -08:00
parent ec1395d072
commit 3bb10128ef
4 changed files with 111 additions and 96 deletions

View File

@@ -330,7 +330,7 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod @staticmethod
def delete_all_entries_by_source(user: KhojUser, file_source: str = None): def delete_all_entries(user: KhojUser, file_source: str = None):
if file_source is None: if file_source is None:
deleted_count, _ = Entry.objects.filter(user=user).delete() deleted_count, _ = Entry.objects.filter(user=user).delete()
else: else:

View File

@@ -145,10 +145,12 @@ def configure_routes(app):
from khoj.routers.web_client import web_client from khoj.routers.web_client import web_client
from khoj.routers.indexer import indexer from khoj.routers.indexer import indexer
from khoj.routers.auth import auth_router from khoj.routers.auth import auth_router
from khoj.routers.subscription import subscription_router
app.include_router(api, prefix="/api") app.include_router(api, prefix="/api")
app.include_router(api_beta, prefix="/api/beta") app.include_router(api_beta, prefix="/api/beta")
app.include_router(indexer, prefix="/api/v1/index") app.include_router(indexer, prefix="/api/v1/index")
app.include_router(subscription_router, prefix="/api/subscription")
app.include_router(web_client) app.include_router(web_client)
app.include_router(auth_router, prefix="/auth") app.include_router(auth_router, prefix="/auth")

View File

@@ -1,8 +1,6 @@
# Standard Packages # Standard Packages
import concurrent.futures import concurrent.futures
from datetime import datetime, timezone
import math import math
import os
import time import time
import logging import logging
import json import json
@@ -12,7 +10,6 @@ from typing import List, Optional, Union, Any
from fastapi import APIRouter, HTTPException, Header, Request from fastapi import APIRouter, HTTPException, Header, Request
from starlette.authentication import requires from starlette.authentication import requires
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
import stripe
# Internal Packages # Internal Packages
from khoj.configure import configure_server from khoj.configure import configure_server
@@ -245,7 +242,7 @@ async def remove_content_source_data(
raise ValueError(f"Invalid content source: {content_source}") raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer": elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete() await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries_by_source)(user, content_source) await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"} return {"status": "ok"}
@@ -725,94 +722,3 @@ async def extract_references_and_questions(
compiled_references = [item.additional["compiled"] for item in result_list] compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query return compiled_references, inferred_queries, defiltered_query
# Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
@api.post("/subscription")
async def subscribe(request: Request):
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
event = None
try:
payload = await request.body()
sig_header = request.headers["stripe-signature"]
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
event_type = event["type"]
if event_type not in {
"invoice.paid",
"customer.subscription.updated",
"customer.subscription.deleted",
"subscription_schedule.canceled",
}:
logger.warn(f"Unhandled Stripe event type: {event['type']}")
return {"success": False}
# Retrieve the customer's details
subscription = event["data"]["object"]
customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Handle valid stripe webhook events
success = True
if event_type in {"invoice.paid"}:
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user = await adapters.get_user_by_email(customer_email)
# Allow updating subscription status if paid user
if user.subscription_renewal_date:
# Mark user as unsubscribed or resubscribed
is_subscribed = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_subscribed=False, renewal_date=False, type="trial"
)
success = user is not None
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
@api.patch("/subscription")
@requires(["authenticated"])
async def unsubscribe(request: Request, email: str, operation: str):
# Retrieve the customer's details
customers = stripe.Customer.list(email=email).auto_paging_iter()
customer = next(customers, None)
if customer is None:
return {"success": False, "message": "Customer not found"}
if operation == "cancel":
customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
return {"success": True}
elif operation == "resubscribe":
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
# Find the subscription that is set to cancel at the end of the period
for subscription in subscriptions:
if subscription.cancel_at_period_end:
# Update the subscription to not cancel at the end of the period
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
return {"success": True}
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}

View File

@@ -0,0 +1,107 @@
# Standard Packages
from datetime import datetime, timezone
import logging
import os
# External Packages
from fastapi import APIRouter, Request
from starlette.authentication import requires
import stripe
# Internal Packages
from database import adapters
# Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
logger = logging.getLogger(__name__)
subscription_router = APIRouter()
@subscription_router.post("")
async def subscribe(request: Request):
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
event = None
try:
payload = await request.body()
sig_header = request.headers["stripe-signature"]
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
event_type = event["type"]
if event_type not in {
"invoice.paid",
"customer.subscription.updated",
"customer.subscription.deleted",
"subscription_schedule.canceled",
}:
logger.warn(f"Unhandled Stripe event type: {event['type']}")
return {"success": False}
# Retrieve the customer's details
subscription = event["data"]["object"]
customer_id = subscription["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Handle valid stripe webhook events
success = True
if event_type in {"invoice.paid"}:
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user = await adapters.get_user_by_email(customer_email)
# Allow updating subscription status if paid user
if user.subscription_renewal_date:
# Mark user as unsubscribed or resubscribed
is_subscribed = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_subscribed=False, renewal_date=False, type="trial"
)
success = user is not None
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
@subscription_router.patch("")
@requires(["authenticated"])
async def update_subscription(request: Request, email: str, operation: str):
# Retrieve the customer's details
customers = stripe.Customer.list(email=email).auto_paging_iter()
customer = next(customers, None)
if customer is None:
return {"success": False, "message": "Customer not found"}
if operation == "cancel":
customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
return {"success": True}
elif operation == "resubscribe":
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
# Find the subscription that is set to cancel at the end of the period
for subscription in subscriptions:
if subscription.cancel_at_period_end:
# Update the subscription to not cancel at the end of the period
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
return {"success": True}
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}