Add a ConversationCommand rate limiter for the chat endpoint

This commit is contained in:
sabaimran
2023-12-16 09:03:52 +05:30
parent 9b961ed496
commit 73a107690d
3 changed files with 62 additions and 24 deletions

View File

@@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
user_with_token.user
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:

View File

@@ -46,6 +46,7 @@ from khoj.routers.helpers import (
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
ConversationCommandRateLimiter,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -67,6 +68,7 @@ from khoj.utils.state import SearchType
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
def map_config_to_object(content_source: str):
@@ -670,6 +672,8 @@ async def chat(
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log

View File

@@ -300,6 +300,40 @@ class ApiUserRateLimiter:
user_requests.append(time())
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
self.cache = defaultdict(lambda: defaultdict(list))
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
if not request.user.is_authenticated:
return
if conversation_command not in self.restricted_commands:
return
user: KhojUser = request.user.object
user_cache = self.cache[user.uuid]
subscribed = has_required_scope(request, ["premium"])
user_cache[conversation_command].append(time())
# Remove requests outside of the 24-hr time window
cutoff = time() - 60 * 60 * 24
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
user_cache[conversation_command].pop(0)
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
return
class ApiIndexedDataLimiter:
def __init__(
self,