diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 5e7d67fa..1be520f9 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -122,7 +122,8 @@ Your image description will be transformed into an image by an AI model on your # Instructions - Retain important information and follow instructions by the user when composing the image description. - Weave in the context provided below if it will enhance the image. -- Specify desired elements, lighting, mood, and composition. +- Specify desired elements, lighting, mood, and composition in the description. +- Decide the shape best suited to render the image. It can be one of square, portrait or landscape. - Add specific, fine position details. Mention painting style, camera parameters to compose the image. - Transform any negations in user instructions into positive alternatives. Instead of saying what should NOT be in the image, describe what SHOULD be there instead. @@ -142,9 +143,8 @@ Your image description will be transformed into an image by an AI model on your ## Online References {online_results} -Now generate a vivid description of the image to be rendered. - -Image Description: +Now generate a vivid description of the image and image shape to be rendered. +Your response should be a JSON object with 'description' and 'shape' fields specified. """.strip() ) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index f62e751c..7ff83829 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -27,7 +27,7 @@ from khoj.database.models import ( TextToImageModelConfig, ) from khoj.processor.conversation.google.utils import _is_retryable_error -from khoj.routers.helpers import ChatEvent, generate_better_image_prompt +from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt from khoj.routers.storage import upload_generated_image_to_bucket from khoj.utils import state from khoj.utils.helpers import convert_image_to_webp, timer @@ -80,7 +80,7 @@ async def text_to_image( # Generate a better image prompt # Use the user's message, chat history, and other context - image_prompt = await generate_better_image_prompt( + image_prompt_response = await generate_better_image_prompt( message, image_chat_history, location_data=location_data, @@ -93,6 +93,8 @@ async def text_to_image( query_files=query_files, tracer=tracer, ) + image_prompt = image_prompt_response["description"] + image_shape = image_prompt_response["shape"] if send_status_func: async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"): @@ -102,13 +104,19 @@ async def text_to_image( with timer(f"Generate image with {text_to_image_config.model_type}", logger): try: if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model) + webp_image_bytes = generate_image_with_openai( + image_prompt, text_to_image_config, text2image_model, image_shape + ) elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model) elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE: - webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model) + webp_image_bytes = generate_image_with_replicate( + image_prompt, text_to_image_config, text2image_model, image_shape + ) elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.GOOGLE: - webp_image_bytes = generate_image_with_google(image_prompt, text_to_image_config, text2image_model) + webp_image_bytes = generate_image_with_google( + image_prompt, text_to_image_config, text2image_model, image_shape + ) except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") @@ -159,7 +167,10 @@ async def text_to_image( reraise=True, ) def generate_image_with_openai( - improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str + improved_image_prompt: str, + text_to_image_config: TextToImageModelConfig, + text2image_model: str, + shape: ImageShape = ImageShape.SQUARE, ): "Generate image using OpenAI (compatible) API" @@ -175,12 +186,21 @@ def generate_image_with_openai( elif state.openai_client: openai_client = state.openai_client + # Convert shape to size for OpenAI + if shape == ImageShape.PORTRAIT: + size = "1024x1536" + elif shape == ImageShape.LANDSCAPE: + size = "1536x1024" + else: # Square + size = "1024x1024" + # Generate image using OpenAI API OPENAI_IMAGE_GEN_STYLE = "vivid" response = openai_client.images.generate( prompt=improved_image_prompt, model=text2image_model, style=OPENAI_IMAGE_GEN_STYLE, + size=size, response_format="b64_json", ) @@ -227,10 +247,22 @@ def generate_image_with_stability( reraise=True, ) def generate_image_with_replicate( - improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str + improved_image_prompt: str, + text_to_image_config: TextToImageModelConfig, + text2image_model: str, + shape: ImageShape = ImageShape.SQUARE, ): "Generate image using Replicate API" + # Convert shape to aspect ratio for Replicate + # Replicate supports only 1:1, 3:4, and 4:3 aspect ratios + if shape == ImageShape.PORTRAIT: + aspect_ratio = "3:4" + elif shape == ImageShape.LANDSCAPE: + aspect_ratio = "4:3" + else: # Square + aspect_ratio = "1:1" + # Create image generation task on Replicate replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions" headers = { @@ -241,7 +273,7 @@ def generate_image_with_replicate( "input": { "prompt": improved_image_prompt, "num_outputs": 1, - "aspect_ratio": "1:1", + "aspect_ratio": aspect_ratio, "output_format": "webp", "output_quality": 100, } @@ -286,7 +318,10 @@ def generate_image_with_replicate( reraise=True, ) def generate_image_with_google( - improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str + improved_image_prompt: str, + text_to_image_config: TextToImageModelConfig, + text2image_model: str, + shape: ImageShape = ImageShape.SQUARE, ): """Generate image using Google's AI over API""" @@ -294,6 +329,14 @@ def generate_image_with_google( api_key = text_to_image_config.api_key or text_to_image_config.ai_model_api.api_key client = genai.Client(api_key=api_key) + # Convert shape to aspect ratio for Google + if shape == ImageShape.PORTRAIT: + aspect_ratio = "3:4" + elif shape == ImageShape.LANDSCAPE: + aspect_ratio = "4:3" + else: # Square + aspect_ratio = "1:1" + # Configure image generation settings config = gtypes.GenerateImagesConfig( number_of_images=1, @@ -301,6 +344,7 @@ def generate_image_with_google( person_generation=gtypes.PersonGeneration.ALLOW_ADULT, include_rai_reason=True, output_mime_type="image/png", + aspect_ratio=aspect_ratio, ) # Call the Gemini API to generate the image diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e9cf748d..ac274a12 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -113,6 +113,7 @@ from khoj.utils import state from khoj.utils.helpers import ( LRU, ConversationCommand, + ImageShape, ToolDefinition, get_file_type, in_debug_mode, @@ -1076,7 +1077,7 @@ async def generate_better_image_prompt( agent: Agent = None, query_files: str = "", tracer: dict = {}, -) -> str: +) -> dict: """ Generate a better image prompt from the given query """ @@ -1104,10 +1105,16 @@ async def generate_better_image_prompt( personality_context=personality_context, ) + class ImagePromptResponse(BaseModel): + description: str = Field(description="Enhanced image description") + shape: ImageShape = Field( + description="Aspect ratio/shape best suited to render the image: Portrait, Landscape, or Square" + ) + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( q, system_message=enhance_image_system_message, query_images=query_images, @@ -1115,13 +1122,19 @@ async def generate_better_image_prompt( chat_history=conversation_history, agent_chat_model=agent_chat_model, user=user, + response_type="json_object", + response_schema=ImagePromptResponse, tracer=tracer, ) - response_text = response.text.strip() - if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")): - response_text = response_text[1:-1] - return response_text + # Parse the structured response + try: + response = clean_json(raw_response.text) + parsed_response = pyjson5.loads(response) + return parsed_response + except Exception: + # Fallback to user query as image description + return {"description": q, "shape": ImageShape.SQUARE} async def search_documents( diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 5efafe15..bdd8c60c 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -875,6 +875,12 @@ def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> s return output_data_uri +class ImageShape(str, Enum): + PORTRAIT = "Portrait" + LANDSCAPE = "Landscape" + SQUARE = "Square" + + def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]: """ Truncate large output files and drop image file data from code results.