mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Add support for Google Imagen AI models for image generation
Use the new Google GenAI client to generate images with Imagen
This commit is contained in:
@@ -0,0 +1,26 @@
|
|||||||
|
# Generated by Django 5.0.10 on 2025-03-11 16:58
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0085_alter_agent_output_modes"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="model_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[
|
||||||
|
("openai", "Openai"),
|
||||||
|
("stability-ai", "Stabilityai"),
|
||||||
|
("replicate", "Replicate"),
|
||||||
|
("google", "Google"),
|
||||||
|
],
|
||||||
|
default="openai",
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -530,6 +530,7 @@ class TextToImageModelConfig(DbBaseModel):
|
|||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
STABILITYAI = "stability-ai"
|
STABILITYAI = "stability-ai"
|
||||||
REPLICATE = "replicate"
|
REPLICATE = "replicate"
|
||||||
|
GOOGLE = "google"
|
||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
@@ -547,11 +548,11 @@ class TextToImageModelConfig(DbBaseModel):
|
|||||||
error[
|
error[
|
||||||
"ai_model_api"
|
"ai_model_api"
|
||||||
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||||
if self.model_type != self.ModelType.OPENAI:
|
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
error["api_key"] = "The API key field must be set for non OpenAI models."
|
error["api_key"] = "The API key field must be set for non OpenAI, non Google models."
|
||||||
if self.ai_model_api:
|
if self.ai_model_api:
|
||||||
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI models."
|
error["ai_model_api"] = "AI Model API cannot be set for non OpenAI, non Google models."
|
||||||
if error:
|
if error:
|
||||||
raise ValidationError(error)
|
raise ValidationError(error)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Optional
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types as gtypes
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
||||||
@@ -86,6 +88,8 @@ async def text_to_image(
|
|||||||
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
||||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
||||||
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
|
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
|
||||||
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
|
||||||
|
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
if "content_policy_violation" in e.message:
|
if "content_policy_violation" in e.message:
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
@@ -99,6 +103,12 @@ async def text_to_image(
|
|||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
yield image_url or image, status_code, message
|
yield image_url or image, status_code, message
|
||||||
return
|
return
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to an unknown error"
|
||||||
|
status_code = 500
|
||||||
|
yield image_url or image, status_code, message
|
||||||
|
return
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||||
@@ -215,3 +225,28 @@ def generate_image_with_replicate(
|
|||||||
# Get the generated image
|
# Get the generated image
|
||||||
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
|
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
|
||||||
return io.BytesIO(requests.get(image_url).content).getvalue()
|
return io.BytesIO(requests.get(image_url).content).getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_with_google(
|
||||||
|
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||||
|
):
|
||||||
|
"""Generate image using Google's AI over API"""
|
||||||
|
|
||||||
|
# Initialize the Google AI client
|
||||||
|
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
|
||||||
|
client = genai.Client(api_key=api_key)
|
||||||
|
|
||||||
|
# Configure image generation settings
|
||||||
|
config = gtypes.GenerateImagesConfig(number_of_images=1)
|
||||||
|
|
||||||
|
# Call the Gemini API to generate the image
|
||||||
|
response = client.models.generate_images(model=text2image_model, prompt=improved_image_prompt, config=config)
|
||||||
|
|
||||||
|
if not response.generated_images:
|
||||||
|
raise ValueError("Failed to generate image using Google AI")
|
||||||
|
|
||||||
|
# Extract the image bytes from the first generated image
|
||||||
|
image_bytes = response.generated_images[0].image.image_bytes
|
||||||
|
|
||||||
|
# Convert to webp for faster loading
|
||||||
|
return convert_image_to_webp(image_bytes)
|
||||||
|
|||||||
@@ -1092,7 +1092,11 @@ async def generate_better_image_prompt(
|
|||||||
online_results=simplified_online_results,
|
online_results=simplified_online_results,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
|
elif model_type in [
|
||||||
|
TextToImageModelConfig.ModelType.STABILITYAI,
|
||||||
|
TextToImageModelConfig.ModelType.REPLICATE,
|
||||||
|
TextToImageModelConfig.ModelType.GOOGLE,
|
||||||
|
]:
|
||||||
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||||
query=q,
|
query=q,
|
||||||
chat_history=conversation_history,
|
chat_history=conversation_history,
|
||||||
|
|||||||
Reference in New Issue
Block a user