mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 21:29:12 +00:00
Do not require OpenAI to generate image as local chat + sd3 works now
Previously the text_to_image helper would only trigger the image generation flow if OpenAI client was setup. This is not required anymore as offline chat model + sd3 API works. So remove that check
This commit is contained in:
@@ -768,88 +768,94 @@ async def text_to_image(
|
|||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image_url or image, status_code, message, intent_type.value
|
return image_url or image, status_code, message, intent_type.value
|
||||||
elif state.openai_client:
|
|
||||||
logger.info("Generating image with OpenAI")
|
text2image_model = text_to_image_config.model_name
|
||||||
text2image_model = text_to_image_config.model_name
|
chat_history = ""
|
||||||
chat_history = ""
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"A: {chat['message']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
||||||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
chat_history += f"Q: Query: {chat['intent']['query']}\n"
|
||||||
chat_history += f"Q: Query: {chat['intent']['query']}\n"
|
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
|
|
||||||
try:
|
with timer("Improve the original user query", logger):
|
||||||
with timer("Improve the original user query", logger):
|
if send_status_func:
|
||||||
if send_status_func:
|
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
||||||
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
improved_image_prompt = await generate_better_image_prompt(
|
||||||
improved_image_prompt = await generate_better_image_prompt(
|
message,
|
||||||
message,
|
chat_history,
|
||||||
chat_history,
|
location_data=location_data,
|
||||||
location_data=location_data,
|
note_references=references,
|
||||||
note_references=references,
|
online_results=online_results,
|
||||||
online_results=online_results,
|
model_type=text_to_image_config.model_type,
|
||||||
model_type=text_to_image_config.model_type,
|
)
|
||||||
|
|
||||||
|
if send_status_func:
|
||||||
|
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||||
|
|
||||||
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
|
with timer("Generate image with OpenAI", logger):
|
||||||
|
try:
|
||||||
|
response = state.openai_client.images.generate(
|
||||||
|
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||||
)
|
)
|
||||||
if send_status_func:
|
image = response.data[0].b64_json
|
||||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
decoded_image = base64.b64decode(image)
|
||||||
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
|
if "content_policy_violation" in e.message:
|
||||||
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
|
status_code = e.status_code # type: ignore
|
||||||
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
|
return image_url or image, status_code, message, intent_type.value
|
||||||
|
else:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
|
status_code = e.status_code # type: ignore
|
||||||
|
return image_url or image, status_code, message, intent_type.value
|
||||||
|
|
||||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||||
with timer("Generate image with OpenAI", logger):
|
with timer("Generate image with Stability AI", logger):
|
||||||
response = state.openai_client.images.generate(
|
try:
|
||||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
response = requests.post(
|
||||||
)
|
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||||
image = response.data[0].b64_json
|
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
||||||
decoded_image = base64.b64decode(image)
|
files={"none": ""},
|
||||||
|
data={
|
||||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
"prompt": improved_image_prompt,
|
||||||
with timer("Generate image with Stability AI", logger):
|
"model": text2image_model,
|
||||||
response = requests.post(
|
"mode": "text-to-image",
|
||||||
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
"output_format": "png",
|
||||||
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
"seed": 1032622926,
|
||||||
files={"none": ""},
|
"aspect_ratio": "1:1",
|
||||||
data={
|
},
|
||||||
"prompt": improved_image_prompt,
|
)
|
||||||
"model": text2image_model,
|
decoded_image = response.content
|
||||||
"mode": "text-to-image",
|
except requests.RequestException as e:
|
||||||
"output_format": "png",
|
|
||||||
"seed": 1032622926,
|
|
||||||
"aspect_ratio": "1:1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
decoded_image = response.content
|
|
||||||
|
|
||||||
with timer("Convert image to webp", logger):
|
|
||||||
# Convert png to webp for faster loading
|
|
||||||
image_io = io.BytesIO(decoded_image)
|
|
||||||
png_image = Image.open(image_io)
|
|
||||||
webp_image_io = io.BytesIO()
|
|
||||||
png_image.save(webp_image_io, "WEBP")
|
|
||||||
webp_image_bytes = webp_image_io.getvalue()
|
|
||||||
webp_image_io.close()
|
|
||||||
image_io.close()
|
|
||||||
|
|
||||||
with timer("Upload image to S3", logger):
|
|
||||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
|
||||||
if image_url:
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
|
||||||
else:
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
||||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
|
||||||
|
|
||||||
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
|
||||||
if "content_policy_violation" in e.message:
|
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
|
||||||
status_code = e.status_code # type: ignore
|
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
|
||||||
return image_url or image, status_code, message, intent_type.value
|
|
||||||
else:
|
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed with Stability AI error: {e}"
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type.value
|
return image_url or image, status_code, message, intent_type.value
|
||||||
return image_url or image, status_code, response, intent_type.value
|
|
||||||
|
with timer("Convert image to webp", logger):
|
||||||
|
# Convert png to webp for faster loading
|
||||||
|
image_io = io.BytesIO(decoded_image)
|
||||||
|
png_image = Image.open(image_io)
|
||||||
|
webp_image_io = io.BytesIO()
|
||||||
|
png_image.save(webp_image_io, "WEBP")
|
||||||
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
|
webp_image_io.close()
|
||||||
|
image_io.close()
|
||||||
|
|
||||||
|
with timer("Upload image to S3", logger):
|
||||||
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
|
if image_url:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||||
|
else:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
|
|||||||
Reference in New Issue
Block a user