Increase user visibility into more errors during image generation

Catch OpenAI connection error and errors during better image prompt
generation
This commit is contained in:
Debanjum Singh Solanky
2024-03-07 09:47:01 +05:30
parent ff31759423
commit 12d32ac99c

View File

@@ -467,16 +467,15 @@ async def text_to_image(
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: [generated image redacted by admin]. Enhanced image prompt: {chat['intent']['inferred-queries'][0]}\n"
with timer("Improve the original user query", logger):
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
)
try:
with timer("Improve the original user query", logger):
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
)
with timer("Generate image with OpenAI", logger):
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
@@ -486,7 +485,7 @@ async def text_to_image(
with timer("Upload image to S3", logger):
image_url = upload_image(image, user.uuid)
return image, status_code, improved_image_prompt, image_url
except openai.OpenAIError or openai.BadRequestError as e:
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore