Add backend support for parsing and processing and storing mermaidjs diagrams

- Replace default diagram output from excalidraw to mermaid
- Retain typing of the excalidraw type for backwards compatibility in the chatlog
This commit is contained in:
sabaimran
2025-01-08 22:08:40 -08:00
parent c448c49811
commit 539ce99343
5 changed files with 274 additions and 21 deletions

View File

@@ -109,6 +109,7 @@ class ChatMessage(PydanticBaseModel):
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
excalidrawDiagram: Optional[List[Dict]] = None
mermaidjsDiagram: str = None
by: str
turnId: Optional[str] = None
intent: Optional[Intent] = None

View File

@@ -194,7 +194,7 @@ Limit your response to 3 sentences max. Be succinct, clear, and informative.
## Diagram Generation
## --
improve_diagram_description_prompt = PromptTemplate.from_template(
improve_excalidraw_diagram_description_prompt = PromptTemplate.from_template(
"""
You are an architect working with a novice digital artist using a diagramming software.
{personality_context}
@@ -338,6 +338,123 @@ Diagram Description: {query}
""".strip()
)
improve_mermaid_js_diagram_description_prompt = PromptTemplate.from_template(
"""
You are a senior architect working with an illustrator using a diagramming software.
{personality_context}
Given a particular request, you need to translate it to to a detailed description that the illustrator can use to create a diagram.
You can use the following diagram types in your instructions:
- Flowchart
- Sequence Diagram
- Gantt Chart (only for time-based queries after 0 AD)
- State Diagram
- Pie Chart
Use these primitives to describe what sort of diagram the drawer should create in natural language, not special syntax. We must recreate the diagram every time, so include all relevant prior information in your description.
- Describe the layout, components, and connections.
- Use simple, concise language.
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Query: {query}
Enhanced Description:
""".strip()
)
mermaid_js_diagram_generation_prompt = PromptTemplate.from_template(
"""
You are a designer with the ability to describe diagrams to compose in professional, fine detail. You dive into the details and make labels, connections, and shapes to represent complex systems.
{personality_context}
----Goals----
You need to create a declarative description of the diagram and relevant components, using the Mermaid.js syntax.
You can choose from the following diagram types:
- Flowchart
- Sequence Diagram
- State Diagram
- Gantt Chart
- Pie Chart
----Examples----
---
title: Node
---
flowchart LR
id["This is the start"] --> id2["This is the end"]
sequenceDiagram
Alice->>John: Hello John, how are you?
John-->>Alice: Great!
Alice-)John: See you later!
stateDiagram-v2
[*] --> Still
Still --> [*]
Still --> Moving
Moving --> Still
Moving --> Crash
Crash --> [*]
gantt
title A Gantt Diagram
dateFormat YYYY-MM-DD
section Section
A task :a1, 2014-01-01, 30d
Another task :after a1, 20d
section Another
Task in Another :2014-01-12, 12d
another task :24d
pie title Pets adopted by volunteers
"Dogs" : 10
"Cats" : 30
"Rats" : 60
flowchart TB
c1-->a2
subgraph one
a1-->a2
end
subgraph two
b1-->b2["this is b2"]
end
subgraph three
c1["this is c1"]-->c2["this is c2"]
end
one --> two
three --> two
two --> c2
----Process----
Create your diagram with great composition and intuitiveness from the provided context and user prompt below.
- You may use subgraphs to group elements together. Each subgraph must have a title.
- **You must wrap ALL entity and node labels in double quotes**. For example, "Entity Name".
- Custom style are not permitted. Default styles only.
- JUST provide the diagram, no additional text or context. Say nothing else in your response except the diagram.
- Keep diagrams simple - maximum 15 nodes
output: {query}
""".strip()
)
failed_diagram_generation = PromptTemplate.from_template(
"""
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.

View File

@@ -266,7 +266,7 @@ def save_to_conversation_log(
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
generated_mermaidjs_diagram: str = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@@ -290,8 +290,8 @@ def save_to_conversation_log(
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
}
if generated_excalidraw_diagram:
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
if generated_mermaidjs_diagram:
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
updated_conversation = message_to_log(
user_message=q,
@@ -441,7 +441,7 @@ def generate_chatml_messages_with_context(
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}
if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant":
generated_assets["diagram"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}
@@ -593,6 +593,11 @@ def clean_json(response: str):
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
def clean_mermaidjs(response: str):
"""Remove any markdown mermaidjs codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().removeprefix("```mermaid").removesuffix("```")
def clean_code_python(code: str):
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
return code.strip().removeprefix("```python").removesuffix("```")

View File

@@ -51,7 +51,7 @@ from khoj.routers.helpers import (
construct_automation_created_message,
create_automation,
gather_raw_query_files,
generate_excalidraw_diagram,
generate_mermaidjs_diagram,
generate_summary_from_files,
get_conversation_command,
is_query_empty,
@@ -781,7 +781,7 @@ async def chat(
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_excalidraw_diagram: str = None
generated_mermaidjs_diagram: str = None
program_execution_context: List[str] = []
if conversation_commands == [ConversationCommand.Default]:
@@ -1156,7 +1156,7 @@ async def chat(
inferred_queries = []
diagram_description = ""
async for result in generate_excalidraw_diagram(
async for result in generate_mermaidjs_diagram(
q=defiltered_query,
conversation_history=meta_log,
location_data=location,
@@ -1172,12 +1172,12 @@ async def chat(
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
better_diagram_description_prompt, excalidraw_diagram_description = result
if better_diagram_description_prompt and excalidraw_diagram_description:
better_diagram_description_prompt, mermaidjs_diagram_description = result
if better_diagram_description_prompt and mermaidjs_diagram_description:
inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description
diagram_description = mermaidjs_diagram_description
generated_excalidraw_diagram = diagram_description
generated_mermaidjs_diagram = diagram_description
generated_asset_results["diagrams"] = {
"query": better_diagram_description_prompt,
@@ -1186,7 +1186,7 @@ async def chat(
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
"excalidrawDiagram": excalidraw_diagram_description,
"mermaidjsDiagram": mermaidjs_diagram_description,
},
):
yield result
@@ -1226,7 +1226,7 @@ async def chat(
raw_query_files,
generated_images,
generated_files,
generated_excalidraw_diagram,
generated_mermaidjs_diagram,
program_execution_context,
generated_asset_results,
tracer,

View File

@@ -97,6 +97,7 @@ from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
clean_json,
clean_mermaidjs,
construct_chat_history,
generate_chatml_messages_with_context,
save_to_conversation_log,
@@ -823,7 +824,7 @@ async def generate_better_diagram_description(
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
location=location,
@@ -887,6 +888,135 @@ async def generate_excalidraw_diagram_from_description(
return response
async def generate_mermaidjs_diagram(
q: str,
conversation_history: Dict[str, Any],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
yield {ChatEvent.STATUS: event}
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
q=q,
conversation_history=conversation_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
if send_status_func:
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
yield {ChatEvent.STATUS: event}
mermaidjs_diagram_description = await generate_mermaidjs_diagram_from_description(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
inferred_queries = f"Instruction: {better_diagram_description_prompt}"
yield inferred_queries, mermaidjs_diagram_description
async def generate_better_mermaidjs_diagram_description(
q: str,
conversation_history: Dict[str, Any],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history = construct_chat_history(conversation_history)
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
tracer=tracer,
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response
async def generate_mermaidjs_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Dict[str, Any]:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
mermaidjs_diagram_generation = prompts.mermaid_js_diagram_generation_prompt.format(
personality_context=personality_context,
query=q,
)
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer)
return clean_mermaidjs(raw_response.strip())
return response
async def generate_better_image_prompt(
q: str,
conversation_history: str,
@@ -1222,7 +1352,7 @@ def generate_chat_response(
raw_query_files: List[FileAttachment] = None,
generated_images: List[str] = None,
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
generated_mermaidjs_diagram: str = None,
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
@@ -1250,7 +1380,7 @@ def generate_chat_response(
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=raw_generated_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
@@ -1965,7 +2095,7 @@ class MessageProcessor:
self.raw_response = ""
self.generated_images = []
self.generated_files = []
self.generated_excalidraw_diagram = []
self.generated_mermaidjs_diagram = []
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
@@ -2012,8 +2142,8 @@ class MessageProcessor:
self.generated_images = chunk_data[key]
elif key == "files":
self.generated_files = chunk_data[key]
elif key == "excalidrawDiagram":
self.generated_excalidraw_diagram = chunk_data[key]
elif key == "mermaidjsDiagram":
self.generated_mermaidjs_diagram = chunk_data[key]
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
@@ -2050,7 +2180,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
"usage": processor.usage,
"images": processor.generated_images,
"files": processor.generated_files,
"excalidrawDiagram": processor.generated_excalidraw_diagram,
"mermaidjsDiagram": processor.generated_mermaidjs_diagram,
}