mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Add websocket chat api endpoint to ease bi-directional communication
This commit is contained in:
@@ -220,7 +220,16 @@ def set_state(args):
|
|||||||
def start_server(app, host=None, port=None, socket=None):
|
def start_server(app, host=None, port=None, socket=None):
|
||||||
logger.info("🌖 Khoj is ready to engage")
|
logger.info("🌖 Khoj is ready to engage")
|
||||||
if socket:
|
if socket:
|
||||||
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None)
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
proxy_headers=True,
|
||||||
|
uds=socket,
|
||||||
|
log_level="debug" if state.verbose > 1 else "info",
|
||||||
|
use_colors=True,
|
||||||
|
log_config=None,
|
||||||
|
ws_ping_timeout=300,
|
||||||
|
timeout_keep_alive=60,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
@@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None):
|
|||||||
log_level="debug" if state.verbose > 1 else "info",
|
log_level="debug" if state.verbose > 1 else "info",
|
||||||
use_colors=True,
|
use_colors=True,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
|
ws_ping_timeout=300,
|
||||||
timeout_keep_alive=60,
|
timeout_keep_alive=60,
|
||||||
**state.ssl_config if state.ssl_config else {},
|
**state.ssl_config if state.ssl_config else {},
|
||||||
)
|
)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -33,7 +33,7 @@ from apscheduler.job import Job
|
|||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.utils import timezone as django_timezone
|
from django.utils import timezone as django_timezone
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
|
||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
@@ -1936,6 +1936,53 @@ class ApiUserRateLimiter:
|
|||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
UserRequests.objects.create(user=user, slug=self.slug)
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
def check_websocket(self, websocket: WebSocket):
|
||||||
|
"""WebSocket-specific rate limiting method"""
|
||||||
|
# Rate limiting disabled if billing is disabled
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
user: KhojUser = websocket.scope["user"].object
|
||||||
|
subscribed = has_required_scope(websocket, ["premium"])
|
||||||
|
|
||||||
|
# Remove requests outside of the time window
|
||||||
|
cutoff = django_timezone.now() - timedelta(seconds=self.window)
|
||||||
|
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||||
|
|
||||||
|
# Check if the user has exceeded the rate limit
|
||||||
|
if subscribed and count_requests >= self.subscribed_requests:
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
||||||
|
)
|
||||||
|
if not subscribed and count_requests >= self.requests:
|
||||||
|
if self.requests >= self.subscribed_requests:
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the current request to the cache
|
||||||
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
class ApiImageRateLimiter:
|
class ApiImageRateLimiter:
|
||||||
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||||
@@ -1983,6 +2030,47 @@ class ApiImageRateLimiter:
|
|||||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
|
||||||
|
"""WebSocket-specific image rate limiting method"""
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not body.images:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check number of images
|
||||||
|
if len(body.images) > self.max_images:
|
||||||
|
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total size of images
|
||||||
|
total_size_mb = 0.0
|
||||||
|
for image in body.images:
|
||||||
|
# Unquote the image in case it's URL encoded
|
||||||
|
image = unquote(image)
|
||||||
|
# Assuming the image is a base64 encoded string
|
||||||
|
# Remove the data:image/jpeg;base64, part if present
|
||||||
|
if "," in image:
|
||||||
|
image = image.split(",", 1)[1]
|
||||||
|
|
||||||
|
# Decode base64 to get the actual size
|
||||||
|
image_bytes = base64.b64decode(image)
|
||||||
|
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||||
|
|
||||||
|
if total_size_mb > self.max_combined_size_mb:
|
||||||
|
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommandRateLimiter:
|
class ConversationCommandRateLimiter:
|
||||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||||
@@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter:
|
|||||||
self.subscribed_rate_limit = subscribed_rate_limit
|
self.subscribed_rate_limit = subscribed_rate_limit
|
||||||
self.restricted_commands = [ConversationCommand.Research]
|
self.restricted_commands = [ConversationCommand.Research]
|
||||||
|
|
||||||
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
|
||||||
if state.billing_enabled is False:
|
if state.billing_enabled is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user