Merge pull request #569 from khoj-ai/features/enforce-subscription-status

Enforce subscription state on the chat API access
This commit is contained in:
sabaimran
2023-11-27 16:12:26 -08:00
committed by GitHub
15 changed files with 368 additions and 55 deletions

View File

@@ -361,12 +361,25 @@
if (newResponseText.getElementsByClassName("spinner").length > 0) { if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner); newResponseText.removeChild(loadingSpinner);
} }
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
rawResponse += chunk; readStream();
newResponseText.innerHTML = ""; }
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
} }
// Scroll to bottom of chat window as chat response is streamed // Scroll to bottom of chat window as chat response is streamed

View File

@@ -101,6 +101,9 @@
<div class="card-description-row"> <div class="card-description-row">
<div id="sync-status"></div> <div id="sync-status"></div>
</div> </div>
<div id="needs-subscription" style="display: none;">
Looks like you're out of space to sync your files. <a href="https://app.khoj.dev/config">Upgrade your plan</a> to unlock more space.
</div>
</div> </div>
</body> </body>

View File

@@ -198,6 +198,11 @@ function pushDataToKhoj (regenerate = false) {
}) })
.catch(error => { .catch(error => {
console.error(error); console.error(error);
if (error.response.status == 429) {
const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('needsSubscription', true);
if (win) win.webContents.send('update-state', state);
}
state['completed'] = false state['completed'] = false
}) })
.finally(() => { .finally(() => {
@@ -396,6 +401,11 @@ app.whenReady().then(() => {
event.reply('update-state', arg); event.reply('update-state', arg);
}); });
ipcMain.on('needsSubscription', (event, arg) => {
console.log(arg);
event.reply('needsSubscription', arg);
});
ipcMain.on('navigate', (event, page) => { ipcMain.on('navigate', (event, page) => {
win.loadFile(page); win.loadFile(page);
}); });

View File

@@ -31,6 +31,10 @@ contextBridge.exposeInMainWorld('updateStateAPI', {
onUpdateState: (callback) => ipcRenderer.on('update-state', callback) onUpdateState: (callback) => ipcRenderer.on('update-state', callback)
}) })
contextBridge.exposeInMainWorld('needsSubscriptionAPI', {
onNeedsSubscription: (callback) => ipcRenderer.on('needsSubscription', callback)
})
contextBridge.exposeInMainWorld('removeFileAPI', { contextBridge.exposeInMainWorld('removeFileAPI', {
removeFile: (filePath) => ipcRenderer.invoke('removeFile', filePath) removeFile: (filePath) => ipcRenderer.invoke('removeFile', filePath)
}) })

View File

@@ -1,7 +1,7 @@
const setFolderButton = document.getElementById('update-folder'); const setFolderButton = document.getElementById('update-folder');
const setFileButton = document.getElementById('update-file'); const setFileButton = document.getElementById('update-file');
const showKey = document.getElementById('show-key');
const loadingBar = document.getElementById('loading-bar'); const loadingBar = document.getElementById('loading-bar');
const needsSubscriptionElement = document.getElementById('needs-subscription');
async function removeFile(filePath) { async function removeFile(filePath) {
const updatedFiles = await window.removeFileAPI.removeFile(filePath); const updatedFiles = await window.removeFileAPI.removeFile(filePath);
@@ -165,6 +165,15 @@ window.updateStateAPI.onUpdateState((event, state) => {
syncStatusElement.innerHTML = `⏱️ Synced at ${currentTime.toLocaleTimeString(undefined, options)}. Next sync at ${nextSyncTime.toLocaleTimeString(undefined, options)}.`; syncStatusElement.innerHTML = `⏱️ Synced at ${currentTime.toLocaleTimeString(undefined, options)}. Next sync at ${nextSyncTime.toLocaleTimeString(undefined, options)}.`;
}); });
window.needsSubscriptionAPI.onNeedsSubscription((event, needsSubscription) => {
console.log("needs subscription", needsSubscription);
if (needsSubscription) {
needsSubscriptionElement.style.display = 'block';
} else {
needsSubscriptionElement.style.display = 'none';
}
});
const urlInput = document.getElementById('khoj-host-url'); const urlInput = document.getElementById('khoj-host-url');
(async function() { (async function() {
const url = await window.hostURLAPI.getURL(); const url = await window.hostURLAPI.getURL();

View File

@@ -21,7 +21,12 @@ from starlette.authentication import (
# Internal Packages # Internal Packages
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import get_all_users, get_or_create_search_model from khoj.database.adapters import (
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
SubscriptionState,
)
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state from khoj.utils import constants, state
@@ -70,7 +75,17 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header # Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1] bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@@ -82,11 +97,23 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user_with_token: if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
user_with_token.user
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser() return AuthCredentials(), UnauthenticatedUser()

View File

@@ -1,8 +1,10 @@
import math import math
import random import random
import secrets import secrets
from datetime import date, datetime, timezone import sys
from datetime import date, datetime, timezone, timedelta
from typing import List, Optional, Type from typing import List, Optional, Type
from enum import Enum
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
@@ -41,6 +43,14 @@ from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name from khoj.utils.helpers import generate_random_name
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
UNSUBSCRIBED = "unsubscribed"
EXPIRED = "expired"
INVALID = "invalid"
async def set_notion_config(token: str, user: KhojUser): async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst() notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config: if not notion_config:
@@ -128,22 +138,38 @@ async def set_user_subscription(
return None return None
def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
# Trial subscription is valid for 7 days
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7):
return SubscriptionState.EXPIRED.value
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
return SubscriptionState.INVALID.value
def get_user_subscription_state(email: str) -> str: def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user """Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
""" """
user_subscription = Subscription.objects.filter(user__email=email).first() user_subscription = Subscription.objects.filter(user__email=email).first()
if not user_subscription: return subscription_to_state(user_subscription)
return "trial"
elif user_subscription.type == Subscription.Type.TRIAL:
return "trial" async def aget_user_subscription_state(email: str) -> str:
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): """Get subscription state of user
return "subscribed" Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): """
return "unsubscribed" user_subscription = await Subscription.objects.filter(user__email=email).afirst()
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): return subscription_to_state(user_subscription)
return "expired"
return "invalid"
async def get_user_by_email(email: str) -> KhojUser: async def get_user_by_email(email: str) -> KhojUser:
@@ -458,6 +484,12 @@ class EntryAdapters:
async def adelete_all_entries(user: KhojUser): async def adelete_all_entries(user: KhojUser):
return await Entry.objects.filter(user=user).adelete() return await Entry.objects.filter(user=user).adelete()
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024
@staticmethod @staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q() q_filter_terms = Q()

View File

@@ -402,10 +402,24 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.removeChild(loadingSpinner); newResponseText.removeChild(loadingSpinner);
} }
rawResponse += chunk; // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
newResponseText.innerHTML = ""; if (chunk.startsWith("{") && chunk.endsWith("}")) {
newResponseText.appendChild(formatHTMLMessage(rawResponse)); try {
readStream(); const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
} }
// Scroll to bottom of chat window as chat response is streamed // Scroll to bottom of chat window as chat response is streamed

View File

@@ -4,6 +4,10 @@
<div class="page"> <div class="page">
<div id="content" class="section"> <div id="content" class="section">
<h2 class="section-title">Content</h2> <h2 class="section-title">Content</h2>
<button id="compute-index-size" class="card-button" onclick="getIndexedDataSize()">
Data Usage
</button>
<p id="indexed-data-size" class="card-description"></p>
<div class="section-cards"> <div class="section-cards">
<div class="card"> <div class="card">
<div class="card-title-row"> <div class="card-title-row">
@@ -191,7 +195,7 @@
<p id="trial-description" <p id="trial-description"
class="card-description" class="card-description"
style="display: {% if subscription_state != 'trial' %}none{% endif %}"> style="display: {% if subscription_state != 'trial' %}none{% endif %}">
Subscribe to Khoj Cloud Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
</p> </p>
<p id="unsubscribe-description" <p id="unsubscribe-description"
class="card-description" class="card-description"
@@ -471,6 +475,15 @@
}); });
} }
function getIndexedDataSize() {
document.getElementById("indexed-data-size").innerHTML = "Calculating...";
fetch('/api/config/index/size')
.then(response => response.json())
.then(data => {
document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used";
});
}
// List user's API keys on page load // List user's API keys on page load
listApiKeys(); listApiKeys();

View File

@@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Union
import uuid import uuid
# External Packages # External Packages
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from asgiref.sync import sync_to_async
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
@@ -334,6 +334,18 @@ def get_default_config_data():
return constants.empty_config return constants.empty_config
@api.get("/config/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api.get("/config/types", response_model=List[str]) @api.get("/config/types", response_model=List[str])
@requires(["authenticated"]) @requires(["authenticated"])
def get_config_types( def get_config_types(
@@ -650,8 +662,8 @@ async def chat(
n: Optional[int] = 5, n: Optional[int] = 5,
d: Optional[float] = 0.18, d: Optional[float] = 0.18,
stream: Optional[bool] = False, stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)), rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user = request.user.object user = request.user.object

View File

@@ -9,10 +9,12 @@ from functools import partial
from time import time from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages from fastapi import Depends, Header, HTTPException, Request, UploadFile
from fastapi import Depends, Header, HTTPException, Request from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
from khoj.database.adapters import ConversationAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
@@ -270,13 +272,15 @@ def generate_chat_response(
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, window: int): def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests self.requests = requests
self.subscribed_requests = subscribed_requests
self.window = window self.window = window
self.cache: dict[str, list[float]] = defaultdict(list) self.cache: dict[str, list[float]] = defaultdict(list)
def __call__(self, request: Request): def __call__(self, request: Request):
user: KhojUser = request.user.object user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
user_requests = self.cache[user.uuid] user_requests = self.cache[user.uuid]
# Remove requests outside of the time window # Remove requests outside of the time window
@@ -285,13 +289,69 @@ class ApiUserRateLimiter:
user_requests.pop(0) user_requests.pop(0)
# Check if the user has exceeded the rate limit # Check if the user has exceeded the rate limit
if len(user_requests) >= self.requests: if subscribed and len(user_requests) >= self.subscribed_requests:
raise HTTPException(status_code=429, detail="Too Many Requests") raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_requests) >= self.requests:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
# Add the current request to the cache # Add the current request to the cache
user_requests.append(time()) user_requests.append(time())
class ApiIndexedDataLimiter:
def __init__(
self,
incoming_entries_size_limit: float,
subscribed_incoming_entries_size_limit: float,
total_entries_size_limit: float,
subscribed_total_entries_size_limit: float,
):
self.num_entries_size = incoming_entries_size_limit
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
self.total_entries_size_limit = total_entries_size_limit
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
def __call__(self, request: Request, files: List[UploadFile]):
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0
deletion_file_names = set()
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object
for file in files:
if file.size == 0:
deletion_file_names.add(file.filename)
incoming_data_size_mb += file.size / 1024 / 1024
num_deleted_entries = 0
for file_path in deletion_file_names:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
class CommonQueryParamsClass: class CommonQueryParamsClass:
def __init__( def __init__(
self, self,

View File

@@ -2,7 +2,7 @@ import asyncio
import logging import logging
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from fastapi import APIRouter, Header, Request, Response, UploadFile from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
from pydantic import BaseModel from pydantic import BaseModel
from starlette.authentication import requires from starlette.authentication import requires
@@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import ContentIndex, SearchModels from khoj.utils.config import ContentIndex, SearchModels
from khoj.utils.helpers import LRU, get_file_type from khoj.utils.helpers import LRU, get_file_type
from khoj.routers.helpers import ApiIndexedDataLimiter
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.yaml import save_config_to_file_updated_state
@@ -53,6 +54,14 @@ async def update(
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
host: Optional[str] = Header(None), host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
): ):
user = request.user.object user = request.user.object
try: try:
@@ -92,7 +101,7 @@ async def update(
logger.info("📬 Initializing content index on first run.") logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig( default_full_config = FullConfig(
content_type=None, content_type=None,
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]), search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None, processor=None,
) )
state.config = default_full_config state.config = default_full_config
@@ -116,7 +125,7 @@ async def update(
configure_content, configure_content,
state.content_index, state.content_index,
state.config.content_type, state.config.content_type,
indexer_input.dict(), indexer_input.model_dump(),
state.search_models, state.search_models,
force, force,
t, t,

View File

@@ -1,13 +1,15 @@
# System Packages # System Packages
import json import json
import os import os
import math
from datetime import timedelta
# External Packages # External Packages
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Request from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.authentication import requires from starlette.authentication import requires, has_required_scope
from khoj.database import adapters from khoj.database import adapters
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
@@ -37,7 +39,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index(request: Request): def index(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -46,7 +47,7 @@ def index(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -57,7 +58,6 @@ def index(request: Request):
def index_post(request: Request): def index_post(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -66,7 +66,7 @@ def index_post(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -77,7 +77,6 @@ def index_post(request: Request):
def search_page(request: Request): def search_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -86,7 +85,7 @@ def search_page(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -97,7 +96,6 @@ def search_page(request: Request):
def chat_page(request: Request): def chat_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -106,7 +104,7 @@ def chat_page(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -141,7 +139,7 @@ def config_page(request: Request):
subscription_renewal_date = ( subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y") user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date if user_subscription and user_subscription.renewal_date
else None else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
) )
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
@@ -171,7 +169,7 @@ def config_page(request: Request):
"subscription_state": user_subscription_state, "subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date, "subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -182,7 +180,6 @@ def config_page(request: Request):
def github_config_page(request: Request): def github_config_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
current_github_config = get_user_github_config(user) current_github_config = get_user_github_config(user)
@@ -212,7 +209,7 @@ def github_config_page(request: Request):
"current_config": current_config, "current_config": current_config,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -223,7 +220,6 @@ def github_config_page(request: Request):
def notion_config_page(request: Request): def notion_config_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = adapters.get_user_subscription(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
current_notion_config = get_user_notion_config(user) current_notion_config = get_user_notion_config(user)
@@ -240,7 +236,7 @@ def notion_config_page(request: Request):
"current_config": current_config, "current_config": current_config,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )
@@ -251,7 +247,6 @@ def notion_config_page(request: Request):
def computer_config_page(request: Request): def computer_config_page(request: Request):
user = request.user.object user = request.user.object
user_picture = request.session.get("user", {}).get("picture") user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user) has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -260,7 +255,7 @@ def computer_config_page(request: Request):
"request": request, "request": request,
"username": user.username, "username": user.username,
"user_photo": user_picture, "user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", "is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents, "has_documents": has_documents,
}, },
) )

View File

@@ -102,6 +102,24 @@ def default_user3():
return user return user
@pytest.mark.django_db
@pytest.fixture
def default_user4():
"""
This user should not have a valid subscription
"""
if KhojUser.objects.filter(username="default4").exists():
return KhojUser.objects.get(username="default4")
user = KhojUser.objects.create(
username="default4",
email="default4@example.com",
password="default4",
)
SubscriptionFactory(user=user, renewal_date=None)
return user
@pytest.mark.django_db @pytest.mark.django_db
@pytest.fixture @pytest.fixture
def api_user(default_user): def api_user(default_user):
@@ -141,6 +159,19 @@ def api_user3(default_user3):
) )
@pytest.mark.django_db
@pytest.fixture
def api_user4(default_user4):
if KhojApiUser.objects.filter(user=default_user4).exists():
return KhojApiUser.objects.get(user=default_user4)
return KhojApiUser.objects.create(
user=default_user4,
name="api-key",
token="kk-diff-secret-4",
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def search_models(search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels() search_models = SearchModels()

View File

@@ -125,6 +125,67 @@ def test_regenerate_with_invalid_content_type(client):
assert response.status_code == 422 assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files(client):
# Arrange
state.billing_enabled = True
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 429
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser):
# Arrange
api_token = api_user4.token
state.billing_enabled = True
files = get_medium_size_sample_files_data()
headers = {"Authorization": f"Bearer {api_token}"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 429
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser):
# Arrange
api_token = api_user4.token
state.billing_enabled = True
files = get_sample_files_data()
headers = {"Authorization": f"Bearer {api_token}"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 200
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files_no_billing(client):
# Arrange
state.billing_enabled = False
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 200
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_index_update(client): def test_index_update(client):
@@ -421,3 +482,23 @@ def get_sample_files_data():
), ),
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")), ("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
] ]
def get_big_size_sample_files_data():
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
return [
(
"files",
("path/to/filename.org", big_text, "text/org"),
)
]
def get_medium_size_sample_files_data():
big_text = "a" * (10 * 1024 * 1024) # a string of approximately 10 MB
return [
(
"files",
("path/to/filename.org", big_text, "text/org"),
)
]