diff --git a/src/khoj/routers/api_subscription.py b/src/khoj/routers/api_subscription.py index 71300f04..11798958 100644 --- a/src/khoj/routers/api_subscription.py +++ b/src/khoj/routers/api_subscription.py @@ -18,6 +18,7 @@ if state.billing_enabled: stripe.api_key = os.getenv("STRIPE_API_KEY") endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET") +official_product_id = os.getenv("STRIPE_KHOJ_PRODUCT_ID", "") logger = logging.getLogger(__name__) subscription_router = APIRouter() @@ -48,6 +49,27 @@ async def subscribe(request: Request): # Retrieve the customer's details subscription = event["data"]["object"] + + # Verify product ID if official_product_id is configured + if official_product_id: + # Get the product ID from the subscription items + subscription_items = subscription.get("items", {}).get("data", []) + if not subscription_items: + logger.warning(f"No subscription items found for event {event['id']}") + return {"success": False} + + # Check if any subscription item matches the official product ID + valid_product = False + for item in subscription_items: + product_id = item.get("price", {}).get("product") + if product_id == official_product_id: + valid_product = True + break + + if not valid_product: + logger.warning(f"Event {event['id']} for non-official product, ignoring") + return {"success": False} + customer_id = subscription["customer"] customer = stripe.Customer.retrieve(customer_id) customer_email = customer["email"]