Persist the train of thought in the conversation history

This commit is contained in:
sabaimran
2024-10-26 23:46:15 -07:00
parent 9e8ac7f89e
commit a121d67b10
5 changed files with 117 additions and 50 deletions

View File

@@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { InlineLoading } from "../loading/loading"; import { InlineLoading } from "../loading/loading";
import { Lightbulb, ArrowDown } from "@phosphor-icons/react"; import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
import AgentProfileCard from "../profileCard/profileCard"; import AgentProfileCard from "../profileCard/profileCard";
import { getIconFromIconName } from "@/app/common/iconUtils"; import { getIconFromIconName } from "@/app/common/iconUtils";
import { AgentData } from "@/app/agents/page"; import { AgentData } from "@/app/agents/page";
import React from "react"; import React from "react";
import { useIsMobileWidth } from "@/app/common/utils"; import { useIsMobileWidth } from "@/app/common/utils";
import { Button } from "@/components/ui/button";
interface ChatResponse { interface ChatResponse {
status: string; status: string;
@@ -40,26 +41,51 @@ interface ChatHistoryProps {
customClassName?: string; customClassName?: string;
} }
function constructTrainOfThought( interface TrainOfThoughtComponentProps {
trainOfThought: string[], trainOfThought: string[];
lastMessage: boolean, lastMessage: boolean;
agentColor: string, agentColor: string;
key: string, key: string;
completed: boolean = false, completed?: boolean;
) { }
const lastIndex = trainOfThought.length - 1;
return (
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
{!completed && <InlineLoading className="float-right" />}
{trainOfThought.map((train, index) => ( function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
<TrainOfThought const lastIndex = props.trainOfThought.length - 1;
key={`train-${index}`} const [collapsed, setCollapsed] = useState(props.completed);
message={train}
primary={index === lastIndex && lastMessage && !completed} return (
agentColor={agentColor} <div className={`${!collapsed ? styles.trainOfThought : ""} shadow-sm`} key={props.key}>
/> {!props.completed && <InlineLoading className="float-right" />}
))} {collapsed ? (
<Button
className="w-fit text-left justify-start content-start text-xs"
onClick={() => setCollapsed(false)}
variant="ghost"
size="sm"
>
What was my train of thought?
</Button>
) : (
<Button
className="w-fit text-left justify-start content-start text-xs"
onClick={() => setCollapsed(true)}
variant="ghost"
size="sm"
>
<XCircle size={16} className="mr-1" />
Close
</Button>
)}
{!collapsed &&
props.trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={index === lastIndex && props.lastMessage && !props.completed}
agentColor={props.agentColor}
/>
))}
</div> </div>
); );
} }
@@ -265,25 +291,39 @@ export default function ChatHistory(props: ChatHistoryProps) {
{data && {data &&
data.chat && data.chat &&
data.chat.map((chatMessage, index) => ( data.chat.map((chatMessage, index) => (
<ChatMessage <>
key={`${index}fullHistory`} {chatMessage.trainOfThought && chatMessage.by === "khoj" && (
ref={ <TrainOfThoughtComponent
// attach ref to the second last message to handle scroll on page load trainOfThought={chatMessage.trainOfThought?.map(
index === data.chat.length - 2 (train) => train.data,
? latestUserMessageRef )}
: // attach ref to the newest fetched message to handle scroll on fetch lastMessage={false}
// note: stabilize index selection against last page having less messages than fetchMessageCount agentColor={data?.agent?.color || "orange"}
index === key={`${index}trainOfThought`}
data.chat.length - (currentPage - 1) * fetchMessageCount completed={true}
? latestFetchedMessageRef />
: null )}
} <ChatMessage
isMobileWidth={isMobileWidth} key={`${index}fullHistory`}
chatMessage={chatMessage} ref={
customClassName="fullHistory" // attach ref to the second last message to handle scroll on page load
borderLeftColor={`${data?.agent?.color}-500`} index === data.chat.length - 2
isLastMessage={index === data.chat.length - 1} ? latestUserMessageRef
/> : // attach ref to the newest fetched message to handle scroll on fetch
// note: stabilize index selection against last page having less messages than fetchMessageCount
index ===
data.chat.length -
(currentPage - 1) * fetchMessageCount
? latestFetchedMessageRef
: null
}
isMobileWidth={isMobileWidth}
chatMessage={chatMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
/>
</>
))} ))}
{props.incomingMessages && {props.incomingMessages &&
props.incomingMessages.map((message, index) => { props.incomingMessages.map((message, index) => {
@@ -305,14 +345,15 @@ export default function ChatHistory(props: ChatHistoryProps) {
customClassName="fullHistory" customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`} borderLeftColor={`${data?.agent?.color}-500`}
/> />
{message.trainOfThought && {message.trainOfThought && (
constructTrainOfThought( <TrainOfThoughtComponent
message.trainOfThought, trainOfThought={message.trainOfThought}
index === incompleteIncomingMessageIndex, lastMessage={index === incompleteIncomingMessageIndex}
data?.agent?.color || "orange", agentColor={data?.agent?.color || "orange"}
`${index}trainOfThought`, key={`${index}trainOfThought`}
message.completed, completed={message.completed}
)} />
)}
<ChatMessage <ChatMessage
key={`${index}incoming`} key={`${index}incoming`}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}

View File

@@ -128,6 +128,11 @@ interface Intent {
"inferred-queries": string[]; "inferred-queries": string[];
} }
interface TrainOfThoughtObject {
type: string;
data: string;
}
export interface SingleChatMessage { export interface SingleChatMessage {
automationId: string; automationId: string;
by: string; by: string;
@@ -136,6 +141,7 @@ export interface SingleChatMessage {
context: Context[]; context: Context[];
onlineContext: OnlineContext; onlineContext: OnlineContext;
codeContext: CodeContext; codeContext: CodeContext;
trainOfThought?: TrainOfThoughtObject[];
rawQuery?: string; rawQuery?: string;
intent?: Intent; intent?: Intent;
agent?: AgentData; agent?: AgentData;

View File

@@ -146,7 +146,12 @@ class ChatEvent(Enum):
def message_to_log( def message_to_log(
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[] user_message,
chat_response,
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
train_of_thought=[],
): ):
"""Create json logs from messages, metadata for conversation log""" """Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = { default_khoj_message_metadata = {
@@ -182,6 +187,7 @@ def save_to_conversation_log(
automation_id: str = None, automation_id: str = None,
query_images: List[str] = None, query_images: List[str] = None,
tracer: Dict[str, Any] = {}, tracer: Dict[str, Any] = {},
train_of_thought: List[Any] = [],
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log( updated_conversation = message_to_log(
@@ -197,8 +203,10 @@ def save_to_conversation_log(
"onlineContext": online_results, "onlineContext": online_results,
"codeContext": code_results, "codeContext": code_results,
"automationId": automation_id, "automationId": automation_id,
"trainOfThought": train_of_thought,
}, },
conversation_log=meta_log.get("chat", []), conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
) )
ConversationAdapters.save_conversation( ConversationAdapters.save_conversation(
user, user,

View File

@@ -570,7 +570,9 @@ async def chat(
user: KhojUser = request.user.object user: KhojUser = request.user.object
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
train_of_thought = []
nonlocal conversation_id nonlocal conversation_id
tracer: dict = { tracer: dict = {
"mid": f"{uuid.uuid4()}", "mid": f"{uuid.uuid4()}",
"cid": conversation_id, "cid": conversation_id,
@@ -590,7 +592,7 @@ async def chat(
uploaded_images.append(uploaded_image) uploaded_images.append(uploaded_image)
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft nonlocal connection_alive, ttft, train_of_thought
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False connection_alive = False
logger.warning(f"User {user} disconnected from {common.client} client") logger.warning(f"User {user} disconnected from {common.client} client")
@@ -598,8 +600,11 @@ async def chat(
try: try:
if event_type == ChatEvent.END_LLM_RESPONSE: if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry() collect_telemetry()
if event_type == ChatEvent.START_LLM_RESPONSE: elif event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time ttft = time.perf_counter() - start_time
elif event_type == ChatEvent.STATUS:
train_of_thought.append({"type": event_type.value, "data": data})
if event_type == ChatEvent.MESSAGE: if event_type == ChatEvent.MESSAGE:
yield data yield data
elif event_type == ChatEvent.REFERENCES or stream: elif event_type == ChatEvent.REFERENCES or stream:
@@ -810,6 +815,7 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought,
) )
return return
@@ -854,6 +860,7 @@ async def chat(
automation_id=automation.id, automation_id=automation.id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@@ -1061,6 +1068,7 @@ async def chat(
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
@@ -1118,6 +1126,7 @@ async def chat(
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
@@ -1144,6 +1153,7 @@ async def chat(
researched_results, researched_results,
uploaded_images, uploaded_images,
tracer, tracer,
train_of_thought,
) )
# Send Response # Send Response

View File

@@ -1113,6 +1113,7 @@ def generate_chat_response(
meta_research: str = "", meta_research: str = "",
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
tracer: dict = {}, tracer: dict = {},
train_of_thought: List[Any] = [],
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@@ -1137,6 +1138,7 @@ def generate_chat_response(
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=query_images, query_images=query_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)