diff --git a/src/khoj/database/migrations/0086_alter_texttoimagemodelconfig_model_type.py b/src/khoj/database/migrations/0086_alter_texttoimagemodelconfig_model_type.py new file mode 100644 index 00000000..f55ba90e --- /dev/null +++ b/src/khoj/database/migrations/0086_alter_texttoimagemodelconfig_model_type.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.10 on 2025-03-11 16:58 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0085_alter_agent_output_modes"), + ] + + operations = [ + migrations.AlterField( + model_name="texttoimagemodelconfig", + name="model_type", + field=models.CharField( + choices=[ + ("openai", "Openai"), + ("stability-ai", "Stabilityai"), + ("replicate", "Replicate"), + ("google", "Google"), + ], + default="openai", + max_length=200, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index f366c15a..f9196f80 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -530,6 +530,7 @@ class TextToImageModelConfig(DbBaseModel): OPENAI = "openai" STABILITYAI = "stability-ai" REPLICATE = "replicate" + GOOGLE = "google" model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) @@ -547,11 +548,11 @@ class TextToImageModelConfig(DbBaseModel): error[ "ai_model_api" ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." - if self.model_type != self.ModelType.OPENAI: + if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE: if not self.api_key: - error["api_key"] = "The API key field must be set for non OpenAI models." + error["api_key"] = "The API key field must be set for non OpenAI, non Google models." if self.ai_model_api: - error["ai_model_api"] = "AI Model API cannot be set for non OpenAI models." + error["ai_model_api"] = "AI Model API cannot be set for non OpenAI, non Google models." if error: raise ValidationError(error) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 252e61eb..742fe737 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Optional import openai import requests +from google import genai +from google.genai import types as gtypes from khoj.database.adapters import ConversationAdapters from khoj.database.models import Agent, KhojUser, TextToImageModelConfig @@ -86,6 +88,8 @@ async def text_to_image( webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model) elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE: webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model) + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE: + webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model) except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") @@ -99,6 +103,12 @@ async def text_to_image( status_code = e.status_code # type: ignore yield image_url or image, status_code, message return + except ValueError as e: + logger.error(f"Image Generation failed with {e}", exc_info=True) + message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to an unknown error" + status_code = 500 + yield image_url or image, status_code, message + return except requests.RequestException as e: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error." @@ -215,3 +225,28 @@ def generate_image_with_replicate( # Get the generated image image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"] return io.BytesIO(requests.get(image_url).content).getvalue() + + +def generate_image_with_google( + improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str +): + """Generate image using Google's AI over API""" + + # Initialize the Google AI client + api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key + client = genai.Client(api_key=api_key) + + # Configure image generation settings + config = gtypes.GenerateImagesConfig(number_of_images=1) + + # Call the Gemini API to generate the image + response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config) + + if not response.generated_images: + raise ValueError("Failed to generate image using Google AI") + + # Extract the image bytes from the first generated image + image_bytes = response.generated_images[0].image.image_bytes + + # Convert to webp for faster loading + return convert_image_to_webp(image_bytes) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2bc392a0..75f38948 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1092,7 +1092,11 @@ async def generate_better_image_prompt( online_results=simplified_online_results, personality_context=personality_context, ) - elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]: + elif model_type in [ + TextToImageModelConfig.ModelType.STABILITYAI, + TextToImageModelConfig.ModelType.REPLICATE, + TextToImageModelConfig.ModelType.GOOGLE, + ]: image_prompt = prompts.image_generation_improve_prompt_sd.format( query=q, chat_history=conversation_history,