Create API webhook, endpoints for subscription payments using Stripe

- Add fields to mark users as subscribed to a specific plan and
  subscription renewal date in DB
- Add ability to unsubscribe a user using their email address
- Expose webhook for stripe to callback confirming payment
This commit is contained in:
Debanjum Singh Solanky
2023-11-07 10:15:21 -08:00
parent 156421d30a
commit 9aaf475c8a
5 changed files with 116 additions and 3 deletions

View File

@@ -73,7 +73,8 @@ dependencies = [
"gunicorn == 21.2.0", "gunicorn == 21.2.0",
"lxml == 4.9.3", "lxml == 4.9.3",
"tzdata == 2023.3", "tzdata == 2023.3",
"rapidocr-onnxruntime == 1.3.8" "rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
] ]
dynamic = ["version"] dynamic = ["version"]

View File

@@ -1,5 +1,5 @@
from typing import Type, TypeVar, List from typing import Type, TypeVar, List
from datetime import date from datetime import date, datetime, timedelta
import secrets import secrets
from typing import Type, TypeVar, List from typing import Type, TypeVar, List
from datetime import date from datetime import date
@@ -103,6 +103,27 @@ async def create_google_user(token: dict) -> KhojUser:
return user return user
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst()
if user:
user.subscription_type = type
start_date = user.subscription_renewal_date or datetime.now()
user.subscription_renewal_date = start_date + timedelta(days=30)
await user.asave()
return user
else:
return None
def is_user_subscribed(email: str, type="standard") -> bool:
user = KhojUser.objects.filter(email=email, subscription_type=type).first()
if user and user.subscription_renewal_date:
is_subscribed = user.subscription_renewal_date > date.today()
return is_subscribed
else:
return False
async def get_user_by_token(token: dict) -> KhojUser: async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst() google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user: if not google_user:

View File

@@ -0,0 +1,24 @@
# Generated by Django 4.2.5 on 2023-11-07 18:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.AddField(
model_name="khojuser",
name="subscription_renewal_date",
field=models.DateTimeField(default=None, null=True),
),
migrations.AddField(
model_name="khojuser",
name="subscription_type",
field=models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
]

View File

@@ -14,7 +14,15 @@ class BaseModel(models.Model):
class KhojUser(AbstractUser): class KhojUser(AbstractUser):
class SubscriptionType(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
subscription_type = models.CharField(
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
)
subscription_renewal_date = models.DateTimeField(null=True, default=None)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.uuid: if not self.uuid:

View File

@@ -1,6 +1,7 @@
# Standard Packages # Standard Packages
import concurrent.futures import concurrent.futures
import math import math
import os
import time import time
import logging import logging
import json import json
@@ -10,6 +11,7 @@ from typing import List, Optional, Union, Any
from fastapi import APIRouter, HTTPException, Header, Request from fastapi import APIRouter, HTTPException, Header, Request
from starlette.authentication import requires from starlette.authentication import requires
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
import stripe
# Internal Packages # Internal Packages
from khoj.configure import configure_server from khoj.configure import configure_server
@@ -23,7 +25,6 @@ from khoj.utils.rawconfig import (
FullConfig, FullConfig,
SearchConfig, SearchConfig,
SearchResponse, SearchResponse,
TextContentConfig,
GithubContentConfig, GithubContentConfig,
NotionContentConfig, NotionContentConfig,
) )
@@ -723,3 +724,61 @@ async def extract_references_and_questions(
compiled_references = [item.additional["compiled"] for item in result_list] compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query return compiled_references, inferred_queries, defiltered_query
# Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGINING_SECRET")
@api.post("/subscription")
async def subscribe(request: Request):
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
event = None
try:
payload = await request.body()
sig_header = request.headers["stripe-signature"]
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
# Handle the event
success = True
if (
event["type"] == "payment_intent.succeeded"
or event["type"] == "invoice.payment_succeeded"
or event["type"] == "customer.subscription.created"
):
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Mark the customer as subscribed
user = await adapters.set_user_subscribed(customer_email)
if not user:
success = False
elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
@api.delete("/subscription")
@requires(["authenticated"])
async def unsubscribe(request: Request, user_email: str):
customer = stripe.Customer.list(email=user_email).data
if not is_none_or_empty(customer):
stripe.Subscription.modify(customer[0].id, cancel_at_period_end=True)
success = True
else:
success = False
return {"success": success}