Prefer explicitly configured OpenAI API url, key for image gen model

Previously we'd use the general openai client, even if the image
generation model has a different api key and base url set.

This change uses the openai config of the image generation models when
set. Otherwise it fallbacks to use the first openai api provider set
This commit is contained in:
Debanjum
2025-01-15 17:55:21 +07:00
parent 24204873c8
commit 182c49b41c

View File

@@ -119,25 +119,27 @@ async def text_to_image(
def generate_image_with_openai( def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
): ):
"Generate image using OpenAI API" "Generate image using OpenAI (compatible) API"
# Get the API key from the user's configuration # Get the API config from the user's configuration
api_key = None
if text_to_image_config.api_key: if text_to_image_config.api_key:
api_key = text_to_image_config.api_key api_key = text_to_image_config.api_key
openai_client = openai.OpenAI(api_key=api_key)
elif text_to_image_config.ai_model_api: elif text_to_image_config.ai_model_api:
api_key = text_to_image_config.ai_model_api.api_key api_key = text_to_image_config.ai_model_api.api_key
api_base_url = text_to_image_config.ai_model_api.api_base_url
openai_client = openai.OpenAI(api_key=api_key, base_url=api_base_url)
elif state.openai_client: elif state.openai_client:
api_key = state.openai_client.api_key openai_client = state.openai_client
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
# Generate image using OpenAI API # Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid" OPENAI_IMAGE_GEN_STYLE = "vivid"
response = state.openai_client.images.generate( response = openai_client.images.generate(
prompt=improved_image_prompt, prompt=improved_image_prompt,
model=text2image_model, model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE, style=OPENAI_IMAGE_GEN_STYLE,
response_format="b64_json", response_format="b64_json",
extra_headers=auth_header,
) )
# Extract the base64 image from the response # Extract the base64 image from the response