Support generating images with different aspect ratios

You can now specify shape of images to be generated. It can be one of
portrait, landscape or square.
This commit is contained in:
Debanjum
2025-08-26 19:20:47 -07:00
parent 5a2cae3756
commit 1e81b51abc
4 changed files with 82 additions and 19 deletions

View File

@@ -122,7 +122,8 @@ Your image description will be transformed into an image by an AI model on your
# Instructions
- Retain important information and follow instructions by the user when composing the image description.
- Weave in the context provided below if it will enhance the image.
- Specify desired elements, lighting, mood, and composition.
- Specify desired elements, lighting, mood, and composition in the description.
- Decide the shape best suited to render the image. It can be one of square, portrait or landscape.
- Add specific, fine position details. Mention painting style, camera parameters to compose the image.
- Transform any negations in user instructions into positive alternatives.
Instead of saying what should NOT be in the image, describe what SHOULD be there instead.
@@ -142,9 +143,8 @@ Your image description will be transformed into an image by an AI model on your
## Online References
{online_results}
Now generate a vivid description of the image to be rendered.
Image Description:
Now generate a vivid description of the image and image shape to be rendered.
Your response should be a JSON object with 'description' and 'shape' fields specified.
""".strip()
)

View File

@@ -27,7 +27,7 @@ from khoj.database.models import (
TextToImageModelConfig,
)
from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import convert_image_to_webp, timer
@@ -80,7 +80,7 @@ async def text_to_image(
# Generate a better image prompt
# Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt(
image_prompt_response = await generate_better_image_prompt(
message,
image_chat_history,
location_data=location_data,
@@ -93,6 +93,8 @@ async def text_to_image(
query_files=query_files,
tracer=tracer,
)
image_prompt = image_prompt_response["description"]
image_shape = image_prompt_response["shape"]
if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
@@ -102,13 +104,19 @@ async def text_to_image(
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
try:
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
webp_image_bytes = generate_image_with_openai(
image_prompt, text_to_image_config, text2image_model, image_shape
)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
webp_image_bytes = generate_image_with_replicate(
image_prompt, text_to_image_config, text2image_model, image_shape
)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
webp_image_bytes = generate_image_with_google(
image_prompt, text_to_image_config, text2image_model, image_shape
)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -159,7 +167,10 @@ async def text_to_image(
reraise=True,
)
def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
improved_image_prompt: str,
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
):
"Generate image using OpenAI (compatible) API"
@@ -175,12 +186,21 @@ def generate_image_with_openai(
elif state.openai_client:
openai_client = state.openai_client
# Convert shape to size for OpenAI
if shape == ImageShape.PORTRAIT:
size = "1024x1536"
elif shape == ImageShape.LANDSCAPE:
size = "1536x1024"
else: # Square
size = "1024x1024"
# Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid"
response = openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE,
size=size,
response_format="b64_json",
)
@@ -227,10 +247,22 @@ def generate_image_with_stability(
reraise=True,
)
def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
improved_image_prompt: str,
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
):
"Generate image using Replicate API"
# Convert shape to aspect ratio for Replicate
# Replicate supports only 1:1, 3:4, and 4:3 aspect ratios
if shape == ImageShape.PORTRAIT:
aspect_ratio = "3:4"
elif shape == ImageShape.LANDSCAPE:
aspect_ratio = "4:3"
else: # Square
aspect_ratio = "1:1"
# Create image generation task on Replicate
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
@@ -241,7 +273,7 @@ def generate_image_with_replicate(
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"aspect_ratio": aspect_ratio,
"output_format": "webp",
"output_quality": 100,
}
@@ -286,7 +318,10 @@ def generate_image_with_replicate(
reraise=True,
)
def generate_image_with_google(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
improved_image_prompt: str,
text_to_image_config: TextToImageModelConfig,
text2image_model: str,
shape: ImageShape = ImageShape.SQUARE,
):
"""Generate image using Google's AI over API"""
@@ -294,6 +329,14 @@ def generate_image_with_google(
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
client = genai.Client(api_key=api_key)
# Convert shape to aspect ratio for Google
if shape == ImageShape.PORTRAIT:
aspect_ratio = "3:4"
elif shape == ImageShape.LANDSCAPE:
aspect_ratio = "4:3"
else: # Square
aspect_ratio = "1:1"
# Configure image generation settings
config = gtypes.GenerateImagesConfig(
number_of_images=1,
@@ -301,6 +344,7 @@ def generate_image_with_google(
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
include_rai_reason=True,
output_mime_type="image/png",
aspect_ratio=aspect_ratio,
)
# Call the Gemini API to generate the image

View File

@@ -113,6 +113,7 @@ from khoj.utils import state
from khoj.utils.helpers import (
LRU,
ConversationCommand,
ImageShape,
ToolDefinition,
get_file_type,
in_debug_mode,
@@ -1076,7 +1077,7 @@ async def generate_better_image_prompt(
agent: Agent = None,
query_files: str = "",
tracer: dict = {},
) -> str:
) -> dict:
"""
Generate a better image prompt from the given query
"""
@@ -1104,10 +1105,16 @@ async def generate_better_image_prompt(
personality_context=personality_context,
)
class ImagePromptResponse(BaseModel):
description: str = Field(description="Enhanced image description")
shape: ImageShape = Field(
description="Aspect ratio/shape best suited to render the image: Portrait, Landscape, or Square"
)
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
raw_response = await send_message_to_model_wrapper(
q,
system_message=enhance_image_system_message,
query_images=query_images,
@@ -1115,13 +1122,19 @@ async def generate_better_image_prompt(
chat_history=conversation_history,
agent_chat_model=agent_chat_model,
user=user,
response_type="json_object",
response_schema=ImagePromptResponse,
tracer=tracer,
)
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
return response_text
# Parse the structured response
try:
response = clean_json(raw_response.text)
parsed_response = pyjson5.loads(response)
return parsed_response
except Exception:
# Fallback to user query as image description
return {"description": q, "shape": ImageShape.SQUARE}
async def search_documents(

View File

@@ -875,6 +875,12 @@ def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> s
return output_data_uri
class ImageShape(str, Enum):
PORTRAIT = "Portrait"
LANDSCAPE = "Landscape"
SQUARE = "Square"
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
"""
Truncate large output files and drop image file data from code results.