diff --git a/pyproject.toml b/pyproject.toml index 1c0a9158..3015ce43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dependencies = [ "django_apscheduler == 0.7.0", "anthropic == 0.52.0", "docx2txt == 0.8", - "google-genai == 1.51.0", + "google-genai == 1.52.0", "google-auth ~= 2.23.3", "pyjson5 == 1.6.7", "resend == 1.0.1", diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 7ff83829..4ab1f241 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -27,10 +27,11 @@ from khoj.database.models import ( TextToImageModelConfig, ) from khoj.processor.conversation.google.utils import _is_retryable_error +from khoj.processor.conversation.utils import get_image_from_base64, get_image_from_url from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt from khoj.routers.storage import upload_generated_image_to_bucket from khoj.utils import state -from khoj.utils.helpers import convert_image_to_webp, timer +from khoj.utils.helpers import convert_image_to_webp, is_none_or_empty, timer from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) @@ -74,27 +75,31 @@ async def text_to_image( elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]: image_chat_history += [ChatMessageModel(by=chat.by, message=chat.message, intent=default_intent)] - if send_status_func: - async for event in send_status_func("**Enhancing the Painting Prompt**"): - yield {ChatEvent.STATUS: event} - # Generate a better image prompt # Use the user's message, chat history, and other context - image_prompt_response = await generate_better_image_prompt( - message, - image_chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - model_type=text_to_image_config.model_type, - query_images=query_images, - user=user, - agent=agent, - query_files=query_files, - tracer=tracer, - ) - image_prompt = image_prompt_response["description"] - image_shape = image_prompt_response["shape"] + if not is_multimodal_model(text2image_model): + if send_status_func: + async for event in send_status_func("**Enhancing the Painting Prompt**"): + yield {ChatEvent.STATUS: event} + + image_prompt_response = await generate_better_image_prompt( + message, + image_chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + query_images=query_images, + user=user, + agent=agent, + query_files=query_files, + tracer=tracer, + ) + image_prompt = image_prompt_response["description"] + image_shape = image_prompt_response["shape"] + else: + image_prompt = message + image_shape = None if send_status_func: async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"): @@ -115,7 +120,12 @@ async def text_to_image( ) elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE: webp_image_bytes = generate_image_with_google( - image_prompt, text_to_image_config, text2image_model, image_shape + image_prompt, + text_to_image_config, + text2image_model, + image_shape, + chat_history=chat_history, + query_images=query_images, ) except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: @@ -322,6 +332,8 @@ def generate_image_with_google( text_to_image_config: TextToImageModelConfig, text2image_model: str, shape: ImageShape = ImageShape.SQUARE, + chat_history: List[ChatMessageModel] = [], + query_images: List[str] = [], ): """Generate image using Google's AI over API""" @@ -337,24 +349,122 @@ def generate_image_with_google( else: # Square aspect_ratio = "1:1" - # Configure image generation settings - config = gtypes.GenerateImagesConfig( - number_of_images=1, - safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE, - person_generation=gtypes.PersonGeneration.ALLOW_ADULT, - include_rai_reason=True, - output_mime_type="image/png", - aspect_ratio=aspect_ratio, - ) + image_bytes = None + if is_multimodal_model(text2image_model): + # Format chat history for Gemini + contents = format_messages_for_gemini(improved_image_prompt, text2image_model, chat_history, query_images) - # Call the Gemini API to generate the image - response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config) + # Configure image generation settings + config = gtypes.GenerateContentConfig( + response_modalities=["IMAGE"], image_config=gtypes.ImageConfig(aspect_ratio=None) + ) - if not response.generated_images: - raise ValueError("Failed to generate image using Google AI") + # Call the Gemini API to generate the image + response = client.models.generate_content( + contents=contents, + model=text2image_model, + config=config, + ) - # Extract the image bytes from the first generated image - image_bytes = response.generated_images[0].image.image_bytes + # Extract the image bytes from the first generated image + for part in response.parts or []: + if part.inline_data is not None: + image = part.as_image() + image_bytes = image.image_bytes + break + if not image_bytes: + raise ValueError("Failed to generate image using Google AI") + else: + # Configure image generation settings + config = gtypes.GenerateImagesConfig( + number_of_images=1, + safety_filter_level=gtypes.SafetyFilterLevel.BLOCK_LOW_AND_ABOVE, + person_generation=gtypes.PersonGeneration.ALLOW_ADULT, + include_rai_reason=True, + output_mime_type="image/png", + aspect_ratio=aspect_ratio, + ) + + # Call the Gemini API to generate the image + response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config) + + if not response.generated_images: + raise ValueError("Failed to generate image using Google AI") + + # Extract the image bytes from the first generated image + image_bytes = response.generated_images[0].image.image_bytes # Convert to webp for faster loading return convert_image_to_webp(image_bytes) + + +def format_messages_for_gemini( + improved_image_prompt: str, + text2image_model: str, + chat_history: List[ChatMessageModel] = [], + query_images: List[str] = [], +) -> List[gtypes.Content]: + """Format chat messages for Gemini multimodal models. + + Reframes assistant messages with generated images as user messages to enable + multi-turn image editing with gemini 3 models. + """ + contents = [] + for chat in chat_history: + role = "model" if chat.by == "khoj" else "user" + parts = [] + + # Reframe assistant messages to gemini 3 as user messages + # This enables multi-turn image edits without storing, passing thought_signature required by gemini 3 models + if role == "model" and text2image_model.startswith("gemini-3"): + if chat.images: + parts.append(gtypes.Part.from_text(text="This is the image you previously generated:")) + for image_data in chat.images: + if image_data.startswith("http"): + image = get_image_from_url(image_data, type="bytes") + else: + image = get_image_from_base64(image_data, type="bytes") + parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type)) + else: + parts.append(gtypes.Part.from_text(text="This is the message you previously sent:")) + messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item] + for text in messages: + if isinstance(text, dict) and not is_none_or_empty(text.get("text")): + parts.append(gtypes.Part.from_text(text=text.get("text"))) + elif isinstance(text, str): + parts.append(gtypes.Part.from_text(text=text)) + contents.append(gtypes.Content(role="user", parts=parts)) + continue + + # Handle regular messages + for image_data in chat.images or []: + if image_data.startswith("http"): + image = get_image_from_url(image_data, type="bytes") + else: + image = get_image_from_base64(image_data, type="bytes") + parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type)) + messages = chat.message if isinstance(chat.message, list) else [chat.message] # type: ignore[list-item] + for text in messages: + if isinstance(text, dict) and not is_none_or_empty(text.get("text")): + parts.append(gtypes.Part.from_text(text=text.get("text"))) + elif isinstance(text, str): + parts.append(gtypes.Part.from_text(text=text)) + contents.append(gtypes.Content(role=role, parts=parts)) + + query_parts = [] + for img in query_images or []: + if img.startswith("http"): + image = get_image_from_url(img, type="bytes") + else: + image = get_image_from_base64(img, type="bytes") + query_parts.append(gtypes.Part.from_bytes(data=image.content, mime_type=image.type)) + query_parts.append(gtypes.Part.from_text(text=improved_image_prompt)) + + contents += [gtypes.Content(role="user", parts=query_parts)] + return contents + + +def is_multimodal_model(model_name: str) -> bool: + """Check if the model can see and generate images""" + multimodal_models = ["gemini-2.5-flash-image", "gemini-3-pro-image-preview"] + return model_name.lower() in multimodal_models diff --git a/uv.lock b/uv.lock index fe8f0835..6b31a20d 100644 --- a/uv.lock +++ b/uv.lock @@ -871,7 +871,7 @@ wheels = [ [[package]] name = "google-genai" -version = "1.51.0" +version = "1.52.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -883,9 +883,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c3/1c/29245699c7c274ed5709b33b6a5192af2d57da5da3d2f189f222d1895336/google_genai-1.51.0.tar.gz", hash = "sha256:596c1ec964b70fec17a6ccfe6ee4edede31022584e8b1d33371d93037c4001b1", size = 258060, upload-time = "2025-11-18T05:32:47.068Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/4e/0ad8585d05312074bb69711b2d81cfed69ce0ae441913d57bf169bed20a7/google_genai-1.52.0.tar.gz", hash = "sha256:a74e8a4b3025f23aa98d6a0f84783119012ca6c336fd68f73c5d2b11465d7fc5", size = 258743, upload-time = "2025-11-21T02:18:55.742Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/28/0185dcda66f1994171067cfdb0e44a166450239d5b11b3a8a281dd2da459/google_genai-1.51.0-py3-none-any.whl", hash = "sha256:bfb7d0c6ba48ba9bda539f0d5e69dad827d8735a8b1e4703bafa0a2945d293e1", size = 260483, upload-time = "2025-11-18T05:32:45.938Z" }, + { url = "https://files.pythonhosted.org/packages/ec/66/03f663e7bca7abe9ccfebe6cb3fe7da9a118fd723a5abb278d6117e7990e/google_genai-1.52.0-py3-none-any.whl", hash = "sha256:c8352b9f065ae14b9322b949c7debab8562982f03bf71d44130cd2b798c20743", size = 261219, upload-time = "2025-11-21T02:18:54.515Z" }, ] [[package]] @@ -1312,7 +1312,7 @@ requires-dist = [ { name = "freezegun", marker = "extra == 'dev'", specifier = ">=1.2.0" }, { name = "gitpython", marker = "extra == 'dev'", specifier = "~=3.1.43" }, { name = "google-auth", specifier = "~=2.23.3" }, - { name = "google-genai", specifier = "==1.51.0" }, + { name = "google-genai", specifier = "==1.52.0" }, { name = "gunicorn", marker = "extra == 'dev'", specifier = "==22.0.0" }, { name = "gunicorn", marker = "extra == 'prod'", specifier = "==22.0.0" }, { name = "httpx", specifier = "==0.28.1" },