Support generating images with different aspect ratios

You can now specify shape of images to be generated. It can be one of
portrait, landscape or square.
This commit is contained in:
Debanjum
2025-08-26 19:20:47 -07:00
parent 5a2cae3756
commit 1e81b51abc
4 changed files with 82 additions and 19 deletions

View File

@@ -122,7 +122,8 @@ Your image description will be transformed into an image by an AI model on your
# Instructions # Instructions
- Retain important information and follow instructions by the user when composing the image description. - Retain important information and follow instructions by the user when composing the image description.
- Weave in the context provided below if it will enhance the image. - Weave in the context provided below if it will enhance the image.
- Specify desired elements, lighting, mood, and composition. - Specify desired elements, lighting, mood, and composition in the description.
- Decide the shape best suited to render the image. It can be one of square, portrait or landscape.
- Add specific, fine position details. Mention painting style, camera parameters to compose the image. - Add specific, fine position details. Mention painting style, camera parameters to compose the image.
- Transform any negations in user instructions into positive alternatives. - Transform any negations in user instructions into positive alternatives.
Instead of saying what should NOT be in the image, describe what SHOULD be there instead. Instead of saying what should NOT be in the image, describe what SHOULD be there instead.
@@ -142,9 +143,8 @@ Your image description will be transformed into an image by an AI model on your
## Online References ## Online References
{online_results} {online_results}
Now generate a vivid description of the image to be rendered. Now generate a vivid description of the image and image shape to be rendered.
Your response should be a JSON object with 'description' and 'shape' fields specified.
Image Description:
""".strip() """.strip()
) )

View File

@@ -27,7 +27,7 @@ from khoj.database.models import (
TextToImageModelConfig, TextToImageModelConfig,
) )
from khoj.processor.conversation.google.utils import _is_retryable_error from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
from khoj.routers.storage import upload_generated_image_to_bucket from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import convert_image_to_webp, timer from khoj.utils.helpers import convert_image_to_webp, timer
@@ -80,7 +80,7 @@ async def text_to_image(
# Generate a better image prompt # Generate a better image prompt
# Use the user's message, chat history, and other context # Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt( image_prompt_response = await generate_better_image_prompt(
message, message,
image_chat_history, image_chat_history,
location_data=location_data, location_data=location_data,
@@ -93,6 +93,8 @@ async def text_to_image(
query_files=query_files, query_files=query_files,
tracer=tracer, tracer=tracer,
) )
image_prompt = image_prompt_response["description"]
image_shape = image_prompt_response["shape"]
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"): async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
@@ -102,13 +104,19 @@ async def text_to_image(
with timer(f"Generate image with {text_to_image_config.model_type}", logger): with timer(f"Generate image with {text_to_image_config.model_type}", logger):
try: try:
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model) webp_image_bytes = generate_image_with_openai(
image_prompt, text_to_image_config, text2image_model, image_shape
)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model) webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE: elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model) webp_image_bytes = generate_image_with_replicate(
image_prompt, text_to_image_config, text2image_model, image_shape
)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE: elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model) webp_image_bytes = generate_image_with_google(
image_prompt, text_to_image_config, text2image_model, image_shape
)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -159,7 +167,10 @@ async def text_to_image(
reraise=True, reraise=True,
) )
def generate_image_with_openai( def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str,
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
): ):
"Generate image using OpenAI (compatible) API" "Generate image using OpenAI (compatible) API"
@@ -175,12 +186,21 @@ def generate_image_with_openai(
elif state.openai_client: elif state.openai_client:
openai_client = state.openai_client openai_client = state.openai_client
# Convert shape to size for OpenAI
if shape == ImageShape.PORTRAIT:
size = "1024x1536"
elif shape == ImageShape.LANDSCAPE:
size = "1536x1024"
else: # Square
size = "1024x1024"
# Generate image using OpenAI API # Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid" OPENAI_IMAGE_GEN_STYLE = "vivid"
response = openai_client.images.generate( response = openai_client.images.generate(
prompt=improved_image_prompt, prompt=improved_image_prompt,
model=text2image_model, model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE, style=OPENAI_IMAGE_GEN_STYLE,
size=size,
response_format="b64_json", response_format="b64_json",
) )
@@ -227,10 +247,22 @@ def generate_image_with_stability(
reraise=True, reraise=True,
) )
def generate_image_with_replicate( def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str,
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
): ):
"Generate image using Replicate API" "Generate image using Replicate API"
# Convert shape to aspect ratio for Replicate
# Replicate supports only 1:1, 3:4, and 4:3 aspect ratios
if shape == ImageShape.PORTRAIT:
aspect_ratio = "3:4"
elif shape == ImageShape.LANDSCAPE:
aspect_ratio = "4:3"
else: # Square
aspect_ratio = "1:1"
# Create image generation task on Replicate # Create image generation task on Replicate
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions" replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = { headers = {
@@ -241,7 +273,7 @@ def generate_image_with_replicate(
"input": { "input": {
"prompt": improved_image_prompt, "prompt": improved_image_prompt,
"num_outputs": 1, "num_outputs": 1,
"aspect_ratio": "1:1", "aspect_ratio": aspect_ratio,
"output_format": "webp", "output_format": "webp",
"output_quality": 100, "output_quality": 100,
} }
@@ -286,7 +318,10 @@ def generate_image_with_replicate(
reraise=True, reraise=True,
) )
def generate_image_with_google( def generate_image_with_google(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str,
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
): ):
"""Generate image using Google's AI over API""" """Generate image using Google's AI over API"""
@@ -294,6 +329,14 @@ def generate_image_with_google(
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
client = genai.Client(api_key=api_key) client = genai.Client(api_key=api_key)
# Convert shape to aspect ratio for Google
if shape == ImageShape.PORTRAIT:
aspect_ratio = "3:4"
elif shape == ImageShape.LANDSCAPE:
aspect_ratio = "4:3"
else: # Square
aspect_ratio = "1:1"
# Configure image generation settings # Configure image generation settings
config = gtypes.GenerateImagesConfig( config = gtypes.GenerateImagesConfig(
number_of_images=1, number_of_images=1,
@@ -301,6 +344,7 @@ def generate_image_with_google(
person_generation=gtypes.PersonGeneration.ALLOW_ADULT, person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
include_rai_reason=True, include_rai_reason=True,
output_mime_type="image/png", output_mime_type="image/png",
aspect_ratio=aspect_ratio,
) )
# Call the Gemini API to generate the image # Call the Gemini API to generate the image

View File

@@ -113,6 +113,7 @@ from khoj.utils import state
from khoj.utils.helpers import ( from khoj.utils.helpers import (
LRU, LRU,
ConversationCommand, ConversationCommand,
ImageShape,
ToolDefinition, ToolDefinition,
get_file_type, get_file_type,
in_debug_mode, in_debug_mode,
@@ -1076,7 +1077,7 @@ async def generate_better_image_prompt(
agent: Agent = None, agent: Agent = None,
query_files: str = "", query_files: str = "",
tracer: dict = {}, tracer: dict = {},
) -> str: ) -> dict:
""" """
Generate a better image prompt from the given query Generate a better image prompt from the given query
""" """
@@ -1104,10 +1105,16 @@ async def generate_better_image_prompt(
personality_context=personality_context, personality_context=personality_context,
) )
class ImagePromptResponse(BaseModel):
description: str = Field(description="Enhanced image description")
shape: ImageShape = Field(
description="Aspect ratio/shape best suited to render the image: Portrait, Landscape, or Square"
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
q, q,
system_message=enhance_image_system_message, system_message=enhance_image_system_message,
query_images=query_images, query_images=query_images,
@@ -1115,13 +1122,19 @@ async def generate_better_image_prompt(
chat_history=conversation_history, chat_history=conversation_history,
agent_chat_model=agent_chat_model, agent_chat_model=agent_chat_model,
user=user, user=user,
response_type="json_object",
response_schema=ImagePromptResponse,
tracer=tracer, tracer=tracer,
) )
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
return response_text # Parse the structured response
try:
response = clean_json(raw_response.text)
parsed_response = pyjson5.loads(response)
return parsed_response
except Exception:
# Fallback to user query as image description
return {"description": q, "shape": ImageShape.SQUARE}
async def search_documents( async def search_documents(

View File

@@ -875,6 +875,12 @@ def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> s
return output_data_uri return output_data_uri
class ImageShape(str, Enum):
PORTRAIT = "Portrait"
LANDSCAPE = "Landscape"
SQUARE = "Square"
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]: def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
""" """
Truncate large output files and drop image file data from code results. Truncate large output files and drop image file data from code results.