mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Merge pull request #584 from khoj-ai/features/enforce-usage-limits-conversation-type
Add a ConversationCommand rate limiter for the chat endpoint
This commit is contained in:
@@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if user:
|
if user:
|
||||||
if state.billing_enabled:
|
if not state.billing_enabled:
|
||||||
subscription_state = await aget_user_subscription_state(user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
subscription_state = await aget_user_subscription_state(user)
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
subscribed = (
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
)
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
if subscribed:
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
if subscribed:
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||||
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||||
# Get bearer token from header
|
# Get bearer token from header
|
||||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||||
@@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if user_with_token:
|
if user_with_token:
|
||||||
if state.billing_enabled:
|
if not state.billing_enabled:
|
||||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
subscribed = (
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
)
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
if subscribed:
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
|
)
|
||||||
user_with_token.user
|
if subscribed:
|
||||||
)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
|
||||||
if state.anonymous_mode:
|
if state.anonymous_mode:
|
||||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||||
if user:
|
if user:
|
||||||
|
|||||||
@@ -401,7 +401,7 @@ class ConversationAdapters:
|
|||||||
)
|
)
|
||||||
|
|
||||||
max_results = 3
|
max_results = 3
|
||||||
all_questions = await sync_to_async(list)(all_questions)
|
all_questions = await sync_to_async(list)(all_questions) # type: ignore
|
||||||
if len(all_questions) < max_results:
|
if len(all_questions) < max_results:
|
||||||
return all_questions
|
return all_questions
|
||||||
|
|
||||||
|
|||||||
@@ -642,6 +642,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||||
} else if (err.status === 422) {
|
} else if (err.status === 422) {
|
||||||
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||||
|
} else if (err.status === 429) {
|
||||||
|
flashStatusInChatInput("⛔️ " + err.statusText);
|
||||||
} else {
|
} else {
|
||||||
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from khoj.routers.helpers import (
|
|||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
|
ConversationCommandRateLimiter,
|
||||||
)
|
)
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
@@ -67,6 +68,7 @@ from khoj.utils.state import SearchType
|
|||||||
# Initialize Router
|
# Initialize Router
|
||||||
api = APIRouter()
|
api = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
||||||
|
|
||||||
|
|
||||||
def map_config_to_object(content_source: str):
|
def map_config_to_object(content_source: str):
|
||||||
@@ -604,7 +606,13 @@ async def chat_options(
|
|||||||
|
|
||||||
@api.post("/transcribe")
|
@api.post("/transcribe")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
|
async def transcribe(
|
||||||
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
|
||||||
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
|
):
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
||||||
user_message: str = None
|
user_message: str = None
|
||||||
@@ -670,6 +678,8 @@ async def chat(
|
|||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
|
|
||||||
|
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||||
|
|
||||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||||
|
|
||||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
|||||||
)
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
except openai.OpenAIError as e:
|
except openai.OpenAIError as e:
|
||||||
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
status_code = 500
|
status_code = 500
|
||||||
|
|
||||||
return image, status_code
|
return image, status_code
|
||||||
@@ -300,6 +300,40 @@ class ApiUserRateLimiter:
|
|||||||
user_requests.append(time())
|
user_requests.append(time())
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationCommandRateLimiter:
|
||||||
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
|
||||||
|
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
||||||
|
self.trial_rate_limit = trial_rate_limit
|
||||||
|
self.subscribed_rate_limit = subscribed_rate_limit
|
||||||
|
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
|
||||||
|
|
||||||
|
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if conversation_command not in self.restricted_commands:
|
||||||
|
return
|
||||||
|
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
user_cache = self.cache[user.uuid]
|
||||||
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
user_cache[conversation_command].append(time())
|
||||||
|
|
||||||
|
# Remove requests outside of the 24-hr time window
|
||||||
|
cutoff = time() - 60 * 60 * 24
|
||||||
|
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
|
||||||
|
user_cache[conversation_command].pop(0)
|
||||||
|
|
||||||
|
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class ApiIndexedDataLimiter:
|
class ApiIndexedDataLimiter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -317,7 +351,7 @@ class ApiIndexedDataLimiter:
|
|||||||
if state.billing_enabled is False:
|
if state.billing_enabled is False:
|
||||||
return
|
return
|
||||||
subscribed = has_required_scope(request, ["premium"])
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
incoming_data_size_mb = 0
|
incoming_data_size_mb = 0.0
|
||||||
deletion_file_names = set()
|
deletion_file_names = set()
|
||||||
|
|
||||||
if not request.user.is_authenticated:
|
if not request.user.is_authenticated:
|
||||||
|
|||||||
Reference in New Issue
Block a user