mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Move Subscription data into separate table in DB. Merge migrations
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Type, TypeVar, List
|
from typing import Optional, Type, TypeVar, List
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
@@ -30,6 +30,7 @@ from database.models import (
|
|||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
|
Subscription,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
@@ -103,35 +104,51 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser:
|
def get_user_subscription(email: str) -> Optional[Subscription]:
|
||||||
user = await KhojUser.objects.filter(email=email).afirst()
|
return Subscription.objects.filter(user__email=email).first()
|
||||||
if user:
|
|
||||||
user.subscription_type = type
|
|
||||||
if is_subscribed is not None:
|
async def set_user_subscription(
|
||||||
user.is_subscribed = is_subscribed
|
email: str, is_recurring=None, renewal_date=None, type="standard"
|
||||||
|
) -> Optional[Subscription]:
|
||||||
|
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
|
||||||
|
if not user_subscription:
|
||||||
|
user = await get_user_by_email(email)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
user_subscription = await Subscription.objects.acreate(
|
||||||
|
user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date
|
||||||
|
)
|
||||||
|
return user_subscription
|
||||||
|
elif user_subscription:
|
||||||
|
user_subscription.type = type
|
||||||
|
if is_recurring is not None:
|
||||||
|
user_subscription.is_recurring = is_recurring
|
||||||
if renewal_date is False:
|
if renewal_date is False:
|
||||||
user.subscription_renewal_date = None
|
user_subscription.renewal_date = None
|
||||||
elif renewal_date is not None:
|
elif renewal_date is not None:
|
||||||
user.subscription_renewal_date = renewal_date
|
user_subscription.renewal_date = renewal_date
|
||||||
await user.asave()
|
await user_subscription.asave()
|
||||||
return user
|
return user_subscription
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_user_subscription_state(email: str) -> str:
|
def get_user_subscription_state(user_subscription: Subscription) -> str:
|
||||||
"""Get subscription state of user
|
"""Get subscription state of user
|
||||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||||
"""
|
"""
|
||||||
user = KhojUser.objects.filter(email=email).first()
|
if not user_subscription:
|
||||||
if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
|
|
||||||
return "trial"
|
return "trial"
|
||||||
elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
|
elif user_subscription.type == Subscription.Type.TRIAL:
|
||||||
|
return "trial"
|
||||||
|
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||||
return "subscribed"
|
return "subscribed"
|
||||||
elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
|
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||||
return "unsubscribed"
|
return "unsubscribed"
|
||||||
elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc):
|
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||||
return "expired"
|
return "expired"
|
||||||
|
return "invalid"
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_email(email: str) -> KhojUser:
|
async def get_user_by_email(email: str) -> KhojUser:
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
# Generated by Django 4.2.5 on 2023-11-07 18:19
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("database", "0012_entry_file_source"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="khojuser",
|
|
||||||
name="subscription_renewal_date",
|
|
||||||
field=models.DateTimeField(default=None, null=True),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="khojuser",
|
|
||||||
name="subscription_type",
|
|
||||||
field=models.CharField(
|
|
||||||
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
37
src/database/migrations/0013_subscription.py
Normal file
37
src/database/migrations/0013_subscription.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Generated by Django 4.2.5 on 2023-11-09 01:27
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0012_entry_file_source"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="Subscription",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
(
|
||||||
|
"type",
|
||||||
|
models.CharField(
|
||||||
|
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("is_recurring", models.BooleanField(default=False)),
|
||||||
|
("renewal_date", models.DateTimeField(default=None, null=True)),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
# Generated by Django 4.2.5 on 2023-11-08 19:40
|
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("database", "0013_khojuser_subscription_renewal_date_and_more"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="khojuser",
|
|
||||||
name="is_subscribed",
|
|
||||||
field=models.BooleanField(default=False),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
@@ -14,16 +14,7 @@ class BaseModel(models.Model):
|
|||||||
|
|
||||||
|
|
||||||
class KhojUser(AbstractUser):
|
class KhojUser(AbstractUser):
|
||||||
class SubscriptionType(models.TextChoices):
|
|
||||||
TRIAL = "trial"
|
|
||||||
STANDARD = "standard"
|
|
||||||
|
|
||||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||||
subscription_type = models.CharField(
|
|
||||||
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
|
|
||||||
)
|
|
||||||
is_subscribed = models.BooleanField(default=False)
|
|
||||||
subscription_renewal_date = models.DateTimeField(null=True, default=None)
|
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
if not self.uuid:
|
if not self.uuid:
|
||||||
@@ -55,6 +46,17 @@ class KhojApiUser(models.Model):
|
|||||||
accessed_at = models.DateTimeField(null=True, default=None)
|
accessed_at = models.DateTimeField(null=True, default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(BaseModel):
|
||||||
|
class Type(models.TextChoices):
|
||||||
|
TRIAL = "trial"
|
||||||
|
STANDARD = "standard"
|
||||||
|
|
||||||
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||||
|
is_recurring = models.BooleanField(default=False)
|
||||||
|
renewal_date = models.DateTimeField(null=True, default=None)
|
||||||
|
|
||||||
|
|
||||||
class NotionConfig(BaseModel):
|
class NotionConfig(BaseModel):
|
||||||
token = models.CharField(max_length=200)
|
token = models.CharField(max_length=200)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
import stripe
|
import stripe
|
||||||
@@ -58,20 +59,20 @@ async def subscribe(request: Request):
|
|||||||
# Mark the user as subscribed and update the next renewal date on payment
|
# Mark the user as subscribed and update the next renewal date on payment
|
||||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||||
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
|
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
|
||||||
success = user is not None
|
success = user is not None
|
||||||
elif event_type in {"customer.subscription.updated"}:
|
elif event_type in {"customer.subscription.updated"}:
|
||||||
user = await adapters.get_user_by_email(customer_email)
|
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
|
||||||
# Allow updating subscription status if paid user
|
# Allow updating subscription status if paid user
|
||||||
if user.subscription_renewal_date:
|
if user_subscription.renewal_date:
|
||||||
# Mark user as unsubscribed or resubscribed
|
# Mark user as unsubscribed or resubscribed
|
||||||
is_subscribed = not subscription["cancel_at_period_end"]
|
is_recurring = not subscription["cancel_at_period_end"]
|
||||||
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
|
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
|
||||||
success = updated_user is not None
|
success = updated_user is not None
|
||||||
elif event_type in {"customer.subscription.deleted"}:
|
elif event_type in {"customer.subscription.deleted"}:
|
||||||
# Reset the user to trial state
|
# Reset the user to trial state
|
||||||
user = await adapters.set_user_subscription(
|
user = await adapters.set_user_subscription(
|
||||||
customer_email, is_subscribed=False, renewal_date=False, type="trial"
|
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
||||||
)
|
)
|
||||||
success = user is not None
|
success = user is not None
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from fastapi import Request
|
|||||||
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
from database import adapters
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
GithubContentConfig,
|
GithubContentConfig,
|
||||||
@@ -117,9 +118,12 @@ def login_page(request: Request):
|
|||||||
def config_page(request: Request):
|
def config_page(request: Request):
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
user_picture = request.session.get("user", {}).get("picture")
|
user_picture = request.session.get("user", {}).get("picture")
|
||||||
user_subscription_state = get_user_subscription_state(user.email)
|
user_subscription = adapters.get_user_subscription(user.email)
|
||||||
|
user_subscription_state = get_user_subscription_state(user_subscription)
|
||||||
subscription_renewal_date = (
|
subscription_renewal_date = (
|
||||||
user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None
|
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||||
|
if user_subscription and user_subscription.renewal_date
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
|
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user