mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 21:29:08 +00:00
Add backend support for parsing and processing and storing mermaidjs diagrams
- Replace default diagram output from excalidraw to mermaid - Retain typing of the excalidraw type for backwards compatibility in the chatlog
This commit is contained in:
@@ -109,6 +109,7 @@ class ChatMessage(PydanticBaseModel):
|
||||
images: Optional[List[str]] = None
|
||||
queryFiles: Optional[List[Dict]] = None
|
||||
excalidrawDiagram: Optional[List[Dict]] = None
|
||||
mermaidjsDiagram: str = None
|
||||
by: str
|
||||
turnId: Optional[str] = None
|
||||
intent: Optional[Intent] = None
|
||||
|
||||
@@ -194,7 +194,7 @@ Limit your response to 3 sentences max. Be succinct, clear, and informative.
|
||||
## Diagram Generation
|
||||
## --
|
||||
|
||||
improve_diagram_description_prompt = PromptTemplate.from_template(
|
||||
improve_excalidraw_diagram_description_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are an architect working with a novice digital artist using a diagramming software.
|
||||
{personality_context}
|
||||
@@ -338,6 +338,123 @@ Diagram Description: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
improve_mermaid_js_diagram_description_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a senior architect working with an illustrator using a diagramming software.
|
||||
{personality_context}
|
||||
|
||||
Given a particular request, you need to translate it to to a detailed description that the illustrator can use to create a diagram.
|
||||
|
||||
You can use the following diagram types in your instructions:
|
||||
- Flowchart
|
||||
- Sequence Diagram
|
||||
- Gantt Chart (only for time-based queries after 0 AD)
|
||||
- State Diagram
|
||||
- Pie Chart
|
||||
|
||||
Use these primitives to describe what sort of diagram the drawer should create in natural language, not special syntax. We must recreate the diagram every time, so include all relevant prior information in your description.
|
||||
|
||||
- Describe the layout, components, and connections.
|
||||
- Use simple, concise language.
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
|
||||
Online References:
|
||||
{online_results}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
Query: {query}
|
||||
|
||||
Enhanced Description:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
mermaid_js_diagram_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a designer with the ability to describe diagrams to compose in professional, fine detail. You dive into the details and make labels, connections, and shapes to represent complex systems.
|
||||
{personality_context}
|
||||
|
||||
----Goals----
|
||||
You need to create a declarative description of the diagram and relevant components, using the Mermaid.js syntax.
|
||||
|
||||
You can choose from the following diagram types:
|
||||
- Flowchart
|
||||
- Sequence Diagram
|
||||
- State Diagram
|
||||
- Gantt Chart
|
||||
- Pie Chart
|
||||
|
||||
----Examples----
|
||||
---
|
||||
title: Node
|
||||
---
|
||||
|
||||
flowchart LR
|
||||
id["This is the start"] --> id2["This is the end"]
|
||||
|
||||
sequenceDiagram
|
||||
Alice->>John: Hello John, how are you?
|
||||
John-->>Alice: Great!
|
||||
Alice-)John: See you later!
|
||||
|
||||
stateDiagram-v2
|
||||
[*] --> Still
|
||||
Still --> [*]
|
||||
|
||||
Still --> Moving
|
||||
Moving --> Still
|
||||
Moving --> Crash
|
||||
Crash --> [*]
|
||||
|
||||
gantt
|
||||
title A Gantt Diagram
|
||||
dateFormat YYYY-MM-DD
|
||||
section Section
|
||||
A task :a1, 2014-01-01, 30d
|
||||
Another task :after a1, 20d
|
||||
section Another
|
||||
Task in Another :2014-01-12, 12d
|
||||
another task :24d
|
||||
|
||||
pie title Pets adopted by volunteers
|
||||
"Dogs" : 10
|
||||
"Cats" : 30
|
||||
"Rats" : 60
|
||||
|
||||
flowchart TB
|
||||
c1-->a2
|
||||
subgraph one
|
||||
a1-->a2
|
||||
end
|
||||
subgraph two
|
||||
b1-->b2["this is b2"]
|
||||
end
|
||||
subgraph three
|
||||
c1["this is c1"]-->c2["this is c2"]
|
||||
end
|
||||
one --> two
|
||||
three --> two
|
||||
two --> c2
|
||||
|
||||
----Process----
|
||||
Create your diagram with great composition and intuitiveness from the provided context and user prompt below.
|
||||
- You may use subgraphs to group elements together. Each subgraph must have a title.
|
||||
- **You must wrap ALL entity and node labels in double quotes**. For example, "Entity Name".
|
||||
- Custom style are not permitted. Default styles only.
|
||||
- JUST provide the diagram, no additional text or context. Say nothing else in your response except the diagram.
|
||||
- Keep diagrams simple - maximum 15 nodes
|
||||
|
||||
output: {query}
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
failed_diagram_generation = PromptTemplate.from_template(
|
||||
"""
|
||||
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.
|
||||
|
||||
@@ -266,7 +266,7 @@ def save_to_conversation_log(
|
||||
raw_query_files: List[FileAttachment] = [],
|
||||
generated_images: List[str] = [],
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_excalidraw_diagram: str = None,
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
@@ -290,8 +290,8 @@ def save_to_conversation_log(
|
||||
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
|
||||
}
|
||||
|
||||
if generated_excalidraw_diagram:
|
||||
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
|
||||
if generated_mermaidjs_diagram:
|
||||
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
|
||||
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
@@ -441,7 +441,7 @@ def generate_chatml_messages_with_context(
|
||||
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||
}
|
||||
|
||||
if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
|
||||
if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant":
|
||||
generated_assets["diagram"] = {
|
||||
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||
}
|
||||
@@ -593,6 +593,11 @@ def clean_json(response: str):
|
||||
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
||||
|
||||
|
||||
def clean_mermaidjs(response: str):
|
||||
"""Remove any markdown mermaidjs codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return response.strip().removeprefix("```mermaid").removesuffix("```")
|
||||
|
||||
|
||||
def clean_code_python(code: str):
|
||||
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return code.strip().removeprefix("```python").removesuffix("```")
|
||||
|
||||
@@ -51,7 +51,7 @@ from khoj.routers.helpers import (
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
gather_raw_query_files,
|
||||
generate_excalidraw_diagram,
|
||||
generate_mermaidjs_diagram,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
@@ -781,7 +781,7 @@ async def chat(
|
||||
|
||||
generated_images: List[str] = []
|
||||
generated_files: List[FileAttachment] = []
|
||||
generated_excalidraw_diagram: str = None
|
||||
generated_mermaidjs_diagram: str = None
|
||||
program_execution_context: List[str] = []
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
@@ -1156,7 +1156,7 @@ async def chat(
|
||||
inferred_queries = []
|
||||
diagram_description = ""
|
||||
|
||||
async for result in generate_excalidraw_diagram(
|
||||
async for result in generate_mermaidjs_diagram(
|
||||
q=defiltered_query,
|
||||
conversation_history=meta_log,
|
||||
location_data=location,
|
||||
@@ -1172,12 +1172,12 @@ async def chat(
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
better_diagram_description_prompt, excalidraw_diagram_description = result
|
||||
if better_diagram_description_prompt and excalidraw_diagram_description:
|
||||
better_diagram_description_prompt, mermaidjs_diagram_description = result
|
||||
if better_diagram_description_prompt and mermaidjs_diagram_description:
|
||||
inferred_queries.append(better_diagram_description_prompt)
|
||||
diagram_description = excalidraw_diagram_description
|
||||
diagram_description = mermaidjs_diagram_description
|
||||
|
||||
generated_excalidraw_diagram = diagram_description
|
||||
generated_mermaidjs_diagram = diagram_description
|
||||
|
||||
generated_asset_results["diagrams"] = {
|
||||
"query": better_diagram_description_prompt,
|
||||
@@ -1186,7 +1186,7 @@ async def chat(
|
||||
async for result in send_event(
|
||||
ChatEvent.GENERATED_ASSETS,
|
||||
{
|
||||
"excalidrawDiagram": excalidraw_diagram_description,
|
||||
"mermaidjsDiagram": mermaidjs_diagram_description,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
@@ -1226,7 +1226,7 @@ async def chat(
|
||||
raw_query_files,
|
||||
generated_images,
|
||||
generated_files,
|
||||
generated_excalidraw_diagram,
|
||||
generated_mermaidjs_diagram,
|
||||
program_execution_context,
|
||||
generated_asset_results,
|
||||
tracer,
|
||||
|
||||
@@ -97,6 +97,7 @@ from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
generate_chatml_messages_with_context,
|
||||
save_to_conversation_log,
|
||||
@@ -823,7 +824,7 @@ async def generate_better_diagram_description(
|
||||
elif online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
|
||||
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
@@ -887,6 +888,135 @@ async def generate_excalidraw_diagram_from_description(
|
||||
return response
|
||||
|
||||
|
||||
async def generate_mermaidjs_diagram(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
|
||||
q=q,
|
||||
conversation_history=conversation_history,
|
||||
location_data=location_data,
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
mermaidjs_diagram_description = await generate_mermaidjs_diagram_from_description(
|
||||
q=better_diagram_description_prompt,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
inferred_queries = f"Instruction: {better_diagram_description_prompt}"
|
||||
|
||||
yield inferred_queries, mermaidjs_diagram_description
|
||||
|
||||
|
||||
async def generate_better_mermaidjs_diagram_description(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a diagram description from the given query and context
|
||||
"""
|
||||
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
|
||||
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
simplified_online_results = {}
|
||||
|
||||
if online_results:
|
||||
for result in online_results:
|
||||
if online_results[result].get("answerBox"):
|
||||
simplified_online_results[result] = online_results[result]["answerBox"]
|
||||
elif online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_mermaidjs_diagram_from_description(
|
||||
q: str,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Dict[str, Any]:
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
mermaidjs_diagram_generation = prompts.mermaid_js_diagram_generation_prompt.format(
|
||||
personality_context=personality_context,
|
||||
query=q,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer)
|
||||
return clean_mermaidjs(raw_response.strip())
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_better_image_prompt(
|
||||
q: str,
|
||||
conversation_history: str,
|
||||
@@ -1222,7 +1352,7 @@ def generate_chat_response(
|
||||
raw_query_files: List[FileAttachment] = None,
|
||||
generated_images: List[str] = None,
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_excalidraw_diagram: str = None,
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
program_execution_context: List[str] = [],
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
tracer: dict = {},
|
||||
@@ -1250,7 +1380,7 @@ def generate_chat_response(
|
||||
raw_query_files=raw_query_files,
|
||||
generated_images=generated_images,
|
||||
raw_generated_files=raw_generated_files,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1965,7 +2095,7 @@ class MessageProcessor:
|
||||
self.raw_response = ""
|
||||
self.generated_images = []
|
||||
self.generated_files = []
|
||||
self.generated_excalidraw_diagram = []
|
||||
self.generated_mermaidjs_diagram = []
|
||||
|
||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
||||
@@ -2012,8 +2142,8 @@ class MessageProcessor:
|
||||
self.generated_images = chunk_data[key]
|
||||
elif key == "files":
|
||||
self.generated_files = chunk_data[key]
|
||||
elif key == "excalidrawDiagram":
|
||||
self.generated_excalidraw_diagram = chunk_data[key]
|
||||
elif key == "mermaidjsDiagram":
|
||||
self.generated_mermaidjs_diagram = chunk_data[key]
|
||||
|
||||
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
||||
if "image" in json_data or "details" in json_data:
|
||||
@@ -2050,7 +2180,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||
"usage": processor.usage,
|
||||
"images": processor.generated_images,
|
||||
"files": processor.generated_files,
|
||||
"excalidrawDiagram": processor.generated_excalidraw_diagram,
|
||||
"mermaidjsDiagram": processor.generated_mermaidjs_diagram,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user