From fdd4c0246168c91a2afd44500759ef48fb8cf6ec Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 25 Apr 2024 23:07:01 +0530 Subject: [PATCH] Use shorter prompt generator to prompt SD3 to create better images --- src/khoj/processor/conversation/prompts.py | 31 +++++++++++++++++++++- src/khoj/routers/helpers.py | 29 ++++++++++++++------ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index a1c7dff1..fcc9fc63 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -121,7 +121,7 @@ User's Notes: ## Image Generation ## -- -image_generation_improve_prompt = PromptTemplate.from_template( +image_generation_improve_prompt_dalle = PromptTemplate.from_template( """ You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt: @@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a Improved Query:""" ) +image_generation_improve_prompt_sd = PromptTemplate.from_template( + """ +You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image. +Use the provided context below to add specific, fine details to the image composition. +Retain any important information and follow any instructions from the original prompt. +Put any text to be rendered in the image within double quotes in your improved prompt. +You are provided with the following context to help enhance the original prompt: + +Today's Date: {current_date} +User's Location: {location} + +User's Notes: +{references} + +Online References: +{online_results} + +Conversation Log: +{chat_history} + +Original Prompt: "{query}" + +Now create an improved prompt using the context provided above to generate an image. +Retain any important information and follow any instructions from the original prompt. +Use the additional context from the user's notes, online references and conversation log to improve the image generation. + +Improved Prompt:""" +) + ## Online Search Conversation ## -- online_search_conversation = PromptTemplate.from_template( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8cc40f80..66807cc3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -453,12 +453,14 @@ async def generate_better_image_prompt( location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, + model_type: Optional[str] = None, ) -> str: """ Generate a better image prompt from the given query """ today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + model_type = model_type or TextToImageModelConfig.ModelType.OPENAI if location_data: location = f"{location_data.city}, {location_data.region}, {location_data.country}" @@ -477,14 +479,24 @@ async def generate_better_image_prompt( elif online_results[result].get("webpages"): simplified_online_results[result] = online_results[result]["webpages"] - image_prompt = prompts.image_generation_improve_prompt.format( - query=q, - chat_history=conversation_history, - location=location_prompt, - current_date=today_date, - references=user_references, - online_results=simplified_online_results, - ) + if model_type == TextToImageModelConfig.ModelType.OPENAI: + image_prompt = prompts.image_generation_improve_prompt_dalle.format( + query=q, + chat_history=conversation_history, + location=location_prompt, + current_date=today_date, + references=user_references, + online_results=simplified_online_results, + ) + elif model_type == TextToImageModelConfig.ModelType.STABILITYAI: + image_prompt = prompts.image_generation_improve_prompt_sd.format( + query=q, + chat_history=conversation_history, + location=location_prompt, + current_date=today_date, + references=user_references, + online_results=simplified_online_results, + ) summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() @@ -774,6 +786,7 @@ async def text_to_image( location_data=location_data, note_references=references, online_results=online_results, + model_type=text_to_image_config.model_type, ) if send_status_func: await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")