Move Subscription data into separate table in DB. Merge migrations

This commit is contained in:
Debanjum Singh Solanky
2023-11-08 17:45:25 -08:00
parent 3bb10128ef
commit 8178004e6d
7 changed files with 95 additions and 75 deletions

View File

@@ -1,4 +1,4 @@
from typing import Type, TypeVar, List from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import secrets import secrets
from typing import Type, TypeVar, List from typing import Type, TypeVar, List
@@ -30,6 +30,7 @@ from database.models import (
GithubRepoConfig, GithubRepoConfig,
Conversation, Conversation,
ChatModelOptions, ChatModelOptions,
Subscription,
UserConversationConfig, UserConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
@@ -103,35 +104,51 @@ async def create_google_user(token: dict) -> KhojUser:
return user return user
async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser: def get_user_subscription(email: str) -> Optional[Subscription]:
user = await KhojUser.objects.filter(email=email).afirst() return Subscription.objects.filter(user__email=email).first()
if user:
user.subscription_type = type
if is_subscribed is not None: async def set_user_subscription(
user.is_subscribed = is_subscribed email: str, is_recurring=None, renewal_date=None, type="standard"
) -> Optional[Subscription]:
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
if not user_subscription:
user = await get_user_by_email(email)
if not user:
return None
user_subscription = await Subscription.objects.acreate(
user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date
)
return user_subscription
elif user_subscription:
user_subscription.type = type
if is_recurring is not None:
user_subscription.is_recurring = is_recurring
if renewal_date is False: if renewal_date is False:
user.subscription_renewal_date = None user_subscription.renewal_date = None
elif renewal_date is not None: elif renewal_date is not None:
user.subscription_renewal_date = renewal_date user_subscription.renewal_date = renewal_date
await user.asave() await user_subscription.asave()
return user return user_subscription
else: else:
return None return None
def get_user_subscription_state(email: str) -> str: def get_user_subscription_state(user_subscription: Subscription) -> str:
"""Get subscription state of user """Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
""" """
user = KhojUser.objects.filter(email=email).first() if not user_subscription:
if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
return "trial" return "trial"
elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc): elif user_subscription.type == Subscription.Type.TRIAL:
return "trial"
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed" return "subscribed"
elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc): elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed" return "unsubscribed"
elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc): elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
return "expired" return "expired"
return "invalid"
async def get_user_by_email(email: str) -> KhojUser: async def get_user_by_email(email: str) -> KhojUser:

View File

@@ -1,24 +0,0 @@
# Generated by Django 4.2.5 on 2023-11-07 18:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.AddField(
model_name="khojuser",
name="subscription_renewal_date",
field=models.DateTimeField(default=None, null=True),
),
migrations.AddField(
model_name="khojuser",
name="subscription_type",
field=models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
]

View File

@@ -0,0 +1,37 @@
# Generated by Django 4.2.5 on 2023-11-09 01:27
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.CreateModel(
name="Subscription",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"type",
models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
("is_recurring", models.BooleanField(default=False)),
("renewal_date", models.DateTimeField(default=None, null=True)),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 4.2.5 on 2023-11-08 19:40
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0013_khojuser_subscription_renewal_date_and_more"),
]
operations = [
migrations.AddField(
model_name="khojuser",
name="is_subscribed",
field=models.BooleanField(default=False),
),
]

View File

@@ -14,16 +14,7 @@ class BaseModel(models.Model):
class KhojUser(AbstractUser): class KhojUser(AbstractUser):
class SubscriptionType(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
subscription_type = models.CharField(
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
)
is_subscribed = models.BooleanField(default=False)
subscription_renewal_date = models.DateTimeField(null=True, default=None)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.uuid: if not self.uuid:
@@ -55,6 +46,17 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None) accessed_at = models.DateTimeField(null=True, default=None)
class Subscription(BaseModel):
class Type(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None)
class NotionConfig(BaseModel): class NotionConfig(BaseModel):
token = models.CharField(max_length=200) token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

View File

@@ -4,6 +4,7 @@ import logging
import os import os
# External Packages # External Packages
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from starlette.authentication import requires from starlette.authentication import requires
import stripe import stripe
@@ -58,20 +59,20 @@ async def subscribe(request: Request):
# Mark the user as subscribed and update the next renewal date on payment # Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0] subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc) renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date) user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
success = user is not None success = user is not None
elif event_type in {"customer.subscription.updated"}: elif event_type in {"customer.subscription.updated"}:
user = await adapters.get_user_by_email(customer_email) user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
# Allow updating subscription status if paid user # Allow updating subscription status if paid user
if user.subscription_renewal_date: if user_subscription.renewal_date:
# Mark user as unsubscribed or resubscribed # Mark user as unsubscribed or resubscribed
is_subscribed = not subscription["cancel_at_period_end"] is_recurring = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed) updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
success = updated_user is not None success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}: elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state # Reset the user to trial state
user = await adapters.set_user_subscription( user = await adapters.set_user_subscription(
customer_email, is_subscribed=False, renewal_date=False, type="trial" customer_email, is_recurring=False, renewal_date=False, type="trial"
) )
success = user is not None success = user is not None

View File

@@ -8,6 +8,7 @@ from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.authentication import requires from starlette.authentication import requires
from database import adapters
from database.models import KhojUser from database.models import KhojUser
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
GithubContentConfig, GithubContentConfig,
@@ -117,9 +118,12 @@ def login_page(request: Request):
def config_page(request: Request): def config_page(request: Request):
user: KhojUser = request.user.object user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email) user_subscription = adapters.get_user_subscription(user.email)
user_subscription_state = get_user_subscription_state(user_subscription)
subscription_renewal_date = ( subscription_renewal_date = (
user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
) )
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all()) enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())