diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5d881ce5..f0e2d974 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1166,35 +1166,40 @@ async def send_message_to_model_wrapper( if vision_available and query_images: logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.") - subscribed = await ais_user_subscribed(user) - chat_model_name = chat_model.name + subscribed = await ais_user_subscribed(user) if user else False max_tokens = ( chat_model.subscribed_max_prompt_size if subscribed and chat_model.subscribed_max_prompt_size else chat_model.max_prompt_size ) + chat_model_name = chat_model.name tokenizer = chat_model.tokenizer model_type = chat_model.model_type vision_available = chat_model.vision_enabled + api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url + loaded_model = None if model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model - truncated_messages = generate_chatml_messages_with_context( - user_message=query, - context_message=context, - system_message=system_message, - model_name=chat_model_name, - loaded_model=loaded_model, - tokenizer_name=tokenizer, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=chat_model.model_type, - query_files=query_files, - ) + truncated_messages = generate_chatml_messages_with_context( + user_message=query, + context_message=context, + system_message=system_message, + model_name=chat_model_name, + loaded_model=loaded_model, + tokenizer_name=tokenizer, + max_prompt_size=max_tokens, + vision_enabled=vision_available, + query_images=query_images, + model_type=model_type, + query_files=query_files, + ) + + if model_type == ChatModel.ModelType.OFFLINE: return send_message_to_model_offline( messages=truncated_messages, loaded_model=loaded_model, @@ -1206,22 +1211,6 @@ async def send_message_to_model_wrapper( ) elif model_type == ChatModel.ModelType.OPENAI: - openai_chat_config = chat_model.ai_model_api - api_key = openai_chat_config.api_key - api_base_url = openai_chat_config.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=query, - context_message=context, - system_message=system_message, - model_name=chat_model_name, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - query_images=query_images, - model_type=chat_model.model_type, - query_files=query_files, - ) - return send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1233,21 +1222,6 @@ async def send_message_to_model_wrapper( tracer=tracer, ) elif model_type == ChatModel.ModelType.ANTHROPIC: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=query, - context_message=context, - system_message=system_message, - model_name=chat_model_name, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - query_images=query_images, - model_type=chat_model.model_type, - query_files=query_files, - ) - return anthropic_send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1258,21 +1232,6 @@ async def send_message_to_model_wrapper( tracer=tracer, ) elif model_type == ChatModel.ModelType.GOOGLE: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=query, - context_message=context, - system_message=system_message, - model_name=chat_model_name, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - query_images=query_images, - model_type=chat_model.model_type, - query_files=query_files, - ) - return gemini_send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1302,27 +1261,37 @@ def send_message_to_model_wrapper_sync( if chat_model is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") + subscribed = ais_user_subscribed(user) if user else False + max_tokens = ( + chat_model.subscribed_max_prompt_size + if subscribed and chat_model.subscribed_max_prompt_size + else chat_model.max_prompt_size + ) chat_model_name = chat_model.name - max_tokens = chat_model.max_prompt_size + model_type = chat_model.model_type vision_available = chat_model.vision_enabled + api_key = chat_model.ai_model_api.api_key + api_base_url = chat_model.ai_model_api.api_base_url + loaded_model = None - if chat_model.model_type == ChatModel.ModelType.OFFLINE: + if model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model_name, - loaded_model=loaded_model, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=chat_model.model_type, - query_images=query_images, - query_files=query_files, - ) + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model_name, + loaded_model=loaded_model, + max_prompt_size=max_tokens, + vision_enabled=vision_available, + model_type=model_type, + query_images=query_images, + query_files=query_files, + ) + + if model_type == ChatModel.ModelType.OFFLINE: return send_message_to_model_offline( messages=truncated_messages, loaded_model=loaded_model, @@ -1333,20 +1302,7 @@ def send_message_to_model_wrapper_sync( tracer=tracer, ) - elif chat_model.model_type == ChatModel.ModelType.OPENAI: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model_name, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=chat_model.model_type, - query_images=query_images, - query_files=query_files, - ) - + elif model_type == ChatModel.ModelType.OPENAI: return send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1357,20 +1313,7 @@ def send_message_to_model_wrapper_sync( tracer=tracer, ) - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model_name, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=chat_model.model_type, - query_images=query_images, - query_files=query_files, - ) - + elif model_type == ChatModel.ModelType.ANTHROPIC: return anthropic_send_message_to_model( messages=truncated_messages, api_key=api_key, @@ -1380,20 +1323,7 @@ def send_message_to_model_wrapper_sync( tracer=tracer, ) - elif chat_model.model_type == ChatModel.ModelType.GOOGLE: - api_key = chat_model.ai_model_api.api_key - api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model_name, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=chat_model.model_type, - query_images=query_images, - query_files=query_files, - ) - + elif model_type == ChatModel.ModelType.GOOGLE: return gemini_send_message_to_model( messages=truncated_messages, api_key=api_key,