diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 717ad859..4bb23812 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - if state.billing_enabled: - subscription_state = await aget_user_subscription_state(user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) - if subscribed: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + + subscription_state = await aget_user_subscription_state(user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: - if state.billing_enabled: - subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) - if subscribed: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser( - user_with_token.user - ) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 6e17c15b..31aa1d57 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -401,7 +401,7 @@ class ConversationAdapters: ) max_results = 3 - all_questions = await sync_to_async(list)(all_questions) + all_questions = await sync_to_async(list)(all_questions) # type: ignore if len(all_questions) < max_results: return all_questions diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index e85759fb..aa6bd4b9 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -642,6 +642,8 @@ To get started, just start typing below. You can also type / to see a list of co flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") } else if (err.status === 422) { flashStatusInChatInput("⛔️ Audio file to large to process.") + } else if (err.status === 429) { + flashStatusInChatInput("⛔️ " + err.statusText); } else { flashStatusInChatInput("⛔️ Failed to transcribe audio.") } diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7efd8bfd..a4063afa 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -46,6 +46,7 @@ from khoj.routers.helpers import ( is_ready_to_chat, update_telemetry_state, validate_conversation_config, + ConversationCommandRateLimiter, ) from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -67,6 +68,7 @@ from khoj.utils.state import SearchType # Initialize Router api = APIRouter() logger = logging.getLogger(__name__) +conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100) def map_config_to_object(content_source: str): @@ -604,7 +606,13 @@ async def chat_options( @api.post("/transcribe") @requires(["authenticated"]) -async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)): +async def transcribe( + request: Request, + common: CommonQueryParams, + file: UploadFile = File(...), + rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)), + rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), +): user: KhojUser = request.user.object audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm" user_message: str = None @@ -670,6 +678,8 @@ async def chat( await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) + conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command) + q = q.replace(f"/{conversation_command.value}", "").strip() meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4e883f35..d2a79e39 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -267,7 +267,7 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: ) image = response.data[0].b64_json except openai.OpenAIError as e: - logger.error(f"Image Generation failed with {e.http_status}: {e.error}") + logger.error(f"Image Generation failed with {e}", exc_info=True) status_code = 500 return image, status_code @@ -300,6 +300,40 @@ class ApiUserRateLimiter: user_requests.append(time()) +class ConversationCommandRateLimiter: + def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int): + self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) + self.trial_rate_limit = trial_rate_limit + self.subscribed_rate_limit = subscribed_rate_limit + self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image] + + def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): + if state.billing_enabled is False: + return + + if not request.user.is_authenticated: + return + + if conversation_command not in self.restricted_commands: + return + + user: KhojUser = request.user.object + user_cache = self.cache[user.uuid] + subscribed = has_required_scope(request, ["premium"]) + user_cache[conversation_command].append(time()) + + # Remove requests outside of the 24-hr time window + cutoff = time() - 60 * 60 * 24 + while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff: + user_cache[conversation_command].pop(0) + + if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit: + raise HTTPException(status_code=429, detail="Too Many Requests") + if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit: + raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") + return + + class ApiIndexedDataLimiter: def __init__( self, @@ -317,7 +351,7 @@ class ApiIndexedDataLimiter: if state.billing_enabled is False: return subscribed = has_required_scope(request, ["premium"]) - incoming_data_size_mb = 0 + incoming_data_size_mb = 0.0 deletion_file_names = set() if not request.user.is_authenticated: