mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Move subscription API to separate, independent router
This commit is contained in:
@@ -330,7 +330,7 @@ class EntryAdapters:
|
|||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_all_entries_by_source(user: KhojUser, file_source: str = None):
|
def delete_all_entries(user: KhojUser, file_source: str = None):
|
||||||
if file_source is None:
|
if file_source is None:
|
||||||
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -145,10 +145,12 @@ def configure_routes(app):
|
|||||||
from khoj.routers.web_client import web_client
|
from khoj.routers.web_client import web_client
|
||||||
from khoj.routers.indexer import indexer
|
from khoj.routers.indexer import indexer
|
||||||
from khoj.routers.auth import auth_router
|
from khoj.routers.auth import auth_router
|
||||||
|
from khoj.routers.subscription import subscription_router
|
||||||
|
|
||||||
app.include_router(api, prefix="/api")
|
app.include_router(api, prefix="/api")
|
||||||
app.include_router(api_beta, prefix="/api/beta")
|
app.include_router(api_beta, prefix="/api/beta")
|
||||||
app.include_router(indexer, prefix="/api/v1/index")
|
app.include_router(indexer, prefix="/api/v1/index")
|
||||||
|
app.include_router(subscription_router, prefix="/api/subscription")
|
||||||
app.include_router(web_client)
|
app.include_router(web_client)
|
||||||
app.include_router(auth_router, prefix="/auth")
|
app.include_router(auth_router, prefix="/auth")
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from datetime import datetime, timezone
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
@@ -12,7 +10,6 @@ from typing import List, Optional, Union, Any
|
|||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
import stripe
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
@@ -245,7 +242,7 @@ async def remove_content_source_data(
|
|||||||
raise ValueError(f"Invalid content source: {content_source}")
|
raise ValueError(f"Invalid content source: {content_source}")
|
||||||
elif content_object != "Computer":
|
elif content_object != "Computer":
|
||||||
await content_object.objects.filter(user=user).adelete()
|
await content_object.objects.filter(user=user).adelete()
|
||||||
await sync_to_async(EntryAdapters.delete_all_entries_by_source)(user, content_source)
|
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
@@ -725,94 +722,3 @@ async def extract_references_and_questions(
|
|||||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||||
|
|
||||||
return compiled_references, inferred_queries, defiltered_query
|
return compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
||||||
|
|
||||||
# Stripe integration for Khoj Cloud Subscription
|
|
||||||
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
|
||||||
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/subscription")
|
|
||||||
async def subscribe(request: Request):
|
|
||||||
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
|
||||||
event = None
|
|
||||||
try:
|
|
||||||
payload = await request.body()
|
|
||||||
sig_header = request.headers["stripe-signature"]
|
|
||||||
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
|
||||||
except ValueError as e:
|
|
||||||
# Invalid payload
|
|
||||||
raise e
|
|
||||||
except stripe.error.SignatureVerificationError as e:
|
|
||||||
# Invalid signature
|
|
||||||
raise e
|
|
||||||
|
|
||||||
event_type = event["type"]
|
|
||||||
if event_type not in {
|
|
||||||
"invoice.paid",
|
|
||||||
"customer.subscription.updated",
|
|
||||||
"customer.subscription.deleted",
|
|
||||||
"subscription_schedule.canceled",
|
|
||||||
}:
|
|
||||||
logger.warn(f"Unhandled Stripe event type: {event['type']}")
|
|
||||||
return {"success": False}
|
|
||||||
|
|
||||||
# Retrieve the customer's details
|
|
||||||
subscription = event["data"]["object"]
|
|
||||||
customer_id = subscription["customer"]
|
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
|
||||||
customer_email = customer["email"]
|
|
||||||
|
|
||||||
# Handle valid stripe webhook events
|
|
||||||
success = True
|
|
||||||
if event_type in {"invoice.paid"}:
|
|
||||||
# Mark the user as subscribed and update the next renewal date on payment
|
|
||||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
|
||||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
|
||||||
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
|
|
||||||
success = user is not None
|
|
||||||
elif event_type in {"customer.subscription.updated"}:
|
|
||||||
user = await adapters.get_user_by_email(customer_email)
|
|
||||||
# Allow updating subscription status if paid user
|
|
||||||
if user.subscription_renewal_date:
|
|
||||||
# Mark user as unsubscribed or resubscribed
|
|
||||||
is_subscribed = not subscription["cancel_at_period_end"]
|
|
||||||
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
|
|
||||||
success = updated_user is not None
|
|
||||||
elif event_type in {"customer.subscription.deleted"}:
|
|
||||||
# Reset the user to trial state
|
|
||||||
user = await adapters.set_user_subscription(
|
|
||||||
customer_email, is_subscribed=False, renewal_date=False, type="trial"
|
|
||||||
)
|
|
||||||
success = user is not None
|
|
||||||
|
|
||||||
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
|
||||||
return {"success": success}
|
|
||||||
|
|
||||||
|
|
||||||
@api.patch("/subscription")
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def unsubscribe(request: Request, email: str, operation: str):
|
|
||||||
# Retrieve the customer's details
|
|
||||||
customers = stripe.Customer.list(email=email).auto_paging_iter()
|
|
||||||
customer = next(customers, None)
|
|
||||||
if customer is None:
|
|
||||||
return {"success": False, "message": "Customer not found"}
|
|
||||||
|
|
||||||
if operation == "cancel":
|
|
||||||
customer_id = customer.id
|
|
||||||
for subscription in stripe.Subscription.list(customer=customer_id):
|
|
||||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
elif operation == "resubscribe":
|
|
||||||
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
|
|
||||||
# Find the subscription that is set to cancel at the end of the period
|
|
||||||
for subscription in subscriptions:
|
|
||||||
if subscription.cancel_at_period_end:
|
|
||||||
# Update the subscription to not cancel at the end of the period
|
|
||||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
|
|
||||||
return {"success": True}
|
|
||||||
return {"success": False, "message": "No subscription found that is set to cancel"}
|
|
||||||
|
|
||||||
return {"success": False, "message": "Invalid operation"}
|
|
||||||
|
|||||||
107
src/khoj/routers/subscription.py
Normal file
107
src/khoj/routers/subscription.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# Standard Packages
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from starlette.authentication import requires
|
||||||
|
import stripe
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from database import adapters
|
||||||
|
|
||||||
|
# Stripe integration for Khoj Cloud Subscription
|
||||||
|
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
||||||
|
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
subscription_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@subscription_router.post("")
|
||||||
|
async def subscribe(request: Request):
|
||||||
|
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
||||||
|
event = None
|
||||||
|
try:
|
||||||
|
payload = await request.body()
|
||||||
|
sig_header = request.headers["stripe-signature"]
|
||||||
|
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
||||||
|
except ValueError as e:
|
||||||
|
# Invalid payload
|
||||||
|
raise e
|
||||||
|
except stripe.error.SignatureVerificationError as e:
|
||||||
|
# Invalid signature
|
||||||
|
raise e
|
||||||
|
|
||||||
|
event_type = event["type"]
|
||||||
|
if event_type not in {
|
||||||
|
"invoice.paid",
|
||||||
|
"customer.subscription.updated",
|
||||||
|
"customer.subscription.deleted",
|
||||||
|
"subscription_schedule.canceled",
|
||||||
|
}:
|
||||||
|
logger.warn(f"Unhandled Stripe event type: {event['type']}")
|
||||||
|
return {"success": False}
|
||||||
|
|
||||||
|
# Retrieve the customer's details
|
||||||
|
subscription = event["data"]["object"]
|
||||||
|
customer_id = subscription["customer"]
|
||||||
|
customer = stripe.Customer.retrieve(customer_id)
|
||||||
|
customer_email = customer["email"]
|
||||||
|
|
||||||
|
# Handle valid stripe webhook events
|
||||||
|
success = True
|
||||||
|
if event_type in {"invoice.paid"}:
|
||||||
|
# Mark the user as subscribed and update the next renewal date on payment
|
||||||
|
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||||
|
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||||
|
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
|
||||||
|
success = user is not None
|
||||||
|
elif event_type in {"customer.subscription.updated"}:
|
||||||
|
user = await adapters.get_user_by_email(customer_email)
|
||||||
|
# Allow updating subscription status if paid user
|
||||||
|
if user.subscription_renewal_date:
|
||||||
|
# Mark user as unsubscribed or resubscribed
|
||||||
|
is_subscribed = not subscription["cancel_at_period_end"]
|
||||||
|
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
|
||||||
|
success = updated_user is not None
|
||||||
|
elif event_type in {"customer.subscription.deleted"}:
|
||||||
|
# Reset the user to trial state
|
||||||
|
user = await adapters.set_user_subscription(
|
||||||
|
customer_email, is_subscribed=False, renewal_date=False, type="trial"
|
||||||
|
)
|
||||||
|
success = user is not None
|
||||||
|
|
||||||
|
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
||||||
|
return {"success": success}
|
||||||
|
|
||||||
|
|
||||||
|
@subscription_router.patch("")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_subscription(request: Request, email: str, operation: str):
|
||||||
|
# Retrieve the customer's details
|
||||||
|
customers = stripe.Customer.list(email=email).auto_paging_iter()
|
||||||
|
customer = next(customers, None)
|
||||||
|
if customer is None:
|
||||||
|
return {"success": False, "message": "Customer not found"}
|
||||||
|
|
||||||
|
if operation == "cancel":
|
||||||
|
customer_id = customer.id
|
||||||
|
for subscription in stripe.Subscription.list(customer=customer_id):
|
||||||
|
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
elif operation == "resubscribe":
|
||||||
|
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
|
||||||
|
# Find the subscription that is set to cancel at the end of the period
|
||||||
|
for subscription in subscriptions:
|
||||||
|
if subscription.cancel_at_period_end:
|
||||||
|
# Update the subscription to not cancel at the end of the period
|
||||||
|
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
|
||||||
|
return {"success": True}
|
||||||
|
return {"success": False, "message": "No subscription found that is set to cancel"}
|
||||||
|
|
||||||
|
return {"success": False, "message": "Invalid operation"}
|
||||||
Reference in New Issue
Block a user