mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 05:29:12 +00:00
Support image generation with Gemini Nano Banana
This commit is contained in:
@@ -87,7 +87,7 @@ dependencies = [
|
||||
"django_apscheduler == 0.7.0",
|
||||
"anthropic == 0.52.0",
|
||||
"docx2txt == 0.8",
|
||||
"google-genai == 1.51.0",
|
||||
"google-genai == 1.52.0",
|
||||
"google-auth ~= 2.23.3",
|
||||
"pyjson5 == 1.6.7",
|
||||
"resend == 1.0.1",
|
||||
|
||||
@@ -27,10 +27,11 @@ from khoj.database.models import (
|
||||
TextToImageModelConfig,
|
||||
)
|
||||
from khoj.processor.conversation.google.utils import _is_retryable_error
|
||||
from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url
|
||||
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
|
||||
from khoj.routers.storage import upload_generated_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import convert_image_to_webp, timer
|
||||
from khoj.utils.helpers import convert_image_to_webp, is_none_or_empty, timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,27 +75,31 @@ async def text_to_image(
|
||||
elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
|
||||
image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)]
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
# Generate a better image prompt
|
||||
# Use the user's message, chat history, and other context
|
||||
image_prompt_response = await generate_better_image_prompt(
|
||||
message,
|
||||
image_chat_history,
|
||||
location_data=location_data,
|
||||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
image_prompt = image_prompt_response["description"]
|
||||
image_shape = image_prompt_response["shape"]
|
||||
if not is_multimodal_model(text2image_model):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
image_prompt_response = await generate_better_image_prompt(
|
||||
message,
|
||||
image_chat_history,
|
||||
location_data=location_data,
|
||||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
image_prompt = image_prompt_response["description"]
|
||||
image_shape = image_prompt_response["shape"]
|
||||
else:
|
||||
image_prompt = message
|
||||
image_shape = None
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
|
||||
@@ -115,7 +120,12 @@ async def text_to_image(
|
||||
)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
|
||||
webp_image_bytes = generate_image_with_google(
|
||||
image_prompt, text_to_image_config, text2image_model, image_shape
|
||||
image_prompt,
|
||||
text_to_image_config,
|
||||
text2image_model,
|
||||
image_shape,
|
||||
chat_history=chat_history,
|
||||
query_images=query_images,
|
||||
)
|
||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||
if "content_policy_violation" in e.message:
|
||||
@@ -322,6 +332,8 @@ def generate_image_with_google(
|
||||
text_to_image_config: TextToImageModelConfig,
|
||||
text2image_model: str,
|
||||
shape: ImageShape = ImageShape.SQUARE,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
query_images: List[str] = [],
|
||||
):
|
||||
"""Generate image using Google's AI over API"""
|
||||
|
||||
@@ -337,24 +349,122 @@ def generate_image_with_google(
|
||||
else: # Square
|
||||
aspect_ratio = "1:1"
|
||||
|
||||
# Configure image generation settings
|
||||
config = gtypes.GenerateImagesConfig(
|
||||
number_of_images=1,
|
||||
safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
|
||||
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
|
||||
include_rai_reason=True,
|
||||
output_mime_type="image/png",
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
image_bytes = None
|
||||
if is_multimodal_model(text2image_model):
|
||||
# Format chat history for Gemini
|
||||
contents = format_messages_for_gemini(improved_image_prompt, text2image_model, chat_history, query_images)
|
||||
|
||||
# Call the Gemini API to generate the image
|
||||
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
|
||||
# Configure image generation settings
|
||||
config = gtypes.GenerateContentConfig(
|
||||
response_modalities=["IMAGE"], image_config=gtypes.ImageConfig(aspect_ratio=None)
|
||||
)
|
||||
|
||||
if not response.generated_images:
|
||||
raise ValueError("Failed to generate image using Google AI")
|
||||
# Call the Gemini API to generate the image
|
||||
response = client.models.generate_content(
|
||||
contents=contents,
|
||||
model=text2image_model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract the image bytes from the first generated image
|
||||
image_bytes = response.generated_images[0].image.image_bytes
|
||||
# Extract the image bytes from the first generated image
|
||||
for part in response.parts or []:
|
||||
if part.inline_data is not None:
|
||||
image = part.as_image()
|
||||
image_bytes = image.image_bytes
|
||||
break
|
||||
if not image_bytes:
|
||||
raise ValueError("Failed to generate image using Google AI")
|
||||
else:
|
||||
# Configure image generation settings
|
||||
config = gtypes.GenerateImagesConfig(
|
||||
number_of_images=1,
|
||||
safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
|
||||
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
|
||||
include_rai_reason=True,
|
||||
output_mime_type="image/png",
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
|
||||
# Call the Gemini API to generate the image
|
||||
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
|
||||
|
||||
if not response.generated_images:
|
||||
raise ValueError("Failed to generate image using Google AI")
|
||||
|
||||
# Extract the image bytes from the first generated image
|
||||
image_bytes = response.generated_images[0].image.image_bytes
|
||||
|
||||
# Convert to webp for faster loading
|
||||
return convert_image_to_webp(image_bytes)
|
||||
|
||||
|
||||
def format_messages_for_gemini(
|
||||
improved_image_prompt: str,
|
||||
text2image_model: str,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
query_images: List[str] = [],
|
||||
) -> List[gtypes.Content]:
|
||||
"""Format chat messages for Gemini multimodal models.
|
||||
|
||||
Reframes assistant messages with generated images as user messages to enable
|
||||
multi-turn image editing with gemini 3 models.
|
||||
"""
|
||||
contents = []
|
||||
for chat in chat_history:
|
||||
role = "model" if chat.by == "khoj" else "user"
|
||||
parts = []
|
||||
|
||||
# Reframe assistant messages to gemini 3 as user messages
|
||||
# This enables multi-turn image edits without storing, passing thought_signature required by gemini 3 models
|
||||
if role == "model" and text2image_model.startswith("gemini-3"):
|
||||
if chat.images:
|
||||
parts.append(gtypes.Part.from_text(text="This is the image you previously generated:"))
|
||||
for image_data in chat.images:
|
||||
if image_data.startswith("http"):
|
||||
image = get_image_from_url(image_data, type="bytes")
|
||||
else:
|
||||
image = get_image_from_base64(image_data, type="bytes")
|
||||
parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
|
||||
else:
|
||||
parts.append(gtypes.Part.from_text(text="This is the message you previously sent:"))
|
||||
messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item]
|
||||
for text in messages:
|
||||
if isinstance(text, dict) and not is_none_or_empty(text.get("text")):
|
||||
parts.append(gtypes.Part.from_text(text=text.get("text")))
|
||||
elif isinstance(text, str):
|
||||
parts.append(gtypes.Part.from_text(text=text))
|
||||
contents.append(gtypes.Content(role="user", parts=parts))
|
||||
continue
|
||||
|
||||
# Handle regular messages
|
||||
for image_data in chat.images or []:
|
||||
if image_data.startswith("http"):
|
||||
image = get_image_from_url(image_data, type="bytes")
|
||||
else:
|
||||
image = get_image_from_base64(image_data, type="bytes")
|
||||
parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
|
||||
messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item]
|
||||
for text in messages:
|
||||
if isinstance(text, dict) and not is_none_or_empty(text.get("text")):
|
||||
parts.append(gtypes.Part.from_text(text=text.get("text")))
|
||||
elif isinstance(text, str):
|
||||
parts.append(gtypes.Part.from_text(text=text))
|
||||
contents.append(gtypes.Content(role=role, parts=parts))
|
||||
|
||||
query_parts = []
|
||||
for img in query_images or []:
|
||||
if img.startswith("http"):
|
||||
image = get_image_from_url(img, type="bytes")
|
||||
else:
|
||||
image = get_image_from_base64(img, type="bytes")
|
||||
query_parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type))
|
||||
query_parts.append(gtypes.Part.from_text(text=improved_image_prompt))
|
||||
|
||||
contents += [gtypes.Content(role="user", parts=query_parts)]
|
||||
return contents
|
||||
|
||||
|
||||
def is_multimodal_model(model_name: str) -> bool:
|
||||
"""Check if the model can see and generate images"""
|
||||
multimodal_models = ["gemini-2.5-flash-image", "gemini-3-pro-image-preview"]
|
||||
return model_name.lower() in multimodal_models
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -871,7 +871,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "google-genai"
|
||||
version = "1.51.0"
|
||||
version = "1.52.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@@ -883,9 +883,9 @@ dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/1c/29245699c7c274ed5709b33b6a5192af2d57da5da3d2f189f222d1895336/google_genai-1.51.0.tar.gz", hash = "sha256:596c1ec964b70fec17a6ccfe6ee4edede31022584e8b1d33371d93037c4001b1", size = 258060, upload-time = "2025-11-18T05:32:47.068Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/4e/0ad8585d05312074bb69711b2d81cfed69ce0ae441913d57bf169bed20a7/google_genai-1.52.0.tar.gz", hash = "sha256:a74e8a4b3025f23aa98d6a0f84783119012ca6c336fd68f73c5d2b11465d7fc5", size = 258743, upload-time = "2025-11-21T02:18:55.742Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/28/0185dcda66f1994171067cfdb0e44a166450239d5b11b3a8a281dd2da459/google_genai-1.51.0-py3-none-any.whl", hash = "sha256:bfb7d0c6ba48ba9bda539f0d5e69dad827d8735a8b1e4703bafa0a2945d293e1", size = 260483, upload-time = "2025-11-18T05:32:45.938Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/66/03f663e7bca7abe9ccfebe6cb3fe7da9a118fd723a5abb278d6117e7990e/google_genai-1.52.0-py3-none-any.whl", hash = "sha256:c8352b9f065ae14b9322b949c7debab8562982f03bf71d44130cd2b798c20743", size = 261219, upload-time = "2025-11-21T02:18:54.515Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1312,7 +1312,7 @@ requires-dist = [
|
||||
{ name = "freezegun", marker = "extra == 'dev'", specifier = ">=1.2.0" },
|
||||
{ name = "gitpython", marker = "extra == 'dev'", specifier = "~=3.1.43" },
|
||||
{ name = "google-auth", specifier = "~=2.23.3" },
|
||||
{ name = "google-genai", specifier = "==1.51.0" },
|
||||
{ name = "google-genai", specifier = "==1.52.0" },
|
||||
{ name = "gunicorn", marker = "extra == 'dev'", specifier = "==22.0.0" },
|
||||
{ name = "gunicorn", marker = "extra == 'prod'", specifier = "==22.0.0" },
|
||||
{ name = "httpx", specifier = "==0.28.1" },
|
||||
|
||||
Reference in New Issue
Block a user