Add support for Google Imagen AI models for image generation

Use the new Google GenAI client to generate images with Imagen
This commit is contained in:
Debanjum
2025-03-11 22:52:20 +05:30
parent bd06fcd9be
commit 7bb6facdea
4 changed files with 70 additions and 4 deletions

View File

@@ -0,0 +1,26 @@
# Generated by Django 5.0.10 on 2025-03-11 16:58
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0085_alter_agent_output_modes"),
]
operations = [
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[
("openai", "Openai"),
("stability-ai", "Stabilityai"),
("replicate", "Replicate"),
("google", "Google"),
],
default="openai",
max_length=200,
),
),
]

View File

@@ -530,6 +530,7 @@ class TextToImageModelConfig(DbBaseModel):
OPENAI = "openai"
STABILITYAI = "stability-ai"
REPLICATE = "replicate"
GOOGLE = "google"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
@@ -547,11 +548,11 @@ class TextToImageModelConfig(DbBaseModel):
error[
"ai_model_api"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
if self.model_type != self.ModelType.OPENAI:
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI models."
error["api_key"] = "The API key field must be set for non OpenAI, non Google models."
if self.ai_model_api:
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI models."
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI, non Google models."
if error:
raise ValidationError(error)

View File

@@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Optional
import openai
import requests
from google import genai
from google.genai import types as gtypes
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
@@ -86,6 +88,8 @@ async def text_to_image(
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -99,6 +103,12 @@ async def text_to_image(
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message
return
except ValueError as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to an unknown error"
status_code = 500
yield image_url or image, status_code, message
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
@@ -215,3 +225,28 @@ def generate_image_with_replicate(
# Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
return io.BytesIO(requests.get(image_url).content).getvalue()
def generate_image_with_google(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"""Generate image using Google's AI over API"""
# Initialize the Google AI client
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
client = genai.Client(api_key=api_key)
# Configure image generation settings
config = gtypes.GenerateImagesConfig(number_of_images=1)
# Call the Gemini API to generate the image
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
if not response.generated_images:
raise ValueError("Failed to generate image using Google AI")
# Extract the image bytes from the first generated image
image_bytes = response.generated_images[0].image.image_bytes
# Convert to webp for faster loading
return convert_image_to_webp(image_bytes)

View File

@@ -1092,7 +1092,11 @@ async def generate_better_image_prompt(
online_results=simplified_online_results,
personality_context=personality_context,
)
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
elif model_type in [
TextToImageModelConfig.ModelType.STABILITYAI,
TextToImageModelConfig.ModelType.REPLICATE,
TextToImageModelConfig.ModelType.GOOGLE,
]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,