From 9314f0a398237229fca9881d9e7137bd991a0e1a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 13 Oct 2024 02:59:10 -0700 Subject: [PATCH] Fix default chat configs to use user model if no server chat model set Post merge cleanup in advanced reasoning to fallback to user chat model if no server chat model defined for advanced and default --- src/khoj/processor/conversation/helpers.py | 5 +++-- src/khoj/processor/tools/run_code.py | 2 +- src/khoj/routers/research.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/helpers.py b/src/khoj/processor/conversation/helpers.py index 06a8557c..4b7e472c 100644 --- a/src/khoj/processor/conversation/helpers.py +++ b/src/khoj/processor/conversation/helpers.py @@ -18,11 +18,11 @@ async def send_message_to_model_wrapper( system_message: str = "", response_type: str = "text", chat_model_option: ChatModelOptions = None, - subscribed: bool = False, + user: KhojUser = None, uploaded_image_url: str = None, ): conversation_config: ChatModelOptions = ( - chat_model_option or await ConversationAdapters.aget_default_conversation_config() + chat_model_option or await ConversationAdapters.aget_default_conversation_config(user) ) vision_available = conversation_config.vision_enabled @@ -32,6 +32,7 @@ async def send_message_to_model_wrapper( conversation_config = vision_enabled_config vision_available = True + subscribed = await ais_user_subscribed(user) chat_model = conversation_config.chat_model max_tokens = ( conversation_config.subscribed_max_prompt_size diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 384b993c..9da04237 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -97,7 +97,7 @@ async def generate_python_code( code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", - subscribed=subscribed, + user=user, ) # Validate that the response is a non-empty, JSON-serializable list diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 60c22c80..ed43c864 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) async def apick_next_tool( query: str, conversation_history: dict, - subscribed: bool, + user: KhojUser = None, uploaded_image_url: str = None, location: LocationData = None, user_name: str = None, @@ -86,13 +86,13 @@ async def apick_next_tool( max_iterations=max_iterations, ) - chat_model_option = await ConversationAdapters.aget_advanced_conversation_config() + chat_model_option = await ConversationAdapters.aget_advanced_conversation_config(user) with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( function_planning_prompt, response_type="json_object", - subscribed=subscribed, + user=user, chat_model_option=chat_model_option, ) @@ -148,7 +148,7 @@ async def execute_information_collection( this_iteration = await apick_next_tool( query, conversation_history, - subscribed, + user, uploaded_image_url, location, user_name,