From b2afbaa31537a344ff96919b9ff71640343f8e34 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 25 Nov 2023 20:28:04 -0800 Subject: [PATCH] Add support for rate limiting the amount of data indexed - Add a dependency on the indexer API endpoint that rounds up the amount of data indexed and uses that to determine whether the next set of data should be processed - Delete any files that are being removed for adminstering the calculation - Show current amount of data indexed in the config page --- src/khoj/configure.py | 2 + src/khoj/database/adapters/__init__.py | 7 +++ src/khoj/interface/web/config.html | 5 ++- src/khoj/routers/api.py | 4 +- src/khoj/routers/helpers.py | 61 ++++++++++++++++++++++++-- src/khoj/routers/indexer.py | 15 +++++-- src/khoj/routers/web_client.py | 6 ++- tests/test_client.py | 38 ++++++++++++++++ 8 files changed, 127 insertions(+), 11 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 6aa747e8..bff7e3ca 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -80,6 +80,7 @@ class UserAuthenticationBackend(AuthenticationBackend): subscribed = ( subscription_state == SubscriptionState.SUBSCRIBED.value or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) @@ -101,6 +102,7 @@ class UserAuthenticationBackend(AuthenticationBackend): subscribed = ( subscription_state == SubscriptionState.SUBSCRIBED.value or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 7f76b796..bcf4856c 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,6 +1,7 @@ import math import random import secrets +import sys from datetime import date, datetime, timezone, timedelta from typing import List, Optional, Type from enum import Enum @@ -474,6 +475,12 @@ class EntryAdapters: async def adelete_all_entries(user: KhojUser): return await Entry.objects.filter(user=user).adelete() + @staticmethod + def get_size_of_indexed_data_in_mb(user: KhojUser): + entries = Entry.objects.filter(user=user).iterator() + total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) + return total_size / 1024 / 1024 + @staticmethod def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): q_filter_terms = Q() diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 318759df..88fbc70d 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -4,6 +4,7 @@

Content

+

{{indexed_data_size_in_mb}} MB used

@@ -171,7 +172,7 @@
- {% if not billing_enabled %} + {% if billing_enabled %}

Billing

@@ -191,7 +192,7 @@

- Subscribe to Khoj Cloud + Subscribe to Khoj Cloud. See pricing for details.

= self.subscribed_num_entries_size: + raise HTTPException(status_code=429, detail="Too much data indexed.") + if not subscribed and incoming_data_size_mb >= self.num_entries_size: + raise HTTPException( + status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit." + ) + + user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user) + if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size: + raise HTTPException(status_code=429, detail="Too much data indexed.") + if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit: + raise HTTPException( + status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit." + ) + + class CommonQueryParamsClass: def __init__( self, diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 0432eed0..0c906707 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -2,7 +2,7 @@ import asyncio import logging from typing import Dict, Optional, Union -from fastapi import APIRouter, Header, Request, Response, UploadFile +from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends from pydantic import BaseModel from starlette.authentication import requires @@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search from khoj.utils import constants, state from khoj.utils.config import ContentIndex, SearchModels from khoj.utils.helpers import LRU, get_file_type +from khoj.routers.helpers import ApiIndexedDataLimiter from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.yaml import save_config_to_file_updated_state @@ -53,6 +54,14 @@ async def update( user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), host: Optional[str] = Header(None), + indexed_data_limiter: ApiIndexedDataLimiter = Depends( + ApiIndexedDataLimiter( + incoming_entries_size_limit=10, + subscribed_incoming_entries_size_limit=25, + total_entries_size_limit=10, + subscribed_total_entries_size_limit=100, + ) + ), ): user = request.user.object try: @@ -92,7 +101,7 @@ async def update( logger.info("📬 Initializing content index on first run.") default_full_config = FullConfig( content_type=None, - search_type=SearchConfig.parse_obj(constants.default_config["search-type"]), + search_type=SearchConfig.model_validate(constants.default_config["search-type"]), processor=None, ) state.config = default_full_config @@ -116,7 +125,7 @@ async def update( configure_content, state.content_index, state.config.content_type, - indexer_input.dict(), + indexer_input.model_dump(), state.search_models, force, t, diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index bf3ff957..c17704bd 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -1,6 +1,8 @@ # System Packages import json import os +import math +from datetime import timedelta # External Packages from fastapi import APIRouter @@ -137,8 +139,9 @@ def config_page(request: Request): subscription_renewal_date = ( user_subscription.renewal_date.strftime("%d %b %Y") if user_subscription and user_subscription.renewal_date - else None + else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y") ) + indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user)) enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) successfully_configured = { @@ -169,6 +172,7 @@ def config_page(request: Request): "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, + "indexed_data_size_in_mb": indexed_data_size_in_mb, }, ) diff --git a/tests/test_client.py b/tests/test_client.py index 19aba03b..98affe27 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -125,6 +125,34 @@ def test_regenerate_with_invalid_content_type(client): assert response.status_code == 422 +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_big_files(client): + state.billing_enabled = True + # Arrange + files = get_big_size_sample_files_data() + headers = {"Authorization": "Bearer kk-secret"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 429 + + +@pytest.mark.django_db(transaction=True) +def test_index_update_big_files_no_billing(client): + # Arrange + files = get_big_size_sample_files_data() + headers = {"Authorization": "Bearer kk-secret"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 200 + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_index_update(client): @@ -421,3 +449,13 @@ def get_sample_files_data(): ), ("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")), ] + + +def get_big_size_sample_files_data(): + big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB + return [ + ( + "files", + ("path/to/filename.org", big_text, "text/org"), + ) + ]