Use scopes to represent whether the use has a valid subscription in the middleware

This commit is contained in:
sabaimran
2023-11-24 20:29:36 -08:00
parent c13953311a
commit 69c8f45830
5 changed files with 59 additions and 32 deletions

View File

@@ -21,7 +21,12 @@ from starlette.authentication import (
# Internal Packages # Internal Packages
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import get_all_users, get_or_create_search_model from khoj.database.adapters import (
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
SubscriptionState,
)
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state from khoj.utils import constants, state
@@ -70,7 +75,11 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) subscription_state = await aget_user_subscription_state(user)
subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header # Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1] bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@@ -82,11 +91,15 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user_with_token: if user_with_token:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser() return AuthCredentials(), UnauthenticatedUser()

View File

@@ -3,6 +3,7 @@ import random
import secrets import secrets
from datetime import date, datetime, timezone from datetime import date, datetime, timezone
from typing import List, Optional, Type from typing import List, Optional, Type
from enum import Enum
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
@@ -40,6 +41,14 @@ from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name from khoj.utils.helpers import generate_random_name
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
UNSUBSCRIBED = "unsubscribed"
EXPIRED = "expired"
INVALID = "invalid"
async def set_notion_config(token: str, user: KhojUser): async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst() notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config: if not notion_config:
@@ -127,22 +136,34 @@ async def set_user_subscription(
return None return None
def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
return SubscriptionState.INVALID.value
def get_user_subscription_state(email: str) -> str: def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user """Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
""" """
user_subscription = Subscription.objects.filter(user__email=email).first() user_subscription = Subscription.objects.filter(user__email=email).first()
if not user_subscription: return subscription_to_state(user_subscription)
return "trial"
elif user_subscription.type == Subscription.Type.TRIAL:
return "trial" async def aget_user_subscription_state(email: str) -> str:
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): """Get subscription state of user
return "subscribed" Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): """
return "unsubscribed" user_subscription = await Subscription.objects.filter(user__email=email).afirst()
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): return subscription_to_state(user_subscription)
return "expired"
return "invalid"
async def get_user_by_email(email: str) -> KhojUser: async def get_user_by_email(email: str) -> KhojUser:

View File

@@ -171,7 +171,7 @@
</div> </div>
</div> </div>
</div> </div>
{% if billing_enabled %} {% if not billing_enabled %}
<div id="billing" class="section"> <div id="billing" class="section">
<h2 class="section-title">Billing</h2> <h2 class="section-title">Billing</h2>
<div class="section-cards"> <div class="section-cards">

View File

@@ -12,7 +12,7 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Header, HTTPException, Request from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires, has_required_scope
# Internal Packages # Internal Packages
from khoj.configure import configure_server from khoj.configure import configure_server

View File

@@ -7,7 +7,7 @@ from fastapi import APIRouter
from fastapi import Request from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.authentication import requires from starlette.authentication import requires, has_required_scope
from khoj.database import adapters from khoj.database import adapters
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
@@ -37,7 +37,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index(request: Request): def index(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -46,7 +45,7 @@ def index(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -57,7 +56,6 @@ def index(request: Request):
def index_post(request: Request): def index_post(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -66,7 +64,7 @@ def index_post(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -77,7 +75,6 @@ def index_post(request: Request):
def search_page(request: Request): def search_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -86,7 +83,7 @@ def search_page(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -97,7 +94,6 @@ def search_page(request: Request):
def chat_page(request: Request): def chat_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -106,7 +102,7 @@ def chat_page(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -171,7 +167,7 @@ def config_page(request: Request):
"subscription_state": user_subscription_state, "subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date, "subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -182,7 +178,6 @@ def config_page(request: Request):
def github_config_page(request: Request): def github_config_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
current_github_config = get_user_github_config(user) current_github_config = get_user_github_config(user)
@@ -212,7 +207,7 @@ def github_config_page(request: Request):
"current_config": current_config, "current_config": current_config,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -223,7 +218,6 @@ def github_config_page(request: Request):
def notion_config_page(request: Request): def notion_config_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = adapters.get_user_subscription(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
current_notion_config = get_user_notion_config(user) current_notion_config = get_user_notion_config(user)
@@ -240,7 +234,7 @@ def notion_config_page(request: Request):
"current_config": current_config, "current_config": current_config,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -251,7 +245,6 @@ def notion_config_page(request: Request):
def computer_config_page(request: Request): def computer_config_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -260,7 +253,7 @@ def computer_config_page(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )