Increase user visibility into more errors during image generation

Catch OpenAI connection error and errors during better image prompt
generation
This commit is contained in:
Debanjum Singh Solanky
2024-03-07 09:47:01 +05:30
parent ff31759423
commit 12d32ac99c

View File

@@ -467,7 +467,7 @@ async def text_to_image(
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: [generated image redacted by admin]. Enhanced image prompt: {chat['intent']['inferred-queries'][0]}\n" chat_history += f"A: [generated image redacted by admin]. Enhanced image prompt: {chat['intent']['inferred-queries'][0]}\n"
try:
with timer("Improve the original user query", logger): with timer("Improve the original user query", logger):
improved_image_prompt = await generate_better_image_prompt( improved_image_prompt = await generate_better_image_prompt(
message, message,
@@ -476,7 +476,6 @@ async def text_to_image(
note_references=references, note_references=references,
online_results=online_results, online_results=online_results,
) )
try:
with timer("Generate image with OpenAI", logger): with timer("Generate image with OpenAI", logger):
response = state.openai_client.images.generate( response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
@@ -486,7 +485,7 @@ async def text_to_image(
with timer("Upload image to S3", logger): with timer("Upload image to S3", logger):
image_url = upload_image(image, user.uuid) image_url = upload_image(image, user.uuid)
return image, status_code, improved_image_prompt, image_url return image, status_code, improved_image_prompt, image_url
except openai.OpenAIError or openai.BadRequestError as e: except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore