diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py
index 3358aa79..a8080166 100644
--- a/src/database/adapters/__init__.py
+++ b/src/database/adapters/__init__.py
@@ -103,33 +103,39 @@ async def create_google_user(token: dict) -> KhojUser:
return user
-async def set_user_unsubscribed(email: str, type="standard") -> KhojUser:
- user = await KhojUser.objects.filter(email=email, subscription_type=type).afirst()
- if user:
- user.is_subscribed = False
- await user.asave()
- return user
- else:
- return None
-
-
-async def set_user_subscribed(email: str, type="standard") -> KhojUser:
+async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst()
if user:
user.subscription_type = type
- user.is_subscribed = True
- start_date = user.subscription_renewal_date or datetime.now()
- user.subscription_renewal_date = start_date + timedelta(days=30)
+ if is_subscribed is not None:
+ user.is_subscribed = is_subscribed
+ if renewal_date is False:
+ user.subscription_renewal_date = None
+ elif renewal_date is not None:
+ user.subscription_renewal_date = renewal_date
await user.asave()
return user
else:
return None
-def is_user_subscribed(email: str, type="standard") -> bool:
- return KhojUser.objects.filter(
- email=email, subscription_type=type, subscription_renewal_date__gte=datetime.now(tz=timezone.utc)
- ).exists()
+def get_user_subscription_state(email: str) -> str:
+ """Get subscription state of user
+ Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
+ """
+ user = KhojUser.objects.filter(email=email).first()
+ if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
+ return "trial"
+ elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
+ return "subscribed"
+ elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
+ return "unsubscribed"
+ elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc):
+ return "expired"
+
+
+async def get_user_by_email(email: str) -> KhojUser:
+ return await KhojUser.objects.filter(email=email).afirst()
async def get_user_by_token(token: dict) -> KhojUser:
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index e86204bb..58c68349 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -167,29 +167,39 @@
Subscription
-
Manage your subscription to Khoj Cloud
-
- {% if not is_subscribed %}
-
+ {% if subscription_state == "subscribed" %}
+ You are subscribed to Khoj Cloud. Subscription will renew on {{ subscription_renewal_date }}
+ {% elif subscription_state == "unsubscribed" %}
+ You are subscribed to Khoj Cloud. Subscription will expire on {{ subscription_renewal_date }}
+ {% elif subscription_state == "expired" %}
+ Subscribe to Khoj Cloud. Subscription expired on {{ subscription_renewal_date }}
{% else %}
+ Subscribe to Khoj Cloud
+ {% endif %}
+
+ {% if subscription_state == "subscribed" %}
+ {% elif subscription_state == "unsubscribed" %}
+
+ {% else %}
+
+ Subscribe
+
+
+ {% endif %}
- {% endif %}
@@ -258,8 +268,17 @@
};
function unsubscribe() {
- fetch('/api/subscription?email=' + '{{username}}', {
- method: 'DELETE',
+ fetch('/api/subscription?operation=cancel&email={{username}}', {
+ method: 'PATCH',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ })
+ }
+
+ function resubscribe() {
+ fetch('/api/subscription?operation=resubscribe&email={{username}}', {
+ method: 'PATCH',
headers: {
'Content-Type': 'application/json',
},
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 77e377e7..e2187f0b 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -1,5 +1,6 @@
# Standard Packages
import concurrent.futures
+from datetime import datetime, timezone
import math
import os
import time
@@ -746,48 +747,72 @@ async def subscribe(request: Request):
# Invalid signature
raise e
- # Handle the event
- success = True
- if (
- event["type"] == "payment_intent.succeeded"
- or event["type"] == "invoice.payment_succeeded"
- or event["type"] == "customer.subscription.created"
- ):
- # Retrieve the customer's details
- customer_id = event["data"]["object"]["customer"]
- customer = stripe.Customer.retrieve(customer_id)
- customer_email = customer["email"]
- # Mark the customer as subscribed
- user = await adapters.set_user_subscribed(customer_email)
- if not user:
- success = False
- elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
- # Retrieve the customer's details
- customer_id = event["data"]["object"]["customer"]
- customer = stripe.Customer.retrieve(customer_id)
- customer_email = customer["email"]
- # Mark the customer as unsubscribed
- user = await adapters.set_user_unsubscribed(customer_email)
- if not user:
- success = False
- else:
+ event_type = event["type"]
+ if event_type not in {
+ "invoice.paid",
+ "customer.subscription.updated",
+ "customer.subscription.deleted",
+ "subscription_schedule.canceled",
+ }:
logger.warn(f"Unhandled Stripe event type: {event['type']}")
return {"success": False}
+ # Retrieve the customer's details
+ subscription = event["data"]["object"]
+ customer_id = subscription["customer"]
+ customer = stripe.Customer.retrieve(customer_id)
+ customer_email = customer["email"]
+
+ # Handle valid stripe webhook events
+ success = True
+ if event_type in {"invoice.paid"}:
+ # Mark the user as subscribed and update the next renewal date on payment
+ subscription = stripe.Subscription.list(customer=customer_id).data[0]
+ renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
+ user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
+ success = user is not None
+ elif event_type in {"customer.subscription.updated"}:
+ user = await adapters.get_user_by_email(customer_email)
+ # Allow updating subscription status if paid user
+ if user.subscription_renewal_date:
+ # Mark user as unsubscribed or resubscribed
+ is_subscribed = not subscription["cancel_at_period_end"]
+ updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
+ success = updated_user is not None
+ elif event_type in {"customer.subscription.deleted"}:
+ # Reset the user to trial state
+ user = await adapters.set_user_subscription(
+ customer_email, is_subscribed=False, renewal_date=False, type="trial"
+ )
+ success = user is not None
+
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
-@api.delete("/subscription")
+@api.patch("/subscription")
@requires(["authenticated"])
-async def unsubscribe(request: Request, email: str):
- customer = stripe.Customer.list(email=email).data
- if not is_none_or_empty(customer):
- customer_id = customer[0].id
+async def unsubscribe(request: Request, email: str, operation: str):
+ # Retrieve the customer's details
+ customers = stripe.Customer.list(email=email).auto_paging_iter()
+ customer = next(customers, None)
+ if customer is None:
+ return {"success": False, "message": "Customer not found"}
+
+ if operation == "cancel":
+ customer_id = customer.id
for subscription in stripe.Subscription.list(customer=customer_id):
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
- success = True
- else:
- success = False
+ return {"success": True}
- return {"success": success}
+ elif operation == "resubscribe":
+ subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
+ # Find the subscription that is set to cancel at the end of the period
+ for subscription in subscriptions:
+ if subscription.cancel_at_period_end:
+ # Update the subscription to not cancel at the end of the period
+ stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
+ return {"success": True}
+ return {"success": False, "message": "No subscription found that is set to cancel"}
+
+ return {"success": False, "message": "Invalid operation"}
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index b62f6b94..fd96dc8f 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -1,5 +1,4 @@
# System Packages
-from datetime import datetime, timezone
import json
import os
@@ -23,7 +22,7 @@ from database.adapters import (
get_user_github_config,
get_user_notion_config,
ConversationAdapters,
- is_user_subscribed,
+ get_user_subscription_state,
)
# Initialize Router
@@ -118,9 +117,9 @@ def login_page(request: Request):
def config_page(request: Request):
user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_is_subscribed = is_user_subscribed(user.email)
- days_to_renewal = (
- (user.subscription_renewal_date - datetime.now(tz=timezone.utc)).days if user.subscription_renewal_date else 0
+ user_subscription_state = get_user_subscription_state(user.email)
+ subscription_renewal_date = (
+ user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None
)
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
@@ -147,8 +146,8 @@ def config_page(request: Request):
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture,
- "is_subscribed": user_is_subscribed,
- "days_to_renewal": days_to_renewal,
+ "subscription_state": user_subscription_state,
+ "subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
},
)