mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support generating images with different aspect ratios
You can now specify shape of images to be generated. It can be one of portrait, landscape or square.
This commit is contained in:
@@ -122,7 +122,8 @@ Your image description will be transformed into an image by an AI model on your
|
||||
# Instructions
|
||||
- Retain important information and follow instructions by the user when composing the image description.
|
||||
- Weave in the context provided below if it will enhance the image.
|
||||
- Specify desired elements, lighting, mood, and composition.
|
||||
- Specify desired elements, lighting, mood, and composition in the description.
|
||||
- Decide the shape best suited to render the image. It can be one of square, portrait or landscape.
|
||||
- Add specific, fine position details. Mention painting style, camera parameters to compose the image.
|
||||
- Transform any negations in user instructions into positive alternatives.
|
||||
Instead of saying what should NOT be in the image, describe what SHOULD be there instead.
|
||||
@@ -142,9 +143,8 @@ Your image description will be transformed into an image by an AI model on your
|
||||
## Online References
|
||||
{online_results}
|
||||
|
||||
Now generate a vivid description of the image to be rendered.
|
||||
|
||||
Image Description:
|
||||
Now generate a vivid description of the image and image shape to be rendered.
|
||||
Your response should be a JSON object with 'description' and 'shape' fields specified.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from khoj.database.models import (
|
||||
TextToImageModelConfig,
|
||||
)
|
||||
from khoj.processor.conversation.google.utils import _is_retryable_error
|
||||
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
||||
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
|
||||
from khoj.routers.storage import upload_generated_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import convert_image_to_webp, timer
|
||||
@@ -80,7 +80,7 @@ async def text_to_image(
|
||||
|
||||
# Generate a better image prompt
|
||||
# Use the user's message, chat history, and other context
|
||||
image_prompt = await generate_better_image_prompt(
|
||||
image_prompt_response = await generate_better_image_prompt(
|
||||
message,
|
||||
image_chat_history,
|
||||
location_data=location_data,
|
||||
@@ -93,6 +93,8 @@ async def text_to_image(
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
image_prompt = image_prompt_response["description"]
|
||||
image_shape = image_prompt_response["shape"]
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
|
||||
@@ -102,13 +104,19 @@ async def text_to_image(
|
||||
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
|
||||
try:
|
||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
|
||||
webp_image_bytes = generate_image_with_openai(
|
||||
image_prompt, text_to_image_config, text2image_model, image_shape
|
||||
)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
||||
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
|
||||
webp_image_bytes = generate_image_with_replicate(
|
||||
image_prompt, text_to_image_config, text2image_model, image_shape
|
||||
)
|
||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE:
|
||||
webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model)
|
||||
webp_image_bytes = generate_image_with_google(
|
||||
image_prompt, text_to_image_config, text2image_model, image_shape
|
||||
)
|
||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
@@ -159,7 +167,10 @@ async def text_to_image(
|
||||
reraise=True,
|
||||
)
|
||||
def generate_image_with_openai(
|
||||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||
improved_image_prompt: str,
|
||||
text_to_image_config: TextToImageModelConfig,
|
||||
text2image_model: str,
|
||||
shape: ImageShape = ImageShape.SQUARE,
|
||||
):
|
||||
"Generate image using OpenAI (compatible) API"
|
||||
|
||||
@@ -175,12 +186,21 @@ def generate_image_with_openai(
|
||||
elif state.openai_client:
|
||||
openai_client = state.openai_client
|
||||
|
||||
# Convert shape to size for OpenAI
|
||||
if shape == ImageShape.PORTRAIT:
|
||||
size = "1024x1536"
|
||||
elif shape == ImageShape.LANDSCAPE:
|
||||
size = "1536x1024"
|
||||
else: # Square
|
||||
size = "1024x1024"
|
||||
|
||||
# Generate image using OpenAI API
|
||||
OPENAI_IMAGE_GEN_STYLE = "vivid"
|
||||
response = openai_client.images.generate(
|
||||
prompt=improved_image_prompt,
|
||||
model=text2image_model,
|
||||
style=OPENAI_IMAGE_GEN_STYLE,
|
||||
size=size,
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
@@ -227,10 +247,22 @@ def generate_image_with_stability(
|
||||
reraise=True,
|
||||
)
|
||||
def generate_image_with_replicate(
|
||||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||
improved_image_prompt: str,
|
||||
text_to_image_config: TextToImageModelConfig,
|
||||
text2image_model: str,
|
||||
shape: ImageShape = ImageShape.SQUARE,
|
||||
):
|
||||
"Generate image using Replicate API"
|
||||
|
||||
# Convert shape to aspect ratio for Replicate
|
||||
# Replicate supports only 1:1, 3:4, and 4:3 aspect ratios
|
||||
if shape == ImageShape.PORTRAIT:
|
||||
aspect_ratio = "3:4"
|
||||
elif shape == ImageShape.LANDSCAPE:
|
||||
aspect_ratio = "4:3"
|
||||
else: # Square
|
||||
aspect_ratio = "1:1"
|
||||
|
||||
# Create image generation task on Replicate
|
||||
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
||||
headers = {
|
||||
@@ -241,7 +273,7 @@ def generate_image_with_replicate(
|
||||
"input": {
|
||||
"prompt": improved_image_prompt,
|
||||
"num_outputs": 1,
|
||||
"aspect_ratio": "1:1",
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": "webp",
|
||||
"output_quality": 100,
|
||||
}
|
||||
@@ -286,7 +318,10 @@ def generate_image_with_replicate(
|
||||
reraise=True,
|
||||
)
|
||||
def generate_image_with_google(
|
||||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||
improved_image_prompt: str,
|
||||
text_to_image_config: TextToImageModelConfig,
|
||||
text2image_model: str,
|
||||
shape: ImageShape = ImageShape.SQUARE,
|
||||
):
|
||||
"""Generate image using Google's AI over API"""
|
||||
|
||||
@@ -294,6 +329,14 @@ def generate_image_with_google(
|
||||
api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Convert shape to aspect ratio for Google
|
||||
if shape == ImageShape.PORTRAIT:
|
||||
aspect_ratio = "3:4"
|
||||
elif shape == ImageShape.LANDSCAPE:
|
||||
aspect_ratio = "4:3"
|
||||
else: # Square
|
||||
aspect_ratio = "1:1"
|
||||
|
||||
# Configure image generation settings
|
||||
config = gtypes.GenerateImagesConfig(
|
||||
number_of_images=1,
|
||||
@@ -301,6 +344,7 @@ def generate_image_with_google(
|
||||
person_generation=gtypes.PersonGeneration.ALLOW_ADULT,
|
||||
include_rai_reason=True,
|
||||
output_mime_type="image/png",
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
|
||||
# Call the Gemini API to generate the image
|
||||
|
||||
@@ -113,6 +113,7 @@ from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
LRU,
|
||||
ConversationCommand,
|
||||
ImageShape,
|
||||
ToolDefinition,
|
||||
get_file_type,
|
||||
in_debug_mode,
|
||||
@@ -1076,7 +1077,7 @@ async def generate_better_image_prompt(
|
||||
agent: Agent = None,
|
||||
query_files: str = "",
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
"""
|
||||
@@ -1104,10 +1105,16 @@ async def generate_better_image_prompt(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
class ImagePromptResponse(BaseModel):
|
||||
description: str = Field(description="Enhanced image description")
|
||||
shape: ImageShape = Field(
|
||||
description="Aspect ratio/shape best suited to render the image: Portrait, Landscape, or Square"
|
||||
)
|
||||
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
q,
|
||||
system_message=enhance_image_system_message,
|
||||
query_images=query_images,
|
||||
@@ -1115,13 +1122,19 @@ async def generate_better_image_prompt(
|
||||
chat_history=conversation_history,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
response_type="json_object",
|
||||
response_schema=ImagePromptResponse,
|
||||
tracer=tracer,
|
||||
)
|
||||
response_text = response.text.strip()
|
||||
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
||||
response_text = response_text[1:-1]
|
||||
|
||||
return response_text
|
||||
# Parse the structured response
|
||||
try:
|
||||
response = clean_json(raw_response.text)
|
||||
parsed_response = pyjson5.loads(response)
|
||||
return parsed_response
|
||||
except Exception:
|
||||
# Fallback to user query as image description
|
||||
return {"description": q, "shape": ImageShape.SQUARE}
|
||||
|
||||
|
||||
async def search_documents(
|
||||
|
||||
@@ -875,6 +875,12 @@ def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> s
|
||||
return output_data_uri
|
||||
|
||||
|
||||
class ImageShape(str, Enum):
|
||||
PORTRAIT = "Portrait"
|
||||
LANDSCAPE = "Landscape"
|
||||
SQUARE = "Square"
|
||||
|
||||
|
||||
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
|
||||
"""
|
||||
Truncate large output files and drop image file data from code results.
|
||||
|
||||
Reference in New Issue
Block a user