diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 717ad859..4bb23812 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - if state.billing_enabled: - subscription_state = await aget_user_subscription_state(user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) - if subscribed: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + + subscription_state = await aget_user_subscription_state(user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: - if state.billing_enabled: - subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) - if subscribed: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser( - user_with_token.user - ) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7efd8bfd..65516a9c 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -46,6 +46,7 @@ from khoj.routers.helpers import ( is_ready_to_chat, update_telemetry_state, validate_conversation_config, + ConversationCommandRateLimiter, ) from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -67,6 +68,7 @@ from khoj.utils.state import SearchType # Initialize Router api = APIRouter() logger = logging.getLogger(__name__) +conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100) def map_config_to_object(content_source: str): @@ -670,6 +672,8 @@ async def chat( await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) + conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command) + q = q.replace(f"/{conversation_command.value}", "").strip() meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4e883f35..56fba861 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -300,6 +300,40 @@ class ApiUserRateLimiter: user_requests.append(time()) +class ConversationCommandRateLimiter: + def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int): + self.cache = defaultdict(lambda: defaultdict(list)) + self.trial_rate_limit = trial_rate_limit + self.subscribed_rate_limit = subscribed_rate_limit + self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image] + + def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): + if state.billing_enabled is False: + return + + if not request.user.is_authenticated: + return + + if conversation_command not in self.restricted_commands: + return + + user: KhojUser = request.user.object + user_cache = self.cache[user.uuid] + subscribed = has_required_scope(request, ["premium"]) + user_cache[conversation_command].append(time()) + + # Remove requests outside of the 24-hr time window + cutoff = time() - 60 * 60 * 24 + while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff: + user_cache[conversation_command].pop(0) + + if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit: + raise HTTPException(status_code=429, detail="Too Many Requests") + if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit: + raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") + return + + class ApiIndexedDataLimiter: def __init__( self,