Support image generation with Gemini Nano Banana

This commit is contained in:
Debanjum
2025-11-18 19:11:32 -08:00
parent dd4381c25c
commit da493be417
3 changed files with 151 additions and 41 deletions

View File

@@ -27,10 +27,11 @@ from khoj.database.models import (
TextToImageModelConfig,
)
from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import convert_image_to_webp, timer
from khoj.utils.helpers import convert_image_to_webp, is_none_or_empty, timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@@ -74,27 +75,31 @@ async def text_to_image(
elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
# Generate a better image prompt
# Use the user's message, chat history, and other context
image_prompt_response = await generate_better_image_prompt(
message,
image_chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
image_prompt = image_prompt_response["description"]
image_shape = image_prompt_response["shape"]
if not is_multimodal_model(text2image_model):
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
image_prompt_response = await generate_better_image_prompt(
message,
image_chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
image_prompt = image_prompt_response["description"]
image_shape = image_prompt_response["shape"]
else:
image_prompt = message
image_shape = None
if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
@@ -115,7 +120,12 @@ async def text_to_image(
)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
webp_image_bytes = generate_image_with_google(
image_prompt, text_to_image_config, text2image_model, image_shape
image_prompt,
text_to_image_config,
text2image_model,
image_shape,
chat_history=chat_history,
query_images=query_images,
)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
@@ -322,6 +332,8 @@ def generate_image_with_google(
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = [],
):
"""Generate image using Google's AI over API"""
@@ -337,24 +349,122 @@ def generate_image_with_google(
else: # Square
aspect_ratio = "1:1"
# Configure image generation settings
config = gtypes.GenerateImagesConfig(
number_of_images=1,
safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
include_rai_reason=True,
output_mime_type="image/png",
aspect_ratio=aspect_ratio,
)
image_bytes = None
if is_multimodal_model(text2image_model):
# Format chat history for Gemini
contents = format_messages_for_gemini(improved_image_prompt, text2image_model, chat_history, query_images)
# Call the Gemini API to generate the image
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
# Configure image generation settings
config = gtypes.GenerateContentConfig(
response_modalities=["IMAGE"], image_config=gtypes.ImageConfig(aspect_ratio=None)
)
if not response.generated_images:
raise ValueError("Failed to generate image using Google AI")
# Call the Gemini API to generate the image
response = client.models.generate_content(
contents=contents,
model=text2image_model,
config=config,
)
# Extract the image bytes from the first generated image
image_bytes = response.generated_images[0].image.image_bytes
# Extract the image bytes from the first generated image
for part in response.parts or []:
if part.inline_data is not None:
image = part.as_image()
image_bytes = image.image_bytes
break
if not image_bytes:
raise ValueError("Failed to generate image using Google AI")
else:
# Configure image generation settings
config = gtypes.GenerateImagesConfig(
number_of_images=1,
safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
include_rai_reason=True,
output_mime_type="image/png",
aspect_ratio=aspect_ratio,
)
# Call the Gemini API to generate the image
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
if not response.generated_images:
raise ValueError("Failed to generate image using Google AI")
# Extract the image bytes from the first generated image
image_bytes = response.generated_images[0].image.image_bytes
# Convert to webp for faster loading
return convert_image_to_webp(image_bytes)
def format_messages_for_gemini(
improved_image_prompt: str,
text2image_model: str,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = [],
) -> List[gtypes.Content]:
"""Format chat messages for Gemini multimodal models.
Reframes assistant messages with generated images as user messages to enable
multi-turn image editing with gemini 3 models.
"""
contents = []
for chat in chat_history:
role = "model" if chat.by == "khoj" else "user"
parts = []
# Reframe assistant messages to gemini 3 as user messages
# This enables multi-turn image edits without storing, passing thought_signature required by gemini 3 models
if role == "model" and text2image_model.startswith("gemini-3"):
if chat.images:
parts.append(gtypes.Part.from_text(text="This is the image you previously generated:"))
for image_data in chat.images:
if image_data.startswith("http"):
image = get_image_from_url(image_data, type="bytes")
else:
image = get_image_from_base64(image_data, type="bytes")
parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
else:
parts.append(gtypes.Part.from_text(text="This is the message you previously sent:"))
messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item]
for text in messages:
if isinstance(text, dict) and not is_none_or_empty(text.get("text")):
parts.append(gtypes.Part.from_text(text=text.get("text")))
elif isinstance(text, str):
parts.append(gtypes.Part.from_text(text=text))
contents.append(gtypes.Content(role="user", parts=parts))
continue
# Handle regular messages
for image_data in chat.images or []:
if image_data.startswith("http"):
image = get_image_from_url(image_data, type="bytes")
else:
image = get_image_from_base64(image_data, type="bytes")
parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item]
for text in messages:
if isinstance(text, dict) and not is_none_or_empty(text.get("text")):
parts.append(gtypes.Part.from_text(text=text.get("text")))
elif isinstance(text, str):
parts.append(gtypes.Part.from_text(text=text))
contents.append(gtypes.Content(role=role, parts=parts))
query_parts = []
for img in query_images or []:
if img.startswith("http"):
image = get_image_from_url(img, type="bytes")
else:
image = get_image_from_base64(img, type="bytes")
query_parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
query_parts.append(gtypes.Part.from_text(text=improved_image_prompt))
contents += [gtypes.Content(role="user", parts=query_parts)]
return contents
def is_multimodal_model(model_name: str) -> bool:
"""Check if the model can see and generate images"""
multimodal_models = ["gemini-2.5-flash-image", "gemini-3-pro-image-preview"]
return model_name.lower() in multimodal_models