Add websocket chat api endpoint to ease bi-directional communication

This commit is contained in:
Debanjum
2025-06-18 16:45:09 -07:00
parent 99ed796c00
commit 38dd85c91f
3 changed files with 939 additions and 708 deletions

View File

@@ -220,7 +220,16 @@ def set_state(args):
def start_server(app, host=None, port=None, socket=None): def start_server(app, host=None, port=None, socket=None):
logger.info("🌖 Khoj is ready to engage") logger.info("🌖 Khoj is ready to engage")
if socket: if socket:
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None) uvicorn.run(
app,
proxy_headers=True,
uds=socket,
log_level="debug" if state.verbose > 1 else "info",
use_colors=True,
log_config=None,
ws_ping_timeout=300,
timeout_keep_alive=60,
)
else: else:
uvicorn.run( uvicorn.run(
app, app,
@@ -229,6 +238,7 @@ def start_server(app, host=None, port=None, socket=None):
log_level="debug" if state.verbose > 1 else "info", log_level="debug" if state.verbose > 1 else "info",
use_colors=True, use_colors=True,
log_config=None, log_config=None,
ws_ping_timeout=300,
timeout_keep_alive=60, timeout_keep_alive=60,
**state.ssl_config if state.ssl_config else {}, **state.ssl_config if state.ssl_config else {},
) )

File diff suppressed because it is too large Load Diff

View File

@@ -33,7 +33,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.utils import timezone as django_timezone from django.utils import timezone as django_timezone
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from starlette.requests import URL from starlette.requests import URL
@@ -1936,6 +1936,53 @@ class ApiUserRateLimiter:
# Add the current request to the cache # Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug) UserRequests.objects.create(user=user, slug=self.slug)
def check_websocket(self, websocket: WebSocket):
"""WebSocket-specific rate limiting method"""
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
# Remove requests outside of the time window
cutoff = django_timezone.now() - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
logger.info(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
)
# Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug)
class ApiImageRateLimiter: class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10): def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
@@ -1983,6 +2030,47 @@ class ApiImageRateLimiter:
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.", detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
) )
def check_websocket(self, websocket: WebSocket, body: ChatRequestBody):
"""WebSocket-specific image rate limiting method"""
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
if not websocket.scope.get("user") or not websocket.scope["user"].is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class ConversationCommandRateLimiter: class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
@@ -1991,7 +2079,7 @@ class ConversationCommandRateLimiter:
self.subscribed_rate_limit = subscribed_rate_limit self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Research] self.restricted_commands = [ConversationCommand.Research]
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): async def update_and_check_if_valid(self, request: Request | WebSocket, conversation_command: ConversationCommand):
if state.billing_enabled is False: if state.billing_enabled is False:
return return