mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Enforce API use limits depending on whether the server has billing enabled
and whether the given user is subscribed
This commit is contained in:
@@ -361,9 +361,22 @@
|
|||||||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||||
newResponseText.removeChild(loadingSpinner);
|
newResponseText.removeChild(loadingSpinner);
|
||||||
}
|
}
|
||||||
|
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||||
newResponseText.innerHTML += chunk;
|
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||||
readStream();
|
try {
|
||||||
|
const responseAsJson = JSON.parse(chunk);
|
||||||
|
if (responseAsJson.detail) {
|
||||||
|
newResponseText.innerHTML += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
newResponseText.innerHTML += chunk;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
newResponseText.innerHTML += chunk;
|
||||||
|
readStream();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
|
|||||||
@@ -82,7 +82,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
||||||
user_with_token.user
|
user_with_token.user
|
||||||
)
|
)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||||
# Get bearer token from header
|
# Get bearer token from header
|
||||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||||
@@ -101,7 +102,8 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
||||||
user_with_token.user
|
user_with_token.user
|
||||||
)
|
)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||||
if state.anonymous_mode:
|
if state.anonymous_mode:
|
||||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||||
if user:
|
if user:
|
||||||
|
|||||||
@@ -403,8 +403,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
newResponseText.removeChild(loadingSpinner);
|
newResponseText.removeChild(loadingSpinner);
|
||||||
}
|
}
|
||||||
|
|
||||||
newResponseText.innerHTML += chunk;
|
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||||
readStream();
|
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||||
|
try {
|
||||||
|
const responseAsJson = JSON.parse(chunk);
|
||||||
|
if (responseAsJson.detail) {
|
||||||
|
newResponseText.innerHTML += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
newResponseText.innerHTML += chunk;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
newResponseText.innerHTML += chunk;
|
||||||
|
readStream();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
|
|||||||
@@ -573,8 +573,8 @@ async def chat(
|
|||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
d: Optional[float] = 0.18,
|
d: Optional[float] = 0.18,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import Depends, Header, HTTPException, Request
|
from fastapi import Depends, Header, HTTPException, Request
|
||||||
|
from starlette.authentication import has_required_scope
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription
|
||||||
@@ -270,13 +271,15 @@ def generate_chat_response(
|
|||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, window: int):
|
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
self.subscribed_requests = subscribed_requests
|
||||||
self.window = window
|
self.window = window
|
||||||
self.cache: dict[str, list[float]] = defaultdict(list)
|
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
def __call__(self, request: Request):
|
def __call__(self, request: Request):
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
subscribed = has_required_scope(request, ["subscribed"])
|
||||||
user_requests = self.cache[user.uuid]
|
user_requests = self.cache[user.uuid]
|
||||||
|
|
||||||
# Remove requests outside of the time window
|
# Remove requests outside of the time window
|
||||||
@@ -285,8 +288,10 @@ class ApiUserRateLimiter:
|
|||||||
user_requests.pop(0)
|
user_requests.pop(0)
|
||||||
|
|
||||||
# Check if the user has exceeded the rate limit
|
# Check if the user has exceeded the rate limit
|
||||||
if len(user_requests) >= self.requests:
|
if subscribed and len(user_requests) >= self.subscribed_requests:
|
||||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
if not subscribed and len(user_requests) >= self.requests:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||||
|
|
||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
user_requests.append(time())
|
user_requests.append(time())
|
||||||
|
|||||||
Reference in New Issue
Block a user