Support image generation with Gemini Nano Banana

This commit is contained in:
Debanjum
2025-11-18 19:11:32 -08:00
parent dd4381c25c
commit da493be417
3 changed files with 151 additions and 41 deletions

View File

@@ -87,7 +87,7 @@ dependencies = [
"django_apscheduler == 0.7.0", "django_apscheduler == 0.7.0",
"anthropic == 0.52.0", "anthropic == 0.52.0",
"docx2txt == 0.8", "docx2txt == 0.8",
"google-genai == 1.51.0", "google-genai == 1.52.0",
"google-auth ~= 2.23.3", "google-auth ~= 2.23.3",
"pyjson5 == 1.6.7", "pyjson5 == 1.6.7",
"resend == 1.0.1", "resend == 1.0.1",

View File

@@ -27,10 +27,11 @@ from khoj.database.models import (
TextToImageModelConfig, TextToImageModelConfig,
) )
from khoj.processor.conversation.google.utils import _is_retryable_error from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
from khoj.routers.storage import upload_generated_image_to_bucket from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import convert_image_to_webp, timer from khoj.utils.helpers import convert_image_to_webp, is_none_or_empty, timer
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -74,27 +75,31 @@ async def text_to_image(
elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]: elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)] image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
# Generate a better image prompt # Generate a better image prompt
# Use the user's message, chat history, and other context # Use the user's message, chat history, and other context
image_prompt_response = await generate_better_image_prompt( if not is_multimodal_model(text2image_model):
message, if send_status_func:
image_chat_history, async for event in send_status_func("**Enhancing the Painting Prompt**"):
location_data=location_data, yield {ChatEvent.STATUS: event}
note_references=references,
online_results=online_results, image_prompt_response = await generate_better_image_prompt(
model_type=text_to_image_config.model_type, message,
query_images=query_images, image_chat_history,
user=user, location_data=location_data,
agent=agent, note_references=references,
query_files=query_files, online_results=online_results,
tracer=tracer, model_type=text_to_image_config.model_type,
) query_images=query_images,
image_prompt = image_prompt_response["description"] user=user,
image_shape = image_prompt_response["shape"] agent=agent,
query_files=query_files,
tracer=tracer,
)
image_prompt = image_prompt_response["description"]
image_shape = image_prompt_response["shape"]
else:
image_prompt = message
image_shape = None
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"): async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
@@ -115,7 +120,12 @@ async def text_to_image(
) )
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE: elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
webp_image_bytes = generate_image_with_google( webp_image_bytes = generate_image_with_google(
image_prompt, text_to_image_config, text2image_model, image_shape image_prompt,
text_to_image_config,
text2image_model,
image_shape,
chat_history=chat_history,
query_images=query_images,
) )
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
@@ -322,6 +332,8 @@ def generate_image_with_google(
text_to_image_config: TextToImageModelConfig, text_to_image_config: TextToImageModelConfig,
text2image_model: str, text2image_model: str,
shape: ImageShape = ImageShape.SQUARE, shape: ImageShape = ImageShape.SQUARE,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = [],
): ):
"""Generate image using Google's AI over API""" """Generate image using Google's AI over API"""
@@ -337,24 +349,122 @@ def generate_image_with_google(
else: # Square else: # Square
aspect_ratio = "1:1" aspect_ratio = "1:1"
# Configure image generation settings image_bytes = None
config = gtypes.GenerateImagesConfig( if is_multimodal_model(text2image_model):
number_of_images=1, # Format chat history for Gemini
safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE, contents = format_messages_for_gemini(improved_image_prompt, text2image_model, chat_history, query_images)
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
include_rai_reason=True,
output_mime_type="image/png",
aspect_ratio=aspect_ratio,
)
# Call the Gemini API to generate the image # Configure image generation settings
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config) config = gtypes.GenerateContentConfig(
response_modalities=["IMAGE"], image_config=gtypes.ImageConfig(aspect_ratio=None)
)
if not response.generated_images: # Call the Gemini API to generate the image
raise ValueError("Failed to generate image using Google AI") response = client.models.generate_content(
contents=contents,
model=text2image_model,
config=config,
)
# Extract the image bytes from the first generated image # Extract the image bytes from the first generated image
image_bytes = response.generated_images[0].image.image_bytes for part in response.parts or []:
if part.inline_data is not None:
image = part.as_image()
image_bytes = image.image_bytes
break
if not image_bytes:
raise ValueError("Failed to generate image using Google AI")
else:
# Configure image generation settings
config = gtypes.GenerateImagesConfig(
number_of_images=1,
safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
include_rai_reason=True,
output_mime_type="image/png",
aspect_ratio=aspect_ratio,
)
# Call the Gemini API to generate the image
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
if not response.generated_images:
raise ValueError("Failed to generate image using Google AI")
# Extract the image bytes from the first generated image
image_bytes = response.generated_images[0].image.image_bytes
# Convert to webp for faster loading # Convert to webp for faster loading
return convert_image_to_webp(image_bytes) return convert_image_to_webp(image_bytes)
def format_messages_for_gemini(
improved_image_prompt: str,
text2image_model: str,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = [],
) -> List[gtypes.Content]:
"""Format chat messages for Gemini multimodal models.
Reframes assistant messages with generated images as user messages to enable
multi-turn image editing with gemini 3 models.
"""
contents = []
for chat in chat_history:
role = "model" if chat.by == "khoj" else "user"
parts = []
# Reframe assistant messages to gemini 3 as user messages
# This enables multi-turn image edits without storing, passing thought_signature required by gemini 3 models
if role == "model" and text2image_model.startswith("gemini-3"):
if chat.images:
parts.append(gtypes.Part.from_text(text="This is the image you previously generated:"))
for image_data in chat.images:
if image_data.startswith("http"):
image = get_image_from_url(image_data, type="bytes")
else:
image = get_image_from_base64(image_data, type="bytes")
parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
else:
parts.append(gtypes.Part.from_text(text="This is the message you previously sent:"))
messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item]
for text in messages:
if isinstance(text, dict) and not is_none_or_empty(text.get("text")):
parts.append(gtypes.Part.from_text(text=text.get("text")))
elif isinstance(text, str):
parts.append(gtypes.Part.from_text(text=text))
contents.append(gtypes.Content(role="user", parts=parts))
continue
# Handle regular messages
for image_data in chat.images or []:
if image_data.startswith("http"):
image = get_image_from_url(image_data, type="bytes")
else:
image = get_image_from_base64(image_data, type="bytes")
parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item]
for text in messages:
if isinstance(text, dict) and not is_none_or_empty(text.get("text")):
parts.append(gtypes.Part.from_text(text=text.get("text")))
elif isinstance(text, str):
parts.append(gtypes.Part.from_text(text=text))
contents.append(gtypes.Content(role=role, parts=parts))
query_parts = []
for img in query_images or []:
if img.startswith("http"):
image = get_image_from_url(img, type="bytes")
else:
image = get_image_from_base64(img, type="bytes")
query_parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
query_parts.append(gtypes.Part.from_text(text=improved_image_prompt))
contents += [gtypes.Content(role="user", parts=query_parts)]
return contents
def is_multimodal_model(model_name: str) -> bool:
"""Check if the model can see and generate images"""
multimodal_models = ["gemini-2.5-flash-image", "gemini-3-pro-image-preview"]
return model_name.lower() in multimodal_models

