Merge pull request #1054 from khoj-ai/features/add-support-for-mermaidjs

We've been having issues generating diagrams with Excalidraw that are any degree of complexity. By contrast, LLMs are able to handle Mermaid.js syntax a lot better, as it's much more forgiving and has a simpler declarative style. Refer to https://mermaid.js.org/.

Update so that new diagrams are generated with Mermaid.js, while old diagrams generated with Excalidraw can still be viewed.
This commit is contained in:
sabaimran
2025-01-15 11:55:12 -08:00
committed by GitHub
13 changed files with 1395 additions and 44 deletions

View File

@@ -487,7 +487,8 @@ export class KhojChatView extends KhojPaneView {
inferredQueries?: string[],
conversationId?: string,
images?: string[],
excalidrawDiagram?: string
excalidrawDiagram?: string,
mermaidjsDiagram?: string
) {
if (!message) return;
@@ -496,8 +497,9 @@ export class KhojChatView extends KhojPaneView {
intentType?.includes("text-to-image") ||
intentType === "excalidraw" ||
(images && images.length > 0) ||
mermaidjsDiagram ||
excalidrawDiagram) {
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram);
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram, mermaidjsDiagram);
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
} else {
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
@@ -517,7 +519,7 @@ export class KhojChatView extends KhojPaneView {
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
}
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string {
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string, mermaidjsDiagram?: string): string {
let imageMarkdown = "";
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
@@ -529,6 +531,8 @@ export class KhojChatView extends KhojPaneView {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
imageMarkdown = redirectMessage;
} else if (mermaidjsDiagram) {
imageMarkdown = "```mermaid\n" + mermaidjsDiagram + "\n```";
} else if (images && images.length > 0) {
imageMarkdown += images.map(image => `![](${image})`).join('\n\n');
imageMarkdown += message;
@@ -901,6 +905,7 @@ export class KhojChatView extends KhojPaneView {
chatBodyEl.dataset.conversationId ?? "",
chatLog.images,
chatLog.excalidrawDiagram,
chatLog.mermaidjsDiagram,
);
// push the user messages to the chat history
if (chatLog.by === "you") {
@@ -1005,7 +1010,7 @@ export class KhojChatView extends KhojPaneView {
}
handleJsonResponse(jsonData: any): void {
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) {
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.mermaidjsDiagram) {
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
} else if (jsonData.response) {
this.chatMessageState.rawResponse = jsonData.response;
@@ -1395,6 +1400,8 @@ export class KhojChatView extends KhojPaneView {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
rawResponse += redirectMessage;
} else if (imageJson.mermaidjsDiagram) {
rawResponse += imageJson.mermaidjsDiagram;
}
// If response has detail field, response is an error message.

View File

@@ -19,7 +19,7 @@ export interface MessageMetadata {
export interface GeneratedAssetsData {
images: string[];
excalidrawDiagram: string;
mermaidjsDiagram: string;
files: AttachedFileText[];
}
@@ -114,8 +114,8 @@ export function processMessageChunk(
currentMessage.generatedImages = generatedAssets.images;
}
if (generatedAssets.excalidrawDiagram) {
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
if (generatedAssets.mermaidjsDiagram) {
currentMessage.generatedMermaidjsDiagram = generatedAssets.mermaidjsDiagram;
}
if (generatedAssets.files) {

View File

@@ -418,7 +418,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
conversationId: props.conversationId,
images: message.generatedImages,
queryFiles: message.generatedFiles,
excalidrawDiagram: message.generatedExcalidrawDiagram,
mermaidjsDiagram: message.generatedMermaidjsDiagram,
turnId: messageTurnId,
}}
conversationId={props.conversationId}

View File

@@ -53,6 +53,7 @@ import { DialogTitle } from "@radix-ui/react-dialog";
import { convertBytesToText } from "@/app/common/utils";
import { ScrollArea } from "@/components/ui/scroll-area";
import { getIconFromFilename } from "@/app/common/iconUtils";
import Mermaid from "../mermaid/mermaid";
const md = new markdownIt({
html: true,
@@ -164,6 +165,7 @@ export interface SingleChatMessage {
turnId?: string;
queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
mermaidjsDiagram?: string;
}
export interface StreamMessage {
@@ -182,9 +184,11 @@ export interface StreamMessage {
turnId?: string;
queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
mermaidjsDiagram?: string;
generatedFiles?: AttachedFileText[];
generatedImages?: string[];
generatedExcalidrawDiagram?: string;
generatedMermaidjsDiagram?: string;
}
export interface ChatHistoryData {
@@ -271,6 +275,7 @@ interface ChatMessageProps {
turnId?: string;
generatedImage?: string;
excalidrawDiagram?: string;
mermaidjsDiagram?: string;
generatedFiles?: AttachedFileText[];
}
@@ -358,6 +363,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
const [isPlaying, setIsPlaying] = useState<boolean>(false);
const [interrupted, setInterrupted] = useState<boolean>(false);
const [excalidrawData, setExcalidrawData] = useState<string>("");
const [mermaidjsData, setMermaidjsData] = useState<string>("");
const interruptedRef = useRef<boolean>(false);
const messageRef = useRef<HTMLDivElement>(null);
@@ -401,6 +407,10 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
setExcalidrawData(props.chatMessage.excalidrawDiagram);
}
if (props.chatMessage.mermaidjsDiagram) {
setMermaidjsData(props.chatMessage.mermaidjsDiagram);
}
// Replace LaTeX delimiters with placeholders
message = message
.replace(/\\\(/g, "LEFTPAREN")
@@ -718,6 +728,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
dangerouslySetInnerHTML={{ __html: markdownRendered }}
/>
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
{mermaidjsData && <Mermaid chart={mermaidjsData} />}
</div>
<div className={styles.teaserReferencesContainer}>
<TeaserReferencesSection

View File

@@ -1,4 +1,9 @@
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { CircleNotch } from "@phosphor-icons/react";
import { AppSidebar } from "../appSidebar/appSidebar";
import { Separator } from "@/components/ui/separator";
import { useIsMobileWidth } from "@/app/common/utils";
import { KhojLogoType } from "../logo/khojLogo";
interface LoadingProps {
className?: string;
@@ -7,21 +12,39 @@ interface LoadingProps {
}
export default function Loading(props: LoadingProps) {
const isMobileWidth = useIsMobileWidth();
return (
// NOTE: We can display usage tips here for casual learning moments.
<div
className={
props.className ||
"bg-background opacity-50 flex items-center justify-center h-screen"
}
>
<div>
{props.message || "Loading"}{" "}
<span>
<CircleNotch className="inline animate-spin h-5 w-5" />
</span>
<SidebarProvider>
<AppSidebar conversationId={""} />
<SidebarInset>
<header className="flex h-16 shrink-0 items-center gap-2 border-b px-4">
<SidebarTrigger className="-ml-1" />
<Separator orientation="vertical" className="mr-2 h-4" />
{isMobileWidth ? (
<a className="p-0 no-underline" href="/">
<KhojLogoType className="h-auto w-16" />
</a>
) : (
<h2 className="text-lg">Ask Anything</h2>
)}
</header>
</SidebarInset>
<div
className={
props.className ||
"bg-background opacity-50 flex items-center justify-center h-full w-full fixed top-0 left-0 z-50"
}
>
<div>
{props.message || "Loading"}{" "}
<span>
<CircleNotch className="inline animate-spin h-5 w-5" />
</span>
</div>
</div>
</div>
</SidebarProvider>
);
}

View File

@@ -0,0 +1,173 @@
import React, { useEffect, useState, useRef } from "react";
import mermaid from "mermaid";
import { Download, Info } from "@phosphor-icons/react";
import { Button } from "@/components/ui/button";
interface MermaidProps {
chart: string;
}
const Mermaid: React.FC<MermaidProps> = ({ chart }) => {
const [mermaidError, setMermaidError] = useState<string | null>(null);
const [mermaidId] = useState(`mermaid-chart-${Math.random().toString(12).substring(7)}`);
const elementRef = useRef<HTMLDivElement>(null);
useEffect(() => {
mermaid.initialize({
startOnLoad: false,
});
mermaid.parseError = (error) => {
console.error("Mermaid errors:", error);
// Extract error message from error object
// Parse error message safely
let errorMessage;
try {
errorMessage = typeof error === "string" ? JSON.parse(error) : error;
} catch (e) {
errorMessage = error?.toString() || "Unknown error";
}
console.log("Mermaid error message:", errorMessage);
if (errorMessage.str !== "element is null") {
setMermaidError(
"Something went wrong while rendering the diagram. Please try again later or downvote the message if the issue persists.",
);
} else {
setMermaidError(null);
}
};
mermaid.contentLoaded();
}, []);
const handleExport = async () => {
if (!elementRef.current) return;
try {
// Get SVG element
const svgElement = elementRef.current.querySelector("svg");
if (!svgElement) throw new Error("No SVG found");
// Get SVG viewBox dimensions
const viewBox = svgElement.getAttribute("viewBox")?.split(" ").map(Number) || [
0, 0, 0, 0,
];
const [, , viewBoxWidth, viewBoxHeight] = viewBox;
// Create canvas with viewBox dimensions
const canvas = document.createElement("canvas");
const scale = 2; // For better resolution
canvas.width = viewBoxWidth * scale;
canvas.height = viewBoxHeight * scale;
const ctx = canvas.getContext("2d");
if (!ctx) throw new Error("Failed to get canvas context");
// Convert SVG to data URL
const svgData = new XMLSerializer().serializeToString(svgElement);
const svgBlob = new Blob([svgData], { type: "image/svg+xml;charset=utf-8" });
const svgUrl = URL.createObjectURL(svgBlob);
// Create and load image
const img = new Image();
img.src = svgUrl;
await new Promise((resolve, reject) => {
img.onload = () => {
// Scale context for better resolution
ctx.scale(scale, scale);
ctx.drawImage(img, 0, 0, viewBoxWidth, viewBoxHeight);
canvas.toBlob((blob) => {
if (!blob) {
reject(new Error("Failed to create blob"));
return;
}
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `mermaid-diagram-${Date.now()}.png`;
a.click();
// Cleanup
URL.revokeObjectURL(url);
URL.revokeObjectURL(svgUrl);
resolve(true);
}, "image/png");
};
img.onerror = () => reject(new Error("Failed to load SVG"));
});
} catch (error) {
console.error("Error exporting diagram:", error);
setMermaidError("Failed to export diagram");
}
};
useEffect(() => {
if (elementRef.current) {
elementRef.current.removeAttribute("data-processed");
mermaid
.run({
nodes: [elementRef.current],
})
.then(() => {
setMermaidError(null);
})
.catch((error) => {
let errorMessage;
try {
errorMessage = typeof error === "string" ? JSON.parse(error) : error;
} catch (e) {
errorMessage = error?.toString() || "Unknown error";
}
console.log("Mermaid error message:", errorMessage);
if (errorMessage.str !== "element is null") {
setMermaidError(
"Something went wrong while rendering the diagram. Please try again later or downvote the message if the issue persists.",
);
} else {
setMermaidError(null);
}
});
}
}, [chart]);
return (
<div>
{mermaidError ? (
<div className="flex items-center gap-2 bg-red-100 border border-red-500 rounded-md p-3 mt-3 text-red-900 text-sm">
<Info className="w-12 h-12" />
<span>Error rendering diagram: {mermaidError}</span>
</div>
) : (
<div
id={mermaidId}
ref={elementRef}
className="mermaid"
style={{
width: "auto",
height: "auto",
boxSizing: "border-box",
overflow: "auto",
}}
>
{chart}
</div>
)}
{!mermaidError && (
<Button onClick={handleExport} variant={"secondary"} className="mt-3">
<Download className="w-5 h-5" />
Export as PNG
</Button>
)}
</div>
);
};
export default Mermaid;

View File

@@ -58,6 +58,7 @@
"lucide-react": "^0.468.0",
"markdown-it": "^14.1.0",
"markdown-it-highlightjs": "^4.1.0",
"mermaid": "^11.4.1",
"next": "14.2.15",
"nodemon": "^3.1.3",
"postcss": "^8.4.38",

File diff suppressed because it is too large Load Diff

View File

@@ -109,6 +109,7 @@ class ChatMessage(PydanticBaseModel):
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
excalidrawDiagram: Optional[List[Dict]] = None
mermaidjsDiagram: str = None
by: str
turnId: Optional[str] = None
intent: Optional[Intent] = None

View File

@@ -194,7 +194,7 @@ Limit your response to 3 sentences max. Be succinct, clear, and informative.
## Diagram Generation
## --
improve_diagram_description_prompt = PromptTemplate.from_template(
improve_excalidraw_diagram_description_prompt = PromptTemplate.from_template(
"""
You are an architect working with a novice digital artist using a diagramming software.
{personality_context}
@@ -338,6 +338,123 @@ Diagram Description: {query}
""".strip()
)
improve_mermaid_js_diagram_description_prompt = PromptTemplate.from_template(
"""
You are a senior architect working with an illustrator using a diagramming software.
{personality_context}
Given a particular request, you need to translate it to to a detailed description that the illustrator can use to create a diagram.
You can use the following diagram types in your instructions:
- Flowchart
- Sequence Diagram
- Gantt Chart (only for time-based queries after 0 AD)
- State Diagram
- Pie Chart
Use these primitives to describe what sort of diagram the drawer should create in natural language, not special syntax. We must recreate the diagram every time, so include all relevant prior information in your description.
- Describe the layout, components, and connections.
- Use simple, concise language.
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Query: {query}
Enhanced Description:
""".strip()
)
mermaid_js_diagram_generation_prompt = PromptTemplate.from_template(
"""
You are a designer with the ability to describe diagrams to compose in professional, fine detail. You dive into the details and make labels, connections, and shapes to represent complex systems.
{personality_context}
----Goals----
You need to create a declarative description of the diagram and relevant components, using the Mermaid.js syntax.
You can choose from the following diagram types:
- Flowchart
- Sequence Diagram
- State Diagram
- Gantt Chart
- Pie Chart
----Examples----
---
title: Node
---
flowchart LR
id["This is the start"] --> id2["This is the end"]
sequenceDiagram
Alice->>John: Hello John, how are you?
John-->>Alice: Great!
Alice-)John: See you later!
stateDiagram-v2
[*] --> Still
Still --> [*]
Still --> Moving
Moving --> Still
Moving --> Crash
Crash --> [*]
gantt
title A Gantt Diagram
dateFormat YYYY-MM-DD
section Section
A task :a1, 2014-01-01, 30d
Another task :after a1, 20d
section Another
Task in Another :2014-01-12, 12d
another task :24d
pie title Pets adopted by volunteers
"Dogs" : 10
"Cats" : 30
"Rats" : 60
flowchart TB
subgraph "Group 1"
a1["Start Node"] --> a2["End Node"]
end
subgraph "Group 2"
b1["Process 1"] --> b2["Process 2"]
end
subgraph "Group 3"
c1["Input"] --> c2["Output"]
end
a["Group 1"] --> b["Group 2"]
c["Group 3"] --> d["Group 2"]
----Process----
Create your diagram with great composition and intuitiveness from the provided context and user prompt below.
- You may use subgraphs to group elements together. Each subgraph must have a title.
- **You must wrap ALL entity and node labels in double quotes**, example: "My Node Label"
- **All nodes MUST use the id["label"] format**. For example: node1["My Node Label"]
- Custom style are not permitted. Default styles only.
- JUST provide the diagram, no additional text or context. Say nothing else in your response except the diagram.
- Keep diagrams simple - maximum 15 nodes
- Every node inside a subgraph MUST use square bracket notation: id["label"]
output: {query}
""".strip()
)
failed_diagram_generation = PromptTemplate.from_template(
"""
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.

View File

@@ -266,7 +266,7 @@ def save_to_conversation_log(
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
generated_mermaidjs_diagram: str = None,
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@@ -290,8 +290,8 @@ def save_to_conversation_log(
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
}
if generated_excalidraw_diagram:
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
if generated_mermaidjs_diagram:
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
updated_conversation = message_to_log(
user_message=q,
@@ -441,7 +441,7 @@ def generate_chatml_messages_with_context(
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}
if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant":
generated_assets["diagram"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}
@@ -593,6 +593,11 @@ def clean_json(response: str):
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
def clean_mermaidjs(response: str):
"""Remove any markdown mermaidjs codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().removeprefix("```mermaid").removesuffix("```")
def clean_code_python(code: str):
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
return code.strip().removeprefix("```python").removesuffix("```")

View File

@@ -51,7 +51,7 @@ from khoj.routers.helpers import (
construct_automation_created_message,
create_automation,
gather_raw_query_files,
generate_excalidraw_diagram,
generate_mermaidjs_diagram,
generate_summary_from_files,
get_conversation_command,
is_query_empty,
@@ -781,7 +781,7 @@ async def chat(
generated_images: List[str] = []
generated_files: List[FileAttachment] = []
generated_excalidraw_diagram: str = None
generated_mermaidjs_diagram: str = None
program_execution_context: List[str] = []
if conversation_commands == [ConversationCommand.Default]:
@@ -1161,7 +1161,7 @@ async def chat(
inferred_queries = []
diagram_description = ""
async for result in generate_excalidraw_diagram(
async for result in generate_mermaidjs_diagram(
q=defiltered_query,
conversation_history=meta_log,
location_data=location,
@@ -1177,12 +1177,12 @@ async def chat(
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
better_diagram_description_prompt, excalidraw_diagram_description = result
if better_diagram_description_prompt and excalidraw_diagram_description:
better_diagram_description_prompt, mermaidjs_diagram_description = result
if better_diagram_description_prompt and mermaidjs_diagram_description:
inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description
diagram_description = mermaidjs_diagram_description
generated_excalidraw_diagram = diagram_description
generated_mermaidjs_diagram = diagram_description
generated_asset_results["diagrams"] = {
"query": better_diagram_description_prompt,
@@ -1191,7 +1191,7 @@ async def chat(
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
"excalidrawDiagram": excalidraw_diagram_description,
"mermaidjsDiagram": mermaidjs_diagram_description,
},
):
yield result
@@ -1231,7 +1231,7 @@ async def chat(
raw_query_files,
generated_images,
generated_files,
generated_excalidraw_diagram,
generated_mermaidjs_diagram,
program_execution_context,
generated_asset_results,
tracer,

View File

@@ -97,6 +97,7 @@ from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
clean_json,
clean_mermaidjs,
construct_chat_history,
generate_chatml_messages_with_context,
save_to_conversation_log,
@@ -823,7 +824,7 @@ async def generate_better_diagram_description(
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
location=location,
@@ -887,6 +888,133 @@ async def generate_excalidraw_diagram_from_description(
return response
async def generate_mermaidjs_diagram(
q: str,
conversation_history: Dict[str, Any],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
yield {ChatEvent.STATUS: event}
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
q=q,
conversation_history=conversation_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
query_images=query_images,
user=user,
agent=agent,
query_files=query_files,
tracer=tracer,
)
if send_status_func:
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
yield {ChatEvent.STATUS: event}
mermaidjs_diagram_description = await generate_mermaidjs_diagram_from_description(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
inferred_queries = f"Instruction: {better_diagram_description_prompt}"
yield inferred_queries, mermaidjs_diagram_description
async def generate_better_mermaidjs_diagram_description(
q: str,
conversation_history: Dict[str, Any],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
location = f"{location_data}" if location_data else "Unknown"
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history = construct_chat_history(conversation_history)
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
location=location,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
tracer=tracer,
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response
async def generate_mermaidjs_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
mermaidjs_diagram_generation = prompts.mermaid_js_diagram_generation_prompt.format(
personality_context=personality_context,
query=q,
)
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer)
return clean_mermaidjs(raw_response.strip())
async def generate_better_image_prompt(
q: str,
conversation_history: str,
@@ -1224,7 +1352,7 @@ def generate_chat_response(
raw_query_files: List[FileAttachment] = None,
generated_images: List[str] = None,
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
generated_mermaidjs_diagram: str = None,
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
@@ -1252,7 +1380,7 @@ def generate_chat_response(
raw_query_files=raw_query_files,
generated_images=generated_images,
raw_generated_files=raw_generated_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
)
@@ -1967,7 +2095,7 @@ class MessageProcessor:
self.raw_response = ""
self.generated_images = []
self.generated_files = []
self.generated_excalidraw_diagram = []
self.generated_mermaidjs_diagram = []
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
@@ -2014,8 +2142,8 @@ class MessageProcessor:
self.generated_images = chunk_data[key]
elif key == "files":
self.generated_files = chunk_data[key]
elif key == "excalidrawDiagram":
self.generated_excalidraw_diagram = chunk_data[key]
elif key == "mermaidjsDiagram":
self.generated_mermaidjs_diagram = chunk_data[key]
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
if "image" in json_data or "details" in json_data:
@@ -2052,7 +2180,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
"usage": processor.usage,
"images": processor.generated_images,
"files": processor.generated_files,
"excalidrawDiagram": processor.generated_excalidraw_diagram,
"mermaidjsDiagram": processor.generated_mermaidjs_diagram,
}