Merge pull request #586 from khoj-ai/features/misc-image-and-online-improvements

Improvements to chat functionality and image generation
This commit is contained in:
sabaimran
2023-12-17 23:28:08 +05:30
committed by GitHub
5 changed files with 81 additions and 22 deletions

View File

@@ -179,9 +179,14 @@
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
@@ -357,6 +362,11 @@
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
@@ -454,7 +464,13 @@
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
@@ -572,7 +588,7 @@
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
});
})
.catch(err => {
@@ -903,11 +919,12 @@
}
.input-row-button {
background: var(--background-color);
border: none;
border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa;
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
padding: 0;
line-height: 1.5em;
cursor: pointer;
transition: background 0.3s ease-in-out;
@@ -989,7 +1006,6 @@
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;

View File

@@ -188,9 +188,14 @@ To get started, just start typing below. You can also type / to see a list of co
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
@@ -362,6 +367,11 @@ To get started, just start typing below. You can also type / to see a list of co
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
@@ -543,7 +553,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
});
})
.catch(err => {
@@ -952,7 +962,7 @@ To get started, just start typing below. You can also type / to see a list of co
grid-template-columns: auto 32px 32px;
grid-column-gap: 10px;
grid-row-gap: 10px;
background: #f9fafc
background: var(--background-color);
}
.option:hover {
box-shadow: 0 0 11px #aaa;
@@ -974,9 +984,9 @@ To get started, just start typing below. You can also type / to see a list of co
}
.input-row-button {
background: var(--background-color);
border: none;
border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa;
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;

View File

@@ -109,6 +109,18 @@ Question: {query}
""".strip()
)
## Image Generation
## --
image_generation_improve_prompt = PromptTemplate.from_template(
"""
Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation.
Query: {query}
Improved Query:"""
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(
@@ -295,10 +307,13 @@ Q:"""
# --
help_message = PromptTemplate.from_template(
"""
**/notes**: Chat using the information in your knowledge base.
**/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
**/default**: Chat using your knowledge base and Khoj's general knowledge for context.
**/help**: Show this help message.
- **/notes**: Chat using the information in your knowledge base.
- **/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
- **/default**: Chat using your knowledge base and Khoj's general knowledge for context.
- **/online**: Chat using the internet as a source of information.
- **/image**: Generate an image based on your message.
- **/help**: Show this help message.
You are using the **{model}** model on the **{device}**.
**version**: {version}

View File

@@ -721,7 +721,7 @@ async def chat(
metadata={"conversation_command": conversation_command.value},
**common.__dict__,
)
image, status_code = await text_to_image(q)
image, status_code, improved_image_prompt = await text_to_image(q)
if image is None:
content_obj = {
"image": image,
@@ -729,8 +729,10 @@ async def chat(
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image")
content_obj = {"image": image, "intentType": "text-to-image"}
await sync_to_async(save_to_conversation_log)(
q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt]
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.

View File

@@ -146,6 +146,20 @@ async def generate_online_subqueries(q: str) -> List[str]:
return [q]
async def generate_better_image_prompt(q: str) -> str:
"""
Generate a better image prompt from the given query
"""
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
)
response = await send_message_to_model_wrapper(image_prompt)
return response.strip()
async def send_message_to_model_wrapper(
message: str,
):
@@ -170,11 +184,13 @@ async def send_message_to_model_wrapper(
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
return send_message_to_model(
openai_response = send_message_to_model(
message=message,
api_key=api_key,
model=chat_model,
)
return openai_response.content
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@@ -250,27 +266,27 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
async def text_to_image(message: str) -> Tuple[Optional[str], int, Optional[str]]:
status_code = 200
image = None
# Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
text2image_model = text_to_image_config.model_name
improved_image_prompt = await generate_better_image_prompt(message)
try:
response = state.openai_client.images.generate(
prompt=message, model=text2image_model, response_format="b64_json"
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
status_code = 500
return image, status_code
return image, status_code, improved_image_prompt
class ApiUserRateLimiter: