diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index a8080166..0a4cb05e 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -330,7 +330,7 @@ class EntryAdapters: return deleted_count @staticmethod - def delete_all_entries_by_source(user: KhojUser, file_source: str = None): + def delete_all_entries(user: KhojUser, file_source: str = None): if file_source is None: deleted_count, _ = Entry.objects.filter(user=user).delete() else: diff --git a/src/khoj/configure.py b/src/khoj/configure.py index ecd35cf9..fd0c67fe 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -145,10 +145,12 @@ def configure_routes(app): from khoj.routers.web_client import web_client from khoj.routers.indexer import indexer from khoj.routers.auth import auth_router + from khoj.routers.subscription import subscription_router app.include_router(api, prefix="/api") app.include_router(api_beta, prefix="/api/beta") app.include_router(indexer, prefix="/api/v1/index") + app.include_router(subscription_router, prefix="/api/subscription") app.include_router(web_client) app.include_router(auth_router, prefix="/auth") diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index e2187f0b..6b913351 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -1,8 +1,6 @@ # Standard Packages import concurrent.futures -from datetime import datetime, timezone import math -import os import time import logging import json @@ -12,7 +10,6 @@ from typing import List, Optional, Union, Any from fastapi import APIRouter, HTTPException, Header, Request from starlette.authentication import requires from asgiref.sync import sync_to_async -import stripe # Internal Packages from khoj.configure import configure_server @@ -245,7 +242,7 @@ async def remove_content_source_data( raise ValueError(f"Invalid content source: {content_source}") elif content_object != "Computer": await content_object.objects.filter(user=user).adelete() - await sync_to_async(EntryAdapters.delete_all_entries_by_source)(user, content_source) + await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source) enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) return {"status": "ok"} @@ -725,94 +722,3 @@ async def extract_references_and_questions( compiled_references = [item.additional["compiled"] for item in result_list] return compiled_references, inferred_queries, defiltered_query - - -# Stripe integration for Khoj Cloud Subscription -stripe.api_key = os.getenv("STRIPE_API_KEY") -endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET") - - -@api.post("/subscription") -async def subscribe(request: Request): - """Webhook for Stripe to send subscription events to Khoj Cloud""" - event = None - try: - payload = await request.body() - sig_header = request.headers["stripe-signature"] - event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret) - except ValueError as e: - # Invalid payload - raise e - except stripe.error.SignatureVerificationError as e: - # Invalid signature - raise e - - event_type = event["type"] - if event_type not in { - "invoice.paid", - "customer.subscription.updated", - "customer.subscription.deleted", - "subscription_schedule.canceled", - }: - logger.warn(f"Unhandled Stripe event type: {event['type']}") - return {"success": False} - - # Retrieve the customer's details - subscription = event["data"]["object"] - customer_id = subscription["customer"] - customer = stripe.Customer.retrieve(customer_id) - customer_email = customer["email"] - - # Handle valid stripe webhook events - success = True - if event_type in {"invoice.paid"}: - # Mark the user as subscribed and update the next renewal date on payment - subscription = stripe.Subscription.list(customer=customer_id).data[0] - renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) - user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date) - success = user is not None - elif event_type in {"customer.subscription.updated"}: - user = await adapters.get_user_by_email(customer_email) - # Allow updating subscription status if paid user - if user.subscription_renewal_date: - # Mark user as unsubscribed or resubscribed - is_subscribed = not subscription["cancel_at_period_end"] - updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed) - success = updated_user is not None - elif event_type in {"customer.subscription.deleted"}: - # Reset the user to trial state - user = await adapters.set_user_subscription( - customer_email, is_subscribed=False, renewal_date=False, type="trial" - ) - success = user is not None - - logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}') - return {"success": success} - - -@api.patch("/subscription") -@requires(["authenticated"]) -async def unsubscribe(request: Request, email: str, operation: str): - # Retrieve the customer's details - customers = stripe.Customer.list(email=email).auto_paging_iter() - customer = next(customers, None) - if customer is None: - return {"success": False, "message": "Customer not found"} - - if operation == "cancel": - customer_id = customer.id - for subscription in stripe.Subscription.list(customer=customer_id): - stripe.Subscription.modify(subscription.id, cancel_at_period_end=True) - return {"success": True} - - elif operation == "resubscribe": - subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter() - # Find the subscription that is set to cancel at the end of the period - for subscription in subscriptions: - if subscription.cancel_at_period_end: - # Update the subscription to not cancel at the end of the period - stripe.Subscription.modify(subscription.id, cancel_at_period_end=False) - return {"success": True} - return {"success": False, "message": "No subscription found that is set to cancel"} - - return {"success": False, "message": "Invalid operation"} diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py new file mode 100644 index 00000000..1c6902f8 --- /dev/null +++ b/src/khoj/routers/subscription.py @@ -0,0 +1,107 @@ +# Standard Packages +from datetime import datetime, timezone +import logging +import os + +# External Packages +from fastapi import APIRouter, Request +from starlette.authentication import requires +import stripe + +# Internal Packages +from database import adapters + +# Stripe integration for Khoj Cloud Subscription +stripe.api_key = os.getenv("STRIPE_API_KEY") +endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET") + + +logger = logging.getLogger(__name__) + +subscription_router = APIRouter() + + +@subscription_router.post("") +async def subscribe(request: Request): + """Webhook for Stripe to send subscription events to Khoj Cloud""" + event = None + try: + payload = await request.body() + sig_header = request.headers["stripe-signature"] + event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret) + except ValueError as e: + # Invalid payload + raise e + except stripe.error.SignatureVerificationError as e: + # Invalid signature + raise e + + event_type = event["type"] + if event_type not in { + "invoice.paid", + "customer.subscription.updated", + "customer.subscription.deleted", + "subscription_schedule.canceled", + }: + logger.warn(f"Unhandled Stripe event type: {event['type']}") + return {"success": False} + + # Retrieve the customer's details + subscription = event["data"]["object"] + customer_id = subscription["customer"] + customer = stripe.Customer.retrieve(customer_id) + customer_email = customer["email"] + + # Handle valid stripe webhook events + success = True + if event_type in {"invoice.paid"}: + # Mark the user as subscribed and update the next renewal date on payment + subscription = stripe.Subscription.list(customer=customer_id).data[0] + renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) + user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date) + success = user is not None + elif event_type in {"customer.subscription.updated"}: + user = await adapters.get_user_by_email(customer_email) + # Allow updating subscription status if paid user + if user.subscription_renewal_date: + # Mark user as unsubscribed or resubscribed + is_subscribed = not subscription["cancel_at_period_end"] + updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed) + success = updated_user is not None + elif event_type in {"customer.subscription.deleted"}: + # Reset the user to trial state + user = await adapters.set_user_subscription( + customer_email, is_subscribed=False, renewal_date=False, type="trial" + ) + success = user is not None + + logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}') + return {"success": success} + + +@subscription_router.patch("") +@requires(["authenticated"]) +async def update_subscription(request: Request, email: str, operation: str): + # Retrieve the customer's details + customers = stripe.Customer.list(email=email).auto_paging_iter() + customer = next(customers, None) + if customer is None: + return {"success": False, "message": "Customer not found"} + + if operation == "cancel": + customer_id = customer.id + for subscription in stripe.Subscription.list(customer=customer_id): + stripe.Subscription.modify(subscription.id, cancel_at_period_end=True) + return {"success": True} + + elif operation == "resubscribe": + subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter() + # Find the subscription that is set to cancel at the end of the period + for subscription in subscriptions: + if subscription.cancel_at_period_end: + # Update the subscription to not cancel at the end of the period + stripe.Subscription.modify(subscription.id, cancel_at_period_end=False) + return {"success": True} + return {"success": False, "message": "No subscription found that is set to cancel"} + + return {"success": False, "message": "Invalid operation"}