mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Add a ConversationCommand rate limiter for the chat endpoint
This commit is contained in:
@@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||
.afirst()
|
||||
)
|
||||
if user:
|
||||
if state.billing_enabled:
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
if not state.billing_enabled:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||
# Get bearer token from header
|
||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||
@@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||
.afirst()
|
||||
)
|
||||
if user_with_token:
|
||||
if state.billing_enabled:
|
||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
|
||||
user_with_token.user
|
||||
)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if not state.billing_enabled:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||
if user:
|
||||
|
||||
@@ -46,6 +46,7 @@ from khoj.routers.helpers import (
|
||||
is_ready_to_chat,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
ConversationCommandRateLimiter,
|
||||
)
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -67,6 +68,7 @@ from khoj.utils.state import SearchType
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
||||
|
||||
|
||||
def map_config_to_object(content_source: str):
|
||||
@@ -670,6 +672,8 @@ async def chat(
|
||||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
|
||||
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||
|
||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
@@ -300,6 +300,40 @@ class ApiUserRateLimiter:
|
||||
user_requests.append(time())
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
|
||||
self.cache = defaultdict(lambda: defaultdict(list))
|
||||
self.trial_rate_limit = trial_rate_limit
|
||||
self.subscribed_rate_limit = subscribed_rate_limit
|
||||
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
|
||||
|
||||
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
if not request.user.is_authenticated:
|
||||
return
|
||||
|
||||
if conversation_command not in self.restricted_commands:
|
||||
return
|
||||
|
||||
user: KhojUser = request.user.object
|
||||
user_cache = self.cache[user.uuid]
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
user_cache[conversation_command].append(time())
|
||||
|
||||
# Remove requests outside of the 24-hr time window
|
||||
cutoff = time() - 60 * 60 * 24
|
||||
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
|
||||
user_cache[conversation_command].pop(0)
|
||||
|
||||
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||
return
|
||||
|
||||
|
||||
class ApiIndexedDataLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user