From 7eaf0e80c57b8a3fd5095b49580869dfe60de7c6 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 30 May 2025 16:40:53 -0700 Subject: [PATCH] Get max prompt size for given user, model via reusable functions --- src/khoj/database/adapters/__init__.py | 20 ++++++++++++++++++++ src/khoj/routers/helpers.py | 14 ++------------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index b3f6efe8..ac121ff9 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1312,6 +1312,26 @@ class ConversationAdapters: else: ServerChatSettings.objects.create(chat_default=chat_model, chat_advanced=chat_model) + @staticmethod + def get_max_context_size(chat_model: ChatModel, user: KhojUser) -> int | None: + """Get the max context size for the user based on the chat model.""" + subscribed = is_user_subscribed(user) if user else False + if subscribed and chat_model.subscribed_max_prompt_size: + max_tokens = chat_model.subscribed_max_prompt_size + else: + max_tokens = chat_model.max_prompt_size + return max_tokens + + @staticmethod + async def aget_max_context_size(chat_model: ChatModel, user: KhojUser) -> int | None: + """Get the max context size for the user based on the chat model.""" + subscribed = await ais_user_subscribed(user) if user else False + if subscribed and chat_model.subscribed_max_prompt_size: + max_tokens = chat_model.subscribed_max_prompt_size + else: + max_tokens = chat_model.max_prompt_size + return max_tokens + @staticmethod async def aget_server_webscraper(): server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 52610726..6d0d0064 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1175,12 +1175,7 @@ async def send_message_to_model_wrapper( if vision_available and query_images: logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.") - subscribed = await ais_user_subscribed(user) if user else False - max_tokens = ( - chat_model.subscribed_max_prompt_size - if subscribed and chat_model.subscribed_max_prompt_size - else chat_model.max_prompt_size - ) + max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user) chat_model_name = chat_model.name tokenizer = chat_model.tokenizer model_type = chat_model.model_type @@ -1272,12 +1267,7 @@ def send_message_to_model_wrapper_sync( if chat_model is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") - subscribed = is_user_subscribed(user) if user else False - max_tokens = ( - chat_model.subscribed_max_prompt_size - if subscribed and chat_model.subscribed_max_prompt_size - else chat_model.max_prompt_size - ) + max_tokens = ConversationAdapters.get_max_context_size(chat_model, user) chat_model_name = chat_model.name model_type = chat_model.model_type vision_available = chat_model.vision_enabled