mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Merge pull request #1054 from khoj-ai/features/add-support-for-mermaidjs
We've been having issues generating diagrams with Excalidraw that are any degree of complexity. By contrast, LLMs are able to handle Mermaid.js syntax a lot better, as it's much more forgiving and has a simpler declarative style. Refer to https://mermaid.js.org/. Update so that new diagrams are generated with Mermaid.js, while old diagrams generated with Excalidraw can still be viewed.
This commit is contained in:
@@ -487,7 +487,8 @@ export class KhojChatView extends KhojPaneView {
|
||||
inferredQueries?: string[],
|
||||
conversationId?: string,
|
||||
images?: string[],
|
||||
excalidrawDiagram?: string
|
||||
excalidrawDiagram?: string,
|
||||
mermaidjsDiagram?: string
|
||||
) {
|
||||
if (!message) return;
|
||||
|
||||
@@ -496,8 +497,9 @@ export class KhojChatView extends KhojPaneView {
|
||||
intentType?.includes("text-to-image") ||
|
||||
intentType === "excalidraw" ||
|
||||
(images && images.length > 0) ||
|
||||
mermaidjsDiagram ||
|
||||
excalidrawDiagram) {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram);
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram, mermaidjsDiagram);
|
||||
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||
} else {
|
||||
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||
@@ -517,7 +519,7 @@ export class KhojChatView extends KhojPaneView {
|
||||
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
|
||||
}
|
||||
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string {
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string, mermaidjsDiagram?: string): string {
|
||||
let imageMarkdown = "";
|
||||
if (intentType === "text-to-image") {
|
||||
imageMarkdown = ``;
|
||||
@@ -529,6 +531,8 @@ export class KhojChatView extends KhojPaneView {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
|
||||
imageMarkdown = redirectMessage;
|
||||
} else if (mermaidjsDiagram) {
|
||||
imageMarkdown = "```mermaid\n" + mermaidjsDiagram + "\n```";
|
||||
} else if (images && images.length > 0) {
|
||||
imageMarkdown += images.map(image => ``).join('\n\n');
|
||||
imageMarkdown += message;
|
||||
@@ -901,6 +905,7 @@ export class KhojChatView extends KhojPaneView {
|
||||
chatBodyEl.dataset.conversationId ?? "",
|
||||
chatLog.images,
|
||||
chatLog.excalidrawDiagram,
|
||||
chatLog.mermaidjsDiagram,
|
||||
);
|
||||
// push the user messages to the chat history
|
||||
if (chatLog.by === "you") {
|
||||
@@ -1005,7 +1010,7 @@ export class KhojChatView extends KhojPaneView {
|
||||
}
|
||||
|
||||
handleJsonResponse(jsonData: any): void {
|
||||
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) {
|
||||
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.mermaidjsDiagram) {
|
||||
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
|
||||
} else if (jsonData.response) {
|
||||
this.chatMessageState.rawResponse = jsonData.response;
|
||||
@@ -1395,6 +1400,8 @@ export class KhojChatView extends KhojPaneView {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
|
||||
rawResponse += redirectMessage;
|
||||
} else if (imageJson.mermaidjsDiagram) {
|
||||
rawResponse += imageJson.mermaidjsDiagram;
|
||||
}
|
||||
|
||||
// If response has detail field, response is an error message.
|
||||
|
||||
@@ -19,7 +19,7 @@ export interface MessageMetadata {
|
||||
|
||||
export interface GeneratedAssetsData {
|
||||
images: string[];
|
||||
excalidrawDiagram: string;
|
||||
mermaidjsDiagram: string;
|
||||
files: AttachedFileText[];
|
||||
}
|
||||
|
||||
@@ -114,8 +114,8 @@ export function processMessageChunk(
|
||||
currentMessage.generatedImages = generatedAssets.images;
|
||||
}
|
||||
|
||||
if (generatedAssets.excalidrawDiagram) {
|
||||
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
|
||||
if (generatedAssets.mermaidjsDiagram) {
|
||||
currentMessage.generatedMermaidjsDiagram = generatedAssets.mermaidjsDiagram;
|
||||
}
|
||||
|
||||
if (generatedAssets.files) {
|
||||
|
||||
@@ -418,7 +418,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||
conversationId: props.conversationId,
|
||||
images: message.generatedImages,
|
||||
queryFiles: message.generatedFiles,
|
||||
excalidrawDiagram: message.generatedExcalidrawDiagram,
|
||||
mermaidjsDiagram: message.generatedMermaidjsDiagram,
|
||||
turnId: messageTurnId,
|
||||
}}
|
||||
conversationId={props.conversationId}
|
||||
|
||||
@@ -53,6 +53,7 @@ import { DialogTitle } from "@radix-ui/react-dialog";
|
||||
import { convertBytesToText } from "@/app/common/utils";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { getIconFromFilename } from "@/app/common/iconUtils";
|
||||
import Mermaid from "../mermaid/mermaid";
|
||||
|
||||
const md = new markdownIt({
|
||||
html: true,
|
||||
@@ -164,6 +165,7 @@ export interface SingleChatMessage {
|
||||
turnId?: string;
|
||||
queryFiles?: AttachedFileText[];
|
||||
excalidrawDiagram?: string;
|
||||
mermaidjsDiagram?: string;
|
||||
}
|
||||
|
||||
export interface StreamMessage {
|
||||
@@ -182,9 +184,11 @@ export interface StreamMessage {
|
||||
turnId?: string;
|
||||
queryFiles?: AttachedFileText[];
|
||||
excalidrawDiagram?: string;
|
||||
mermaidjsDiagram?: string;
|
||||
generatedFiles?: AttachedFileText[];
|
||||
generatedImages?: string[];
|
||||
generatedExcalidrawDiagram?: string;
|
||||
generatedMermaidjsDiagram?: string;
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
@@ -271,6 +275,7 @@ interface ChatMessageProps {
|
||||
turnId?: string;
|
||||
generatedImage?: string;
|
||||
excalidrawDiagram?: string;
|
||||
mermaidjsDiagram?: string;
|
||||
generatedFiles?: AttachedFileText[];
|
||||
}
|
||||
|
||||
@@ -358,6 +363,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
const [isPlaying, setIsPlaying] = useState<boolean>(false);
|
||||
const [interrupted, setInterrupted] = useState<boolean>(false);
|
||||
const [excalidrawData, setExcalidrawData] = useState<string>("");
|
||||
const [mermaidjsData, setMermaidjsData] = useState<string>("");
|
||||
|
||||
const interruptedRef = useRef<boolean>(false);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
@@ -401,6 +407,10 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
setExcalidrawData(props.chatMessage.excalidrawDiagram);
|
||||
}
|
||||
|
||||
if (props.chatMessage.mermaidjsDiagram) {
|
||||
setMermaidjsData(props.chatMessage.mermaidjsDiagram);
|
||||
}
|
||||
|
||||
// Replace LaTeX delimiters with placeholders
|
||||
message = message
|
||||
.replace(/\\\(/g, "LEFTPAREN")
|
||||
@@ -718,6 +728,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||
dangerouslySetInnerHTML={{ __html: markdownRendered }}
|
||||
/>
|
||||
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
|
||||
{mermaidjsData && <Mermaid chart={mermaidjsData} />}
|
||||
</div>
|
||||
<div className={styles.teaserReferencesContainer}>
|
||||
<TeaserReferencesSection
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||
import { CircleNotch } from "@phosphor-icons/react";
|
||||
import { AppSidebar } from "../appSidebar/appSidebar";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { useIsMobileWidth } from "@/app/common/utils";
|
||||
import { KhojLogoType } from "../logo/khojLogo";
|
||||
|
||||
interface LoadingProps {
|
||||
className?: string;
|
||||
@@ -7,21 +12,39 @@ interface LoadingProps {
|
||||
}
|
||||
|
||||
export default function Loading(props: LoadingProps) {
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
return (
|
||||
// NOTE: We can display usage tips here for casual learning moments.
|
||||
<div
|
||||
className={
|
||||
props.className ||
|
||||
"bg-background opacity-50 flex items-center justify-center h-screen"
|
||||
}
|
||||
>
|
||||
<div>
|
||||
{props.message || "Loading"}{" "}
|
||||
<span>
|
||||
<CircleNotch className="inline animate-spin h-5 w-5" />
|
||||
</span>
|
||||
<SidebarProvider>
|
||||
<AppSidebar conversationId={""} />
|
||||
<SidebarInset>
|
||||
<header className="flex h-16 shrink-0 items-center gap-2 border-b px-4">
|
||||
<SidebarTrigger className="-ml-1" />
|
||||
<Separator orientation="vertical" className="mr-2 h-4" />
|
||||
{isMobileWidth ? (
|
||||
<a className="p-0 no-underline" href="/">
|
||||
<KhojLogoType className="h-auto w-16" />
|
||||
</a>
|
||||
) : (
|
||||
<h2 className="text-lg">Ask Anything</h2>
|
||||
)}
|
||||
</header>
|
||||
</SidebarInset>
|
||||
<div
|
||||
className={
|
||||
props.className ||
|
||||
"bg-background opacity-50 flex items-center justify-center h-full w-full fixed top-0 left-0 z-50"
|
||||
}
|
||||
>
|
||||
<div>
|
||||
{props.message || "Loading"}{" "}
|
||||
<span>
|
||||
<CircleNotch className="inline animate-spin h-5 w-5" />
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</SidebarProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
173
src/interface/web/app/components/mermaid/mermaid.tsx
Normal file
173
src/interface/web/app/components/mermaid/mermaid.tsx
Normal file
@@ -0,0 +1,173 @@
|
||||
import React, { useEffect, useState, useRef } from "react";
|
||||
import mermaid from "mermaid";
|
||||
import { Download, Info } from "@phosphor-icons/react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
interface MermaidProps {
|
||||
chart: string;
|
||||
}
|
||||
|
||||
const Mermaid: React.FC<MermaidProps> = ({ chart }) => {
|
||||
const [mermaidError, setMermaidError] = useState<string | null>(null);
|
||||
const [mermaidId] = useState(`mermaid-chart-${Math.random().toString(12).substring(7)}`);
|
||||
const elementRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
mermaid.initialize({
|
||||
startOnLoad: false,
|
||||
});
|
||||
|
||||
mermaid.parseError = (error) => {
|
||||
console.error("Mermaid errors:", error);
|
||||
// Extract error message from error object
|
||||
// Parse error message safely
|
||||
let errorMessage;
|
||||
try {
|
||||
errorMessage = typeof error === "string" ? JSON.parse(error) : error;
|
||||
} catch (e) {
|
||||
errorMessage = error?.toString() || "Unknown error";
|
||||
}
|
||||
|
||||
console.log("Mermaid error message:", errorMessage);
|
||||
|
||||
if (errorMessage.str !== "element is null") {
|
||||
setMermaidError(
|
||||
"Something went wrong while rendering the diagram. Please try again later or downvote the message if the issue persists.",
|
||||
);
|
||||
} else {
|
||||
setMermaidError(null);
|
||||
}
|
||||
};
|
||||
|
||||
mermaid.contentLoaded();
|
||||
}, []);
|
||||
|
||||
const handleExport = async () => {
|
||||
if (!elementRef.current) return;
|
||||
|
||||
try {
|
||||
// Get SVG element
|
||||
const svgElement = elementRef.current.querySelector("svg");
|
||||
if (!svgElement) throw new Error("No SVG found");
|
||||
|
||||
// Get SVG viewBox dimensions
|
||||
const viewBox = svgElement.getAttribute("viewBox")?.split(" ").map(Number) || [
|
||||
0, 0, 0, 0,
|
||||
];
|
||||
const [, , viewBoxWidth, viewBoxHeight] = viewBox;
|
||||
|
||||
// Create canvas with viewBox dimensions
|
||||
const canvas = document.createElement("canvas");
|
||||
const scale = 2; // For better resolution
|
||||
canvas.width = viewBoxWidth * scale;
|
||||
canvas.height = viewBoxHeight * scale;
|
||||
const ctx = canvas.getContext("2d");
|
||||
if (!ctx) throw new Error("Failed to get canvas context");
|
||||
|
||||
// Convert SVG to data URL
|
||||
const svgData = new XMLSerializer().serializeToString(svgElement);
|
||||
const svgBlob = new Blob([svgData], { type: "image/svg+xml;charset=utf-8" });
|
||||
const svgUrl = URL.createObjectURL(svgBlob);
|
||||
|
||||
// Create and load image
|
||||
const img = new Image();
|
||||
img.src = svgUrl;
|
||||
|
||||
await new Promise((resolve, reject) => {
|
||||
img.onload = () => {
|
||||
// Scale context for better resolution
|
||||
ctx.scale(scale, scale);
|
||||
ctx.drawImage(img, 0, 0, viewBoxWidth, viewBoxHeight);
|
||||
|
||||
canvas.toBlob((blob) => {
|
||||
if (!blob) {
|
||||
reject(new Error("Failed to create blob"));
|
||||
return;
|
||||
}
|
||||
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `mermaid-diagram-${Date.now()}.png`;
|
||||
a.click();
|
||||
|
||||
// Cleanup
|
||||
URL.revokeObjectURL(url);
|
||||
URL.revokeObjectURL(svgUrl);
|
||||
resolve(true);
|
||||
}, "image/png");
|
||||
};
|
||||
|
||||
img.onerror = () => reject(new Error("Failed to load SVG"));
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error exporting diagram:", error);
|
||||
setMermaidError("Failed to export diagram");
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (elementRef.current) {
|
||||
elementRef.current.removeAttribute("data-processed");
|
||||
|
||||
mermaid
|
||||
.run({
|
||||
nodes: [elementRef.current],
|
||||
})
|
||||
.then(() => {
|
||||
setMermaidError(null);
|
||||
})
|
||||
.catch((error) => {
|
||||
let errorMessage;
|
||||
try {
|
||||
errorMessage = typeof error === "string" ? JSON.parse(error) : error;
|
||||
} catch (e) {
|
||||
errorMessage = error?.toString() || "Unknown error";
|
||||
}
|
||||
|
||||
console.log("Mermaid error message:", errorMessage);
|
||||
|
||||
if (errorMessage.str !== "element is null") {
|
||||
setMermaidError(
|
||||
"Something went wrong while rendering the diagram. Please try again later or downvote the message if the issue persists.",
|
||||
);
|
||||
} else {
|
||||
setMermaidError(null);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [chart]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{mermaidError ? (
|
||||
<div className="flex items-center gap-2 bg-red-100 border border-red-500 rounded-md p-3 mt-3 text-red-900 text-sm">
|
||||
<Info className="w-12 h-12" />
|
||||
<span>Error rendering diagram: {mermaidError}</span>
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
id={mermaidId}
|
||||
ref={elementRef}
|
||||
className="mermaid"
|
||||
style={{
|
||||
width: "auto",
|
||||
height: "auto",
|
||||
boxSizing: "border-box",
|
||||
overflow: "auto",
|
||||
}}
|
||||
>
|
||||
{chart}
|
||||
</div>
|
||||
)}
|
||||
{!mermaidError && (
|
||||
<Button onClick={handleExport} variant={"secondary"} className="mt-3">
|
||||
<Download className="w-5 h-5" />
|
||||
Export as PNG
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Mermaid;
|
||||
@@ -58,6 +58,7 @@
|
||||
"lucide-react": "^0.468.0",
|
||||
"markdown-it": "^14.1.0",
|
||||
"markdown-it-highlightjs": "^4.1.0",
|
||||
"mermaid": "^11.4.1",
|
||||
"next": "14.2.15",
|
||||
"nodemon": "^3.1.3",
|
||||
"postcss": "^8.4.38",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -109,6 +109,7 @@ class ChatMessage(PydanticBaseModel):
|
||||
images: Optional[List[str]] = None
|
||||
queryFiles: Optional[List[Dict]] = None
|
||||
excalidrawDiagram: Optional[List[Dict]] = None
|
||||
mermaidjsDiagram: str = None
|
||||
by: str
|
||||
turnId: Optional[str] = None
|
||||
intent: Optional[Intent] = None
|
||||
|
||||
@@ -194,7 +194,7 @@ Limit your response to 3 sentences max. Be succinct, clear, and informative.
|
||||
## Diagram Generation
|
||||
## --
|
||||
|
||||
improve_diagram_description_prompt = PromptTemplate.from_template(
|
||||
improve_excalidraw_diagram_description_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are an architect working with a novice digital artist using a diagramming software.
|
||||
{personality_context}
|
||||
@@ -338,6 +338,123 @@ Diagram Description: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
improve_mermaid_js_diagram_description_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a senior architect working with an illustrator using a diagramming software.
|
||||
{personality_context}
|
||||
|
||||
Given a particular request, you need to translate it to to a detailed description that the illustrator can use to create a diagram.
|
||||
|
||||
You can use the following diagram types in your instructions:
|
||||
- Flowchart
|
||||
- Sequence Diagram
|
||||
- Gantt Chart (only for time-based queries after 0 AD)
|
||||
- State Diagram
|
||||
- Pie Chart
|
||||
|
||||
Use these primitives to describe what sort of diagram the drawer should create in natural language, not special syntax. We must recreate the diagram every time, so include all relevant prior information in your description.
|
||||
|
||||
- Describe the layout, components, and connections.
|
||||
- Use simple, concise language.
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
|
||||
Online References:
|
||||
{online_results}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
Query: {query}
|
||||
|
||||
Enhanced Description:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
mermaid_js_diagram_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a designer with the ability to describe diagrams to compose in professional, fine detail. You dive into the details and make labels, connections, and shapes to represent complex systems.
|
||||
{personality_context}
|
||||
|
||||
----Goals----
|
||||
You need to create a declarative description of the diagram and relevant components, using the Mermaid.js syntax.
|
||||
|
||||
You can choose from the following diagram types:
|
||||
- Flowchart
|
||||
- Sequence Diagram
|
||||
- State Diagram
|
||||
- Gantt Chart
|
||||
- Pie Chart
|
||||
|
||||
----Examples----
|
||||
---
|
||||
title: Node
|
||||
---
|
||||
|
||||
flowchart LR
|
||||
id["This is the start"] --> id2["This is the end"]
|
||||
|
||||
sequenceDiagram
|
||||
Alice->>John: Hello John, how are you?
|
||||
John-->>Alice: Great!
|
||||
Alice-)John: See you later!
|
||||
|
||||
stateDiagram-v2
|
||||
[*] --> Still
|
||||
Still --> [*]
|
||||
|
||||
Still --> Moving
|
||||
Moving --> Still
|
||||
Moving --> Crash
|
||||
Crash --> [*]
|
||||
|
||||
gantt
|
||||
title A Gantt Diagram
|
||||
dateFormat YYYY-MM-DD
|
||||
section Section
|
||||
A task :a1, 2014-01-01, 30d
|
||||
Another task :after a1, 20d
|
||||
section Another
|
||||
Task in Another :2014-01-12, 12d
|
||||
another task :24d
|
||||
|
||||
pie title Pets adopted by volunteers
|
||||
"Dogs" : 10
|
||||
"Cats" : 30
|
||||
"Rats" : 60
|
||||
|
||||
flowchart TB
|
||||
subgraph "Group 1"
|
||||
a1["Start Node"] --> a2["End Node"]
|
||||
end
|
||||
subgraph "Group 2"
|
||||
b1["Process 1"] --> b2["Process 2"]
|
||||
end
|
||||
subgraph "Group 3"
|
||||
c1["Input"] --> c2["Output"]
|
||||
end
|
||||
a["Group 1"] --> b["Group 2"]
|
||||
c["Group 3"] --> d["Group 2"]
|
||||
|
||||
----Process----
|
||||
Create your diagram with great composition and intuitiveness from the provided context and user prompt below.
|
||||
- You may use subgraphs to group elements together. Each subgraph must have a title.
|
||||
- **You must wrap ALL entity and node labels in double quotes**, example: "My Node Label"
|
||||
- **All nodes MUST use the id["label"] format**. For example: node1["My Node Label"]
|
||||
- Custom style are not permitted. Default styles only.
|
||||
- JUST provide the diagram, no additional text or context. Say nothing else in your response except the diagram.
|
||||
- Keep diagrams simple - maximum 15 nodes
|
||||
- Every node inside a subgraph MUST use square bracket notation: id["label"]
|
||||
|
||||
output: {query}
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
failed_diagram_generation = PromptTemplate.from_template(
|
||||
"""
|
||||
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.
|
||||
|
||||
@@ -266,7 +266,7 @@ def save_to_conversation_log(
|
||||
raw_query_files: List[FileAttachment] = [],
|
||||
generated_images: List[str] = [],
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_excalidraw_diagram: str = None,
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
train_of_thought: List[Any] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
@@ -290,8 +290,8 @@ def save_to_conversation_log(
|
||||
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
|
||||
}
|
||||
|
||||
if generated_excalidraw_diagram:
|
||||
khoj_message_metadata["excalidrawDiagram"] = generated_excalidraw_diagram
|
||||
if generated_mermaidjs_diagram:
|
||||
khoj_message_metadata["mermaidjsDiagram"] = generated_mermaidjs_diagram
|
||||
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
@@ -441,7 +441,7 @@ def generate_chatml_messages_with_context(
|
||||
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||
}
|
||||
|
||||
if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
|
||||
if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant":
|
||||
generated_assets["diagram"] = {
|
||||
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||
}
|
||||
@@ -593,6 +593,11 @@ def clean_json(response: str):
|
||||
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
||||
|
||||
|
||||
def clean_mermaidjs(response: str):
|
||||
"""Remove any markdown mermaidjs codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return response.strip().removeprefix("```mermaid").removesuffix("```")
|
||||
|
||||
|
||||
def clean_code_python(code: str):
|
||||
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return code.strip().removeprefix("```python").removesuffix("```")
|
||||
|
||||
@@ -51,7 +51,7 @@ from khoj.routers.helpers import (
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
gather_raw_query_files,
|
||||
generate_excalidraw_diagram,
|
||||
generate_mermaidjs_diagram,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
@@ -781,7 +781,7 @@ async def chat(
|
||||
|
||||
generated_images: List[str] = []
|
||||
generated_files: List[FileAttachment] = []
|
||||
generated_excalidraw_diagram: str = None
|
||||
generated_mermaidjs_diagram: str = None
|
||||
program_execution_context: List[str] = []
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
@@ -1161,7 +1161,7 @@ async def chat(
|
||||
inferred_queries = []
|
||||
diagram_description = ""
|
||||
|
||||
async for result in generate_excalidraw_diagram(
|
||||
async for result in generate_mermaidjs_diagram(
|
||||
q=defiltered_query,
|
||||
conversation_history=meta_log,
|
||||
location_data=location,
|
||||
@@ -1177,12 +1177,12 @@ async def chat(
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
better_diagram_description_prompt, excalidraw_diagram_description = result
|
||||
if better_diagram_description_prompt and excalidraw_diagram_description:
|
||||
better_diagram_description_prompt, mermaidjs_diagram_description = result
|
||||
if better_diagram_description_prompt and mermaidjs_diagram_description:
|
||||
inferred_queries.append(better_diagram_description_prompt)
|
||||
diagram_description = excalidraw_diagram_description
|
||||
diagram_description = mermaidjs_diagram_description
|
||||
|
||||
generated_excalidraw_diagram = diagram_description
|
||||
generated_mermaidjs_diagram = diagram_description
|
||||
|
||||
generated_asset_results["diagrams"] = {
|
||||
"query": better_diagram_description_prompt,
|
||||
@@ -1191,7 +1191,7 @@ async def chat(
|
||||
async for result in send_event(
|
||||
ChatEvent.GENERATED_ASSETS,
|
||||
{
|
||||
"excalidrawDiagram": excalidraw_diagram_description,
|
||||
"mermaidjsDiagram": mermaidjs_diagram_description,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
@@ -1231,7 +1231,7 @@ async def chat(
|
||||
raw_query_files,
|
||||
generated_images,
|
||||
generated_files,
|
||||
generated_excalidraw_diagram,
|
||||
generated_mermaidjs_diagram,
|
||||
program_execution_context,
|
||||
generated_asset_results,
|
||||
tracer,
|
||||
|
||||
@@ -97,6 +97,7 @@ from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
generate_chatml_messages_with_context,
|
||||
save_to_conversation_log,
|
||||
@@ -823,7 +824,7 @@ async def generate_better_diagram_description(
|
||||
elif online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
|
||||
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
@@ -887,6 +888,133 @@ async def generate_excalidraw_diagram_from_description(
|
||||
return response
|
||||
|
||||
|
||||
async def generate_mermaidjs_diagram(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
|
||||
q=q,
|
||||
conversation_history=conversation_history,
|
||||
location_data=location_data,
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
mermaidjs_diagram_description = await generate_mermaidjs_diagram_from_description(
|
||||
q=better_diagram_description_prompt,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
inferred_queries = f"Instruction: {better_diagram_description_prompt}"
|
||||
|
||||
yield inferred_queries, mermaidjs_diagram_description
|
||||
|
||||
|
||||
async def generate_better_mermaidjs_diagram_description(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a diagram description from the given query and context
|
||||
"""
|
||||
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
|
||||
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
simplified_online_results = {}
|
||||
|
||||
if online_results:
|
||||
for result in online_results:
|
||||
if online_results[result].get("answerBox"):
|
||||
simplified_online_results[result] = online_results[result]["answerBox"]
|
||||
elif online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_mermaidjs_diagram_from_description(
|
||||
q: str,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
mermaidjs_diagram_generation = prompts.mermaid_js_diagram_generation_prompt.format(
|
||||
personality_context=personality_context,
|
||||
query=q,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer)
|
||||
return clean_mermaidjs(raw_response.strip())
|
||||
|
||||
|
||||
async def generate_better_image_prompt(
|
||||
q: str,
|
||||
conversation_history: str,
|
||||
@@ -1224,7 +1352,7 @@ def generate_chat_response(
|
||||
raw_query_files: List[FileAttachment] = None,
|
||||
generated_images: List[str] = None,
|
||||
raw_generated_files: List[FileAttachment] = [],
|
||||
generated_excalidraw_diagram: str = None,
|
||||
generated_mermaidjs_diagram: str = None,
|
||||
program_execution_context: List[str] = [],
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
tracer: dict = {},
|
||||
@@ -1252,7 +1380,7 @@ def generate_chat_response(
|
||||
raw_query_files=raw_query_files,
|
||||
generated_images=generated_images,
|
||||
raw_generated_files=raw_generated_files,
|
||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1967,7 +2095,7 @@ class MessageProcessor:
|
||||
self.raw_response = ""
|
||||
self.generated_images = []
|
||||
self.generated_files = []
|
||||
self.generated_excalidraw_diagram = []
|
||||
self.generated_mermaidjs_diagram = []
|
||||
|
||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
||||
@@ -2014,8 +2142,8 @@ class MessageProcessor:
|
||||
self.generated_images = chunk_data[key]
|
||||
elif key == "files":
|
||||
self.generated_files = chunk_data[key]
|
||||
elif key == "excalidrawDiagram":
|
||||
self.generated_excalidraw_diagram = chunk_data[key]
|
||||
elif key == "mermaidjsDiagram":
|
||||
self.generated_mermaidjs_diagram = chunk_data[key]
|
||||
|
||||
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
||||
if "image" in json_data or "details" in json_data:
|
||||
@@ -2052,7 +2180,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||
"usage": processor.usage,
|
||||
"images": processor.generated_images,
|
||||
"files": processor.generated_files,
|
||||
"excalidrawDiagram": processor.generated_excalidraw_diagram,
|
||||
"mermaidjsDiagram": processor.generated_mermaidjs_diagram,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user