mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 05:39:06 +00:00
Add support for Google Imagen AI models for image generation
Use the new Google GenAI client to generate images with Imagen
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
# Generated by Django 5.0.10 on 2025-03-11 16:58
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0085_alter_agent_output_modes"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="texttoimagemodelconfig",
|
||||
name="model_type",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("openai", "Openai"),
|
||||
("stability-ai", "Stabilityai"),
|
||||
("replicate", "Replicate"),
|
||||
("google", "Google"),
|
||||
],
|
||||
default="openai",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -530,6 +530,7 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
OPENAI = "openai"
|
||||
STABILITYAI = "stability-ai"
|
||||
REPLICATE = "replicate"
|
||||
GOOGLE = "google"
|
||||
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
@@ -547,11 +548,11 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
error[
|
||||
"ai_model_api"
|
||||
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||
if self.model_type != self.ModelType.OPENAI:
|
||||
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
|
||||
if not self.api_key:
|
||||
error["api_key"] = "The API key field must be set for non OpenAI models."
|
||||
error["api_key"] = "The API key field must be set for non OpenAI, non Google models."
|
||||
if self.ai_model_api:
|
||||
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI models."
|
||||
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI, non Google models."
|
||||
if error:
|
||||
raise ValidationError(error)
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from google import genai
|
||||
from google.genai import types as gtypes
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
||||
@@ -86,6 +88,8 @@ async def text_to_image(
|
||||
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
||||
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
|
||||
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
|
||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
@@ -99,6 +103,12 @@ async def text_to_image(
|
||||
status_code = e.status_code # type: ignore
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
except ValueError as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to an unknown error"
|
||||
status_code = 500
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||
@@ -215,3 +225,28 @@ def generate_image_with_replicate(
|
||||
# Get the generated image
|
||||
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
|
||||
return io.BytesIO(requests.get(image_url).content).getvalue()
|
||||
|
||||
|
||||
def generate_image_with_google(
|
||||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||
):
|
||||
"""Generate image using Google's AI over API"""
|
||||
|
||||
# Initialize the Google AI client
|
||||
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Configure image generation settings
|
||||
config = gtypes.GenerateImagesConfig(number_of_images=1)
|
||||
|
||||
# Call the Gemini API to generate the image
|
||||
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
|
||||
|
||||
if not response.generated_images:
|
||||
raise ValueError("Failed to generate image using Google AI")
|
||||
|
||||
# Extract the image bytes from the first generated image
|
||||
image_bytes = response.generated_images[0].image.image_bytes
|
||||
|
||||
# Convert to webp for faster loading
|
||||
return convert_image_to_webp(image_bytes)
|
||||
|
||||
@@ -1092,7 +1092,11 @@ async def generate_better_image_prompt(
|
||||
online_results=simplified_online_results,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
|
||||
elif model_type in [
|
||||
TextToImageModelConfig.ModelType.STABILITYAI,
|
||||
TextToImageModelConfig.ModelType.REPLICATE,
|
||||
TextToImageModelConfig.ModelType.GOOGLE,
|
||||
]:
|
||||
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
|
||||
Reference in New Issue
Block a user