Reduce logical complexity of constructing context from chat history

- Process chat history in default order instead of processing it in
  reverse. Improve legibility of context construction for minor
  performance hit in dropping message from front of list.
- Handle multiple system messages by collating them into list
- Remove logic to drop system role for gemma-2, o1 models. Better to
  make code more readable than support old models.
This commit is contained in:
Debanjum
2025-08-25 00:33:19 -07:00
parent 1e81b51abc
commit 8a16f5a2af
2 changed files with 60 additions and 58 deletions

View File

@@ -262,7 +262,7 @@ def construct_question_history(
continue
message = chat.message
inferred_queries_list = chat.intent.inferred_queries or []
inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else []
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
if not inferred_queries_list:
@@ -593,18 +593,10 @@ def generate_chatml_messages_with_context(
role = "user" if chat.by == "you" else "assistant"
# Legacy code to handle excalidraw diagrams prior to Dec 2024
if chat.by == "khoj" and "excalidraw" in chat.intent.type or "":
if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type:
chat_message = (chat.intent.inferred_queries or [])[0]
if chat.queryFiles:
raw_query_files = chat.queryFiles
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
# Add search and action context
if not is_none_or_empty(chat.onlineContext):
message_context += [
{
@@ -643,7 +635,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)
chatml_messages.append(reconstructed_context_message)
# Add generated assets
if not is_none_or_empty(chat.images) and role == "assistant":
@@ -664,8 +656,17 @@ def generate_chatml_messages_with_context(
)
)
# Add user query with attached file, images or khoj response
if chat.queryFiles:
raw_query_files = chat.queryFiles
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
message_content = construct_structured_message(
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files
)
reconstructed_message = ChatMessage(
@@ -673,19 +674,36 @@ def generate_chatml_messages_with_context(
role=role,
additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
)
chatml_messages.insert(0, reconstructed_message)
chatml_messages.append(reconstructed_message)
if len(chatml_messages) >= 3 * lookback_turns:
break
messages: list[ChatMessage] = []
if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system"))
if len(chatml_messages) > 0:
messages += chatml_messages
if program_execution_context:
program_context_text = "\n".join(program_execution_context)
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if generated_files:
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
if not is_none_or_empty(generated_asset_results):
messages.append(
ChatMessage(
content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n",
content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)),
role="user",
)
),
)
if not is_none_or_empty(user_message):
@@ -698,23 +716,6 @@ def generate_chatml_messages_with_context(
)
)
if generated_files:
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
if program_execution_context:
program_context_text = "\n".join(program_execution_context)
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if len(chatml_messages) > 0:
messages += chatml_messages
if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system"))
# Normalize message content to list of chatml dictionaries
for message in messages:
if isinstance(message.content, str):
@@ -723,8 +724,8 @@ def generate_chatml_messages_with_context(
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order
return messages[::-1]
# Return messages in chronological order
return messages
def get_encoder(
@@ -795,7 +796,9 @@ def count_tokens(
def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
"""Count total tokens in messages including system message"""
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
system_message_tokens = (
sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0
)
message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
@@ -812,11 +815,14 @@ def truncate_messages(
encoder = get_encoder(model_name, tokenizer_name)
# Extract system message from messages
system_message = None
for idx, message in enumerate(messages):
system_message = []
non_system_messages = []
for message in messages:
if message.role == "system":
system_message = messages.pop(idx)
break
system_message.append(message)
else:
non_system_messages.append(message)
messages = non_system_messages
# Drop older messages until under max supported prompt size by model
total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
@@ -824,20 +830,20 @@ def truncate_messages(
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
# If the last message has more than one content part, pop the oldest content part.
# For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call":
if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call":
# The oldest content part is earlier in content list. So pop from the front.
messages[-1].content.pop(0)
messages[0].content.pop(0)
# Otherwise, pop the last message if it has only one content part or is a tool call.
else:
# The oldest message is the last one. So pop from the back.
dropped_message = messages.pop()
dropped_message = messages.pop(0)
# Drop tool result pair of tool call, if tool call message has been removed
if (
dropped_message.additional_kwargs.get("message_type") == "tool_call"
and messages
and messages[-1].additional_kwargs.get("message_type") == "tool_result"
and messages[0].additional_kwargs.get("message_type") == "tool_result"
):
messages.pop()
messages.pop(0)
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
@@ -876,11 +882,7 @@ def truncate_messages(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
)
if system_message:
# Default system message role is system.
# Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
return messages + [system_message] if system_message else messages
return system_message + messages if system_message else messages
def reciprocal_conversation_to_chatml(message_pair):

View File

@@ -29,7 +29,7 @@ class TestTruncateMessage:
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
chat_history.append(big_chat_message)
chat_history.insert(0, big_chat_message)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
@@ -68,7 +68,7 @@ class TestTruncateMessage:
content_list += [{"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history.insert(0, big_chat_message)
chat_history.append(big_chat_message)
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
@@ -92,7 +92,7 @@ class TestTruncateMessage:
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.model_copy()
chat_history.insert(0, big_chat_message)
chat_history.append(big_chat_message)
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
@@ -104,8 +104,8 @@ class TestTruncateMessage:
assert len(chat_history) == 1, (
"Only most recent message should be present as it itself is larger than context size"
)
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
@@ -116,7 +116,7 @@ class TestTruncateMessage:
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.model_copy()
chat_history.insert(0, big_chat_message)
chat_history.append(big_chat_message)
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
@@ -130,8 +130,8 @@ class TestTruncateMessage:
) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"