mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 05:39:12 +00:00
Dedupe code to format messages before sending to appropriate chat model
Fallback to assume not a subscribed user if user not passed. This allows user arg to be actually optional in the async send_message_to_model_wrapper function
This commit is contained in:
@@ -1166,35 +1166,40 @@ async def send_message_to_model_wrapper(
|
|||||||
if vision_available and query_images:
|
if vision_available and query_images:
|
||||||
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
|
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
|
||||||
|
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user) if user else False
|
||||||
chat_model_name = chat_model.name
|
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
chat_model.subscribed_max_prompt_size
|
chat_model.subscribed_max_prompt_size
|
||||||
if subscribed and chat_model.subscribed_max_prompt_size
|
if subscribed and chat_model.subscribed_max_prompt_size
|
||||||
else chat_model.max_prompt_size
|
else chat_model.max_prompt_size
|
||||||
)
|
)
|
||||||
|
chat_model_name = chat_model.name
|
||||||
tokenizer = chat_model.tokenizer
|
tokenizer = chat_model.tokenizer
|
||||||
model_type = chat_model.model_type
|
model_type = chat_model.model_type
|
||||||
vision_available = chat_model.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
|
loaded_model = None
|
||||||
|
|
||||||
if model_type == ChatModel.ModelType.OFFLINE:
|
if model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=query,
|
|
||||||
context_message=context,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=query,
|
||||||
|
context_message=context,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model_name,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
query_images=query_images,
|
||||||
|
model_type=model_type,
|
||||||
|
query_files=query_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == ChatModel.ModelType.OFFLINE:
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
@@ -1206,22 +1211,6 @@ async def send_message_to_model_wrapper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == ChatModel.ModelType.OPENAI:
|
elif model_type == ChatModel.ModelType.OPENAI:
|
||||||
openai_chat_config = chat_model.ai_model_api
|
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
api_base_url = openai_chat_config.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=query,
|
|
||||||
context_message=context,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
query_images=query_images,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
return send_message_to_model(
|
return send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -1233,21 +1222,6 @@ async def send_message_to_model_wrapper(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=query,
|
|
||||||
context_message=context,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
query_images=query_images,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -1258,21 +1232,6 @@ async def send_message_to_model_wrapper(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=query,
|
|
||||||
context_message=context,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
query_images=query_images,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -1302,27 +1261,37 @@ def send_message_to_model_wrapper_sync(
|
|||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
|
subscribed = ais_user_subscribed(user) if user else False
|
||||||
|
max_tokens = (
|
||||||
|
chat_model.subscribed_max_prompt_size
|
||||||
|
if subscribed and chat_model.subscribed_max_prompt_size
|
||||||
|
else chat_model.max_prompt_size
|
||||||
|
)
|
||||||
chat_model_name = chat_model.name
|
chat_model_name = chat_model.name
|
||||||
max_tokens = chat_model.max_prompt_size
|
model_type = chat_model.model_type
|
||||||
vision_available = chat_model.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
|
api_key = chat_model.ai_model_api.api_key
|
||||||
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
|
loaded_model = None
|
||||||
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
if model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model_name,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
model_type=model_type,
|
||||||
|
query_images=query_images,
|
||||||
|
query_files=query_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == ChatModel.ModelType.OFFLINE:
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
@@ -1333,20 +1302,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
elif model_type == ChatModel.ModelType.OPENAI:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
return send_message_to_model(
|
return send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -1357,20 +1313,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -1380,20 +1323,7 @@ def send_message_to_model_wrapper_sync(
|
|||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = chat_model.ai_model_api.api_key
|
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model_name,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=chat_model.model_type,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|||||||
Reference in New Issue
Block a user