diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 0a4cb05e..4f71c7aa 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,4 +1,4 @@ -from typing import Type, TypeVar, List +from typing import Optional, Type, TypeVar, List from datetime import date, datetime, timedelta import secrets from typing import Type, TypeVar, List @@ -30,6 +30,7 @@ from database.models import ( GithubRepoConfig, Conversation, ChatModelOptions, + Subscription, UserConversationConfig, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, @@ -103,35 +104,51 @@ async def create_google_user(token: dict) -> KhojUser: return user -async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser: - user = await KhojUser.objects.filter(email=email).afirst() - if user: - user.subscription_type = type - if is_subscribed is not None: - user.is_subscribed = is_subscribed +def get_user_subscription(email: str) -> Optional[Subscription]: + return Subscription.objects.filter(user__email=email).first() + + +async def set_user_subscription( + email: str, is_recurring=None, renewal_date=None, type="standard" +) -> Optional[Subscription]: + user_subscription = await Subscription.objects.filter(user__email=email).afirst() + if not user_subscription: + user = await get_user_by_email(email) + if not user: + return None + user_subscription = await Subscription.objects.acreate( + user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date + ) + return user_subscription + elif user_subscription: + user_subscription.type = type + if is_recurring is not None: + user_subscription.is_recurring = is_recurring if renewal_date is False: - user.subscription_renewal_date = None + user_subscription.renewal_date = None elif renewal_date is not None: - user.subscription_renewal_date = renewal_date - await user.asave() - return user + user_subscription.renewal_date = renewal_date + await user_subscription.asave() + return user_subscription else: return None -def get_user_subscription_state(email: str) -> str: +def get_user_subscription_state(user_subscription: Subscription) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired """ - user = KhojUser.objects.filter(email=email).first() - if user.subscription_type == KhojUser.SubscriptionType.TRIAL: + if not user_subscription: return "trial" - elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc): + elif user_subscription.type == Subscription.Type.TRIAL: + return "trial" + elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): return "subscribed" - elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc): + elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): return "unsubscribed" - elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc): + elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): return "expired" + return "invalid" async def get_user_by_email(email: str) -> KhojUser: diff --git a/src/database/migrations/0013_khojuser_subscription_renewal_date_and_more.py b/src/database/migrations/0013_khojuser_subscription_renewal_date_and_more.py deleted file mode 100644 index d7f3df5c..00000000 --- a/src/database/migrations/0013_khojuser_subscription_renewal_date_and_more.py +++ /dev/null @@ -1,24 +0,0 @@ -# Generated by Django 4.2.5 on 2023-11-07 18:19 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("database", "0012_entry_file_source"), - ] - - operations = [ - migrations.AddField( - model_name="khojuser", - name="subscription_renewal_date", - field=models.DateTimeField(default=None, null=True), - ), - migrations.AddField( - model_name="khojuser", - name="subscription_type", - field=models.CharField( - choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20 - ), - ), - ] diff --git a/src/database/migrations/0013_subscription.py b/src/database/migrations/0013_subscription.py new file mode 100644 index 00000000..931cea12 --- /dev/null +++ b/src/database/migrations/0013_subscription.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2023-11-09 01:27 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0012_entry_file_source"), + ] + + operations = [ + migrations.CreateModel( + name="Subscription", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "type", + models.CharField( + choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20 + ), + ), + ("is_recurring", models.BooleanField(default=False)), + ("renewal_date", models.DateTimeField(default=None, null=True)), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/database/migrations/0014_khojuser_is_subscribed.py b/src/database/migrations/0014_khojuser_is_subscribed.py deleted file mode 100644 index 79b035ba..00000000 --- a/src/database/migrations/0014_khojuser_is_subscribed.py +++ /dev/null @@ -1,17 +0,0 @@ -# Generated by Django 4.2.5 on 2023-11-08 19:40 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("database", "0013_khojuser_subscription_renewal_date_and_more"), - ] - - operations = [ - migrations.AddField( - model_name="khojuser", - name="is_subscribed", - field=models.BooleanField(default=False), - ), - ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 5a3b96c5..28f8cd2a 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -14,16 +14,7 @@ class BaseModel(models.Model): class KhojUser(AbstractUser): - class SubscriptionType(models.TextChoices): - TRIAL = "trial" - STANDARD = "standard" - uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) - subscription_type = models.CharField( - max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL - ) - is_subscribed = models.BooleanField(default=False) - subscription_renewal_date = models.DateTimeField(null=True, default=None) def save(self, *args, **kwargs): if not self.uuid: @@ -55,6 +46,17 @@ class KhojApiUser(models.Model): accessed_at = models.DateTimeField(null=True, default=None) +class Subscription(BaseModel): + class Type(models.TextChoices): + TRIAL = "trial" + STANDARD = "standard" + + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL) + is_recurring = models.BooleanField(default=False) + renewal_date = models.DateTimeField(null=True, default=None) + + class NotionConfig(BaseModel): token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/subscription.py index 1c6902f8..a9862f51 100644 --- a/src/khoj/routers/subscription.py +++ b/src/khoj/routers/subscription.py @@ -4,6 +4,7 @@ import logging import os # External Packages +from asgiref.sync import sync_to_async from fastapi import APIRouter, Request from starlette.authentication import requires import stripe @@ -58,20 +59,20 @@ async def subscribe(request: Request): # Mark the user as subscribed and update the next renewal date on payment subscription = stripe.Subscription.list(customer=customer_id).data[0] renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) - user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date) + user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date) success = user is not None elif event_type in {"customer.subscription.updated"}: - user = await adapters.get_user_by_email(customer_email) + user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email) # Allow updating subscription status if paid user - if user.subscription_renewal_date: + if user_subscription.renewal_date: # Mark user as unsubscribed or resubscribed - is_subscribed = not subscription["cancel_at_period_end"] - updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed) + is_recurring = not subscription["cancel_at_period_end"] + updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring) success = updated_user is not None elif event_type in {"customer.subscription.deleted"}: # Reset the user to trial state user = await adapters.set_user_subscription( - customer_email, is_subscribed=False, renewal_date=False, type="trial" + customer_email, is_recurring=False, renewal_date=False, type="trial" ) success = user is not None diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index fd96dc8f..b47f6537 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -8,6 +8,7 @@ from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.templating import Jinja2Templates from starlette.authentication import requires +from database import adapters from database.models import KhojUser from khoj.utils.rawconfig import ( GithubContentConfig, @@ -117,9 +118,12 @@ def login_page(request: Request): def config_page(request: Request): user: KhojUser = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) + user_subscription = adapters.get_user_subscription(user.email) + user_subscription_state = get_user_subscription_state(user_subscription) subscription_renewal_date = ( - user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None + user_subscription.renewal_date.strftime("%d %b %Y") + if user_subscription and user_subscription.renewal_date + else None ) enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())