mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Persist the train of thought in the conversation history
This commit is contained in:
@@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
|||||||
|
|
||||||
import { InlineLoading } from "../loading/loading";
|
import { InlineLoading } from "../loading/loading";
|
||||||
|
|
||||||
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
|
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
|
||||||
|
|
||||||
import AgentProfileCard from "../profileCard/profileCard";
|
import AgentProfileCard from "../profileCard/profileCard";
|
||||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||||
import { AgentData } from "@/app/agents/page";
|
import { AgentData } from "@/app/agents/page";
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import { useIsMobileWidth } from "@/app/common/utils";
|
import { useIsMobileWidth } from "@/app/common/utils";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
||||||
interface ChatResponse {
|
interface ChatResponse {
|
||||||
status: string;
|
status: string;
|
||||||
@@ -40,26 +41,51 @@ interface ChatHistoryProps {
|
|||||||
customClassName?: string;
|
customClassName?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
function constructTrainOfThought(
|
interface TrainOfThoughtComponentProps {
|
||||||
trainOfThought: string[],
|
trainOfThought: string[];
|
||||||
lastMessage: boolean,
|
lastMessage: boolean;
|
||||||
agentColor: string,
|
agentColor: string;
|
||||||
key: string,
|
key: string;
|
||||||
completed: boolean = false,
|
completed?: boolean;
|
||||||
) {
|
}
|
||||||
const lastIndex = trainOfThought.length - 1;
|
|
||||||
return (
|
|
||||||
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
|
|
||||||
{!completed && <InlineLoading className="float-right" />}
|
|
||||||
|
|
||||||
{trainOfThought.map((train, index) => (
|
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||||
<TrainOfThought
|
const lastIndex = props.trainOfThought.length - 1;
|
||||||
key={`train-${index}`}
|
const [collapsed, setCollapsed] = useState(props.completed);
|
||||||
message={train}
|
|
||||||
primary={index === lastIndex && lastMessage && !completed}
|
return (
|
||||||
agentColor={agentColor}
|
<div className={`${!collapsed ? styles.trainOfThought : ""} shadow-sm`} key={props.key}>
|
||||||
/>
|
{!props.completed && <InlineLoading className="float-right" />}
|
||||||
))}
|
{collapsed ? (
|
||||||
|
<Button
|
||||||
|
className="w-fit text-left justify-start content-start text-xs"
|
||||||
|
onClick={() => setCollapsed(false)}
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
What was my train of thought?
|
||||||
|
</Button>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
className="w-fit text-left justify-start content-start text-xs"
|
||||||
|
onClick={() => setCollapsed(true)}
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
<XCircle size={16} className="mr-1" />
|
||||||
|
Close
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!collapsed &&
|
||||||
|
props.trainOfThought.map((train, index) => (
|
||||||
|
<TrainOfThought
|
||||||
|
key={`train-${index}`}
|
||||||
|
message={train}
|
||||||
|
primary={index === lastIndex && props.lastMessage && !props.completed}
|
||||||
|
agentColor={props.agentColor}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -265,25 +291,39 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||||||
{data &&
|
{data &&
|
||||||
data.chat &&
|
data.chat &&
|
||||||
data.chat.map((chatMessage, index) => (
|
data.chat.map((chatMessage, index) => (
|
||||||
<ChatMessage
|
<>
|
||||||
key={`${index}fullHistory`}
|
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
|
||||||
ref={
|
<TrainOfThoughtComponent
|
||||||
// attach ref to the second last message to handle scroll on page load
|
trainOfThought={chatMessage.trainOfThought?.map(
|
||||||
index === data.chat.length - 2
|
(train) => train.data,
|
||||||
? latestUserMessageRef
|
)}
|
||||||
: // attach ref to the newest fetched message to handle scroll on fetch
|
lastMessage={false}
|
||||||
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
agentColor={data?.agent?.color || "orange"}
|
||||||
index ===
|
key={`${index}trainOfThought`}
|
||||||
data.chat.length - (currentPage - 1) * fetchMessageCount
|
completed={true}
|
||||||
? latestFetchedMessageRef
|
/>
|
||||||
: null
|
)}
|
||||||
}
|
<ChatMessage
|
||||||
isMobileWidth={isMobileWidth}
|
key={`${index}fullHistory`}
|
||||||
chatMessage={chatMessage}
|
ref={
|
||||||
customClassName="fullHistory"
|
// attach ref to the second last message to handle scroll on page load
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
index === data.chat.length - 2
|
||||||
isLastMessage={index === data.chat.length - 1}
|
? latestUserMessageRef
|
||||||
/>
|
: // attach ref to the newest fetched message to handle scroll on fetch
|
||||||
|
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
||||||
|
index ===
|
||||||
|
data.chat.length -
|
||||||
|
(currentPage - 1) * fetchMessageCount
|
||||||
|
? latestFetchedMessageRef
|
||||||
|
: null
|
||||||
|
}
|
||||||
|
isMobileWidth={isMobileWidth}
|
||||||
|
chatMessage={chatMessage}
|
||||||
|
customClassName="fullHistory"
|
||||||
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
isLastMessage={index === data.chat.length - 1}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
))}
|
))}
|
||||||
{props.incomingMessages &&
|
{props.incomingMessages &&
|
||||||
props.incomingMessages.map((message, index) => {
|
props.incomingMessages.map((message, index) => {
|
||||||
@@ -305,14 +345,15 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
/>
|
/>
|
||||||
{message.trainOfThought &&
|
{message.trainOfThought && (
|
||||||
constructTrainOfThought(
|
<TrainOfThoughtComponent
|
||||||
message.trainOfThought,
|
trainOfThought={message.trainOfThought}
|
||||||
index === incompleteIncomingMessageIndex,
|
lastMessage={index === incompleteIncomingMessageIndex}
|
||||||
data?.agent?.color || "orange",
|
agentColor={data?.agent?.color || "orange"}
|
||||||
`${index}trainOfThought`,
|
key={`${index}trainOfThought`}
|
||||||
message.completed,
|
completed={message.completed}
|
||||||
)}
|
/>
|
||||||
|
)}
|
||||||
<ChatMessage
|
<ChatMessage
|
||||||
key={`${index}incoming`}
|
key={`${index}incoming`}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
|
|||||||
@@ -128,6 +128,11 @@ interface Intent {
|
|||||||
"inferred-queries": string[];
|
"inferred-queries": string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface TrainOfThoughtObject {
|
||||||
|
type: string;
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
|
||||||
export interface SingleChatMessage {
|
export interface SingleChatMessage {
|
||||||
automationId: string;
|
automationId: string;
|
||||||
by: string;
|
by: string;
|
||||||
@@ -136,6 +141,7 @@ export interface SingleChatMessage {
|
|||||||
context: Context[];
|
context: Context[];
|
||||||
onlineContext: OnlineContext;
|
onlineContext: OnlineContext;
|
||||||
codeContext: CodeContext;
|
codeContext: CodeContext;
|
||||||
|
trainOfThought?: TrainOfThoughtObject[];
|
||||||
rawQuery?: string;
|
rawQuery?: string;
|
||||||
intent?: Intent;
|
intent?: Intent;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
|
|||||||
@@ -146,7 +146,12 @@ class ChatEvent(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def message_to_log(
|
def message_to_log(
|
||||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
user_message,
|
||||||
|
chat_response,
|
||||||
|
user_message_metadata={},
|
||||||
|
khoj_message_metadata={},
|
||||||
|
conversation_log=[],
|
||||||
|
train_of_thought=[],
|
||||||
):
|
):
|
||||||
"""Create json logs from messages, metadata for conversation log"""
|
"""Create json logs from messages, metadata for conversation log"""
|
||||||
default_khoj_message_metadata = {
|
default_khoj_message_metadata = {
|
||||||
@@ -182,6 +187,7 @@ def save_to_conversation_log(
|
|||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
tracer: Dict[str, Any] = {},
|
tracer: Dict[str, Any] = {},
|
||||||
|
train_of_thought: List[Any] = [],
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
@@ -197,8 +203,10 @@ def save_to_conversation_log(
|
|||||||
"onlineContext": online_results,
|
"onlineContext": online_results,
|
||||||
"codeContext": code_results,
|
"codeContext": code_results,
|
||||||
"automationId": automation_id,
|
"automationId": automation_id,
|
||||||
|
"trainOfThought": train_of_thought,
|
||||||
},
|
},
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
)
|
)
|
||||||
ConversationAdapters.save_conversation(
|
ConversationAdapters.save_conversation(
|
||||||
user,
|
user,
|
||||||
|
|||||||
@@ -570,7 +570,9 @@ async def chat(
|
|||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
|
train_of_thought = []
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
|
||||||
tracer: dict = {
|
tracer: dict = {
|
||||||
"mid": f"{uuid.uuid4()}",
|
"mid": f"{uuid.uuid4()}",
|
||||||
"cid": conversation_id,
|
"cid": conversation_id,
|
||||||
@@ -590,7 +592,7 @@ async def chat(
|
|||||||
uploaded_images.append(uploaded_image)
|
uploaded_images.append(uploaded_image)
|
||||||
|
|
||||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft
|
nonlocal connection_alive, ttft, train_of_thought
|
||||||
if not connection_alive or await request.is_disconnected():
|
if not connection_alive or await request.is_disconnected():
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||||
@@ -598,8 +600,11 @@ async def chat(
|
|||||||
try:
|
try:
|
||||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||||
collect_telemetry()
|
collect_telemetry()
|
||||||
if event_type == ChatEvent.START_LLM_RESPONSE:
|
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
||||||
ttft = time.perf_counter() - start_time
|
ttft = time.perf_counter() - start_time
|
||||||
|
elif event_type == ChatEvent.STATUS:
|
||||||
|
train_of_thought.append({"type": event_type.value, "data": data})
|
||||||
|
|
||||||
if event_type == ChatEvent.MESSAGE:
|
if event_type == ChatEvent.MESSAGE:
|
||||||
yield data
|
yield data
|
||||||
elif event_type == ChatEvent.REFERENCES or stream:
|
elif event_type == ChatEvent.REFERENCES or stream:
|
||||||
@@ -810,6 +815,7 @@ async def chat(
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -854,6 +860,7 @@ async def chat(
|
|||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
@@ -1061,6 +1068,7 @@ async def chat(
|
|||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
@@ -1118,6 +1126,7 @@ async def chat(
|
|||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
@@ -1144,6 +1153,7 @@ async def chat(
|
|||||||
researched_results,
|
researched_results,
|
||||||
uploaded_images,
|
uploaded_images,
|
||||||
tracer,
|
tracer,
|
||||||
|
train_of_thought,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
|||||||
@@ -1113,6 +1113,7 @@ def generate_chat_response(
|
|||||||
meta_research: str = "",
|
meta_research: str = "",
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
train_of_thought: List[Any] = [],
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
@@ -1137,6 +1138,7 @@ def generate_chat_response(
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
train_of_thought=train_of_thought,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
|
|||||||
Reference in New Issue
Block a user