mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Create API webhook, endpoints for subscription payments using Stripe
- Add fields to mark users as subscribed to a specific plan and subscription renewal date in DB - Add ability to unsubscribe a user using their email address - Expose webhook for stripe to callback confirming payment
This commit is contained in:
@@ -73,7 +73,8 @@ dependencies = [
|
|||||||
"gunicorn == 21.2.0",
|
"gunicorn == 21.2.0",
|
||||||
"lxml == 4.9.3",
|
"lxml == 4.9.3",
|
||||||
"tzdata == 2023.3",
|
"tzdata == 2023.3",
|
||||||
"rapidocr-onnxruntime == 1.3.8"
|
"rapidocr-onnxruntime == 1.3.8",
|
||||||
|
"stripe == 7.3.0",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
from datetime import date
|
from datetime import date, datetime, timedelta
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
from datetime import date
|
from datetime import date
|
||||||
@@ -103,6 +103,27 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
|
||||||
|
user = await KhojUser.objects.filter(email=email).afirst()
|
||||||
|
if user:
|
||||||
|
user.subscription_type = type
|
||||||
|
start_date = user.subscription_renewal_date or datetime.now()
|
||||||
|
user.subscription_renewal_date = start_date + timedelta(days=30)
|
||||||
|
await user.asave()
|
||||||
|
return user
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_user_subscribed(email: str, type="standard") -> bool:
|
||||||
|
user = KhojUser.objects.filter(email=email, subscription_type=type).first()
|
||||||
|
if user and user.subscription_renewal_date:
|
||||||
|
is_subscribed = user.subscription_renewal_date > date.today()
|
||||||
|
return is_subscribed
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_token(token: dict) -> KhojUser:
|
async def get_user_by_token(token: dict) -> KhojUser:
|
||||||
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
||||||
if not google_user:
|
if not google_user:
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# Generated by Django 4.2.5 on 2023-11-07 18:19
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0012_entry_file_source"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="khojuser",
|
||||||
|
name="subscription_renewal_date",
|
||||||
|
field=models.DateTimeField(default=None, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="khojuser",
|
||||||
|
name="subscription_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -14,7 +14,15 @@ class BaseModel(models.Model):
|
|||||||
|
|
||||||
|
|
||||||
class KhojUser(AbstractUser):
|
class KhojUser(AbstractUser):
|
||||||
|
class SubscriptionType(models.TextChoices):
|
||||||
|
TRIAL = "trial"
|
||||||
|
STANDARD = "standard"
|
||||||
|
|
||||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||||
|
subscription_type = models.CharField(
|
||||||
|
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
|
||||||
|
)
|
||||||
|
subscription_renewal_date = models.DateTimeField(null=True, default=None)
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
if not self.uuid:
|
if not self.uuid:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
@@ -10,6 +11,7 @@ from typing import List, Optional, Union, Any
|
|||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
import stripe
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
@@ -23,7 +25,6 @@ from khoj.utils.rawconfig import (
|
|||||||
FullConfig,
|
FullConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchResponse,
|
SearchResponse,
|
||||||
TextContentConfig,
|
|
||||||
GithubContentConfig,
|
GithubContentConfig,
|
||||||
NotionContentConfig,
|
NotionContentConfig,
|
||||||
)
|
)
|
||||||
@@ -723,3 +724,61 @@ async def extract_references_and_questions(
|
|||||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||||
|
|
||||||
return compiled_references, inferred_queries, defiltered_query
|
return compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
||||||
|
|
||||||
|
# Stripe integration for Khoj Cloud Subscription
|
||||||
|
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
||||||
|
endpoint_secret = os.getenv("STRIPE_SIGINING_SECRET")
|
||||||
|
|
||||||
|
|
||||||
|
@api.post("/subscription")
|
||||||
|
async def subscribe(request: Request):
|
||||||
|
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
||||||
|
event = None
|
||||||
|
try:
|
||||||
|
payload = await request.body()
|
||||||
|
sig_header = request.headers["stripe-signature"]
|
||||||
|
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
||||||
|
except ValueError as e:
|
||||||
|
# Invalid payload
|
||||||
|
raise e
|
||||||
|
except stripe.error.SignatureVerificationError as e:
|
||||||
|
# Invalid signature
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Handle the event
|
||||||
|
success = True
|
||||||
|
if (
|
||||||
|
event["type"] == "payment_intent.succeeded"
|
||||||
|
or event["type"] == "invoice.payment_succeeded"
|
||||||
|
or event["type"] == "customer.subscription.created"
|
||||||
|
):
|
||||||
|
# Retrieve the customer's details
|
||||||
|
customer_id = event["data"]["object"]["customer"]
|
||||||
|
customer = stripe.Customer.retrieve(customer_id)
|
||||||
|
customer_email = customer["email"]
|
||||||
|
# Mark the customer as subscribed
|
||||||
|
user = await adapters.set_user_subscribed(customer_email)
|
||||||
|
if not user:
|
||||||
|
success = False
|
||||||
|
elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
|
||||||
|
# Retrieve the customer's details
|
||||||
|
customer_id = event["data"]["object"]["customer"]
|
||||||
|
customer = stripe.Customer.retrieve(customer_id)
|
||||||
|
|
||||||
|
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
||||||
|
|
||||||
|
return {"success": success}
|
||||||
|
|
||||||
|
|
||||||
|
@api.delete("/subscription")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def unsubscribe(request: Request, user_email: str):
|
||||||
|
customer = stripe.Customer.list(email=user_email).data
|
||||||
|
if not is_none_or_empty(customer):
|
||||||
|
stripe.Subscription.modify(customer[0].id, cancel_at_period_end=True)
|
||||||
|
success = True
|
||||||
|
else:
|
||||||
|
success = False
|
||||||
|
|
||||||
|
return {"success": success}
|
||||||
|
|||||||
Reference in New Issue
Block a user