diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 8b230641..34f6be08 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1533,8 +1533,8 @@ async def chat_ws( # Apply rate limiting manually try: - rate_limiter_per_minute.check_websocket(websocket) - rate_limiter_per_day.check_websocket(websocket) + await rate_limiter_per_minute.check_websocket(websocket) + await rate_limiter_per_day.check_websocket(websocket) image_rate_limiter.check_websocket(websocket, body) except HTTPException as e: await websocket.send_text(json.dumps({"error": e.detail})) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0bb2d7ba..4c286980 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1936,7 +1936,7 @@ class ApiUserRateLimiter: # Add the current request to the cache UserRequests.objects.create(user=user, slug=self.slug) - def check_websocket(self, websocket: WebSocket): + async def check_websocket(self, websocket: WebSocket): """WebSocket-specific rate limiting method""" # Rate limiting disabled if billing is disabled if state.billing_enabled is False: @@ -1954,7 +1954,7 @@ class ApiUserRateLimiter: # Remove requests outside of the time window cutoff = django_timezone.now() - timedelta(seconds=self.window) - count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() + count_requests = await UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).acount() # Check if the user has exceeded the rate limit if subscribed and count_requests >= self.subscribed_requests: @@ -1984,7 +1984,7 @@ class ApiUserRateLimiter: ) # Add the current request to the cache - UserRequests.objects.create(user=user, slug=self.slug) + await UserRequests.objects.acreate(user=user, slug=self.slug) class ApiImageRateLimiter: