diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index d34205d9..4e4e1008 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -21,7 +21,12 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
-from khoj.database.adapters import get_all_users, get_or_create_search_model
+from khoj.database.adapters import (
+ get_all_users,
+ get_or_create_search_model,
+ aget_user_subscription_state,
+ SubscriptionState,
+)
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state
@@ -70,7 +75,11 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
- return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
+ subscription_state = await aget_user_subscription_state(user)
+ subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
+ if subscribed:
+ return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
+ return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@@ -82,11 +91,15 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
+ subscription_state = await aget_user_subscription_state(user_with_token.user)
+ subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
+ if subscribed:
+ return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
- return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
+ return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser()
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 7fd04006..146de11c 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -3,6 +3,7 @@ import random
import secrets
from datetime import date, datetime, timezone
from typing import List, Optional, Type
+from enum import Enum
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
@@ -40,6 +41,14 @@ from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name
+class SubscriptionState(Enum):
+ TRIAL = "trial"
+ SUBSCRIBED = "subscribed"
+ UNSUBSCRIBED = "unsubscribed"
+ EXPIRED = "expired"
+ INVALID = "invalid"
+
+
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config:
@@ -127,22 +136,34 @@ async def set_user_subscription(
return None
+def subscription_to_state(subscription: Subscription) -> str:
+ if not subscription:
+ return SubscriptionState.INVALID.value
+ elif subscription.type == Subscription.Type.TRIAL:
+ return SubscriptionState.TRIAL.value
+ elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
+ return SubscriptionState.SUBSCRIBED.value
+ elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
+ return SubscriptionState.UNSUBSCRIBED.value
+ elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
+ return SubscriptionState.EXPIRED.value
+ return SubscriptionState.INVALID.value
+
+
def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = Subscription.objects.filter(user__email=email).first()
- if not user_subscription:
- return "trial"
- elif user_subscription.type == Subscription.Type.TRIAL:
- return "trial"
- elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
- return "subscribed"
- elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
- return "unsubscribed"
- elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
- return "expired"
- return "invalid"
+ return subscription_to_state(user_subscription)
+
+
+async def aget_user_subscription_state(email: str) -> str:
+ """Get subscription state of user
+ Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
+ """
+ user_subscription = await Subscription.objects.filter(user__email=email).afirst()
+ return subscription_to_state(user_subscription)
async def get_user_by_email(email: str) -> KhojUser:
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index 01a3786f..318759df 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -171,7 +171,7 @@
- {% if billing_enabled %}
+ {% if not billing_enabled %}
Billing
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 83955088..d5b6ce0e 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -12,7 +12,7 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
-from starlette.authentication import requires
+from starlette.authentication import requires, has_required_scope
# Internal Packages
from khoj.configure import configure_server
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index dab16fa8..bf3ff957 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -7,7 +7,7 @@ from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
-from starlette.authentication import requires
+from starlette.authentication import requires, has_required_scope
from khoj.database import adapters
from khoj.database.models import KhojUser
from khoj.utils.rawconfig import (
@@ -37,7 +37,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@@ -46,7 +45,7 @@ def index(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -57,7 +56,6 @@ def index(request: Request):
def index_post(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@@ -66,7 +64,7 @@ def index_post(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -77,7 +75,6 @@ def index_post(request: Request):
def search_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@@ -86,7 +83,7 @@ def search_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -97,7 +94,6 @@ def search_page(request: Request):
def chat_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@@ -106,7 +102,7 @@ def chat_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -171,7 +167,7 @@ def config_page(request: Request):
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -182,7 +178,6 @@ def config_page(request: Request):
def github_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
current_github_config = get_user_github_config(user)
@@ -212,7 +207,7 @@ def github_config_page(request: Request):
"current_config": current_config,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -223,7 +218,6 @@ def github_config_page(request: Request):
def notion_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = adapters.get_user_subscription(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
current_notion_config = get_user_notion_config(user)
@@ -240,7 +234,7 @@ def notion_config_page(request: Request):
"current_config": current_config,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@@ -251,7 +245,6 @@ def notion_config_page(request: Request):
def computer_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
- user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@@ -260,7 +253,7 @@ def computer_config_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
- "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
+ "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)