8
uv.lock generated
View File

@@ -871,7 +871,7 @@ wheels = [
[[package]] [[package]]
name = "google-genai" name = "google-genai"
version = "1.51.0" version = "1.52.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
@@ -883,9 +883,9 @@ dependencies = [
{ name = "typing-extensions" }, { name = "typing-extensions" },
{ name = "websockets" }, { name = "websockets" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/c3/1c/29245699c7c274ed5709b33b6a5192af2d57da5da3d2f189f222d1895336/google_genai-1.51.0.tar.gz", hash = "sha256:596c1ec964b70fec17a6ccfe6ee4edede31022584e8b1d33371d93037c4001b1", size = 258060, upload-time = "2025-11-18T05:32:47.068Z" } sdist = { url = "https://files.pythonhosted.org/packages/09/4e/0ad8585d05312074bb69711b2d81cfed69ce0ae441913d57bf169bed20a7/google_genai-1.52.0.tar.gz", hash = "sha256:a74e8a4b3025f23aa98d6a0f84783119012ca6c336fd68f73c5d2b11465d7fc5", size = 258743, upload-time = "2025-11-21T02:18:55.742Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/28/0185dcda66f1994171067cfdb0e44a166450239d5b11b3a8a281dd2da459/google_genai-1.51.0-py3-none-any.whl", hash = "sha256:bfb7d0c6ba48ba9bda539f0d5e69dad827d8735a8b1e4703bafa0a2945d293e1", size = 260483, upload-time = "2025-11-18T05:32:45.938Z" }, { url = "https://files.pythonhosted.org/packages/ec/66/03f663e7bca7abe9ccfebe6cb3fe7da9a118fd723a5abb278d6117e7990e/google_genai-1.52.0-py3-none-any.whl", hash = "sha256:c8352b9f065ae14b9322b949c7debab8562982f03bf71d44130cd2b798c20743", size = 261219, upload-time = "2025-11-21T02:18:54.515Z" },
] ]
[[package]] [[package]]
@@ -1312,7 +1312,7 @@ requires-dist = [
{ name = "freezegun", marker = "extra == 'dev'", specifier = ">=1.2.0" }, { name = "freezegun", marker = "extra == 'dev'", specifier = ">=1.2.0" },
{ name = "gitpython", marker = "extra == 'dev'", specifier = "~=3.1.43" }, { name = "gitpython", marker = "extra == 'dev'", specifier = "~=3.1.43" },
{ name = "google-auth", specifier = "~=2.23.3" }, { name = "google-auth", specifier = "~=2.23.3" },
{ name = "google-genai", specifier = "==1.51.0" }, { name = "google-genai", specifier = "==1.52.0" },
{ name = "gunicorn", marker = "extra == 'dev'", specifier = "==22.0.0" }, { name = "gunicorn", marker = "extra == 'dev'", specifier = "==22.0.0" },
{ name = "gunicorn", marker = "extra == 'prod'", specifier = "==22.0.0" }, { name = "gunicorn", marker = "extra == 'prod'", specifier = "==22.0.0" },
{ name = "httpx", specifier = "==0.28.1" }, { name = "httpx", specifier = "==0.28.1" },