mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Reduce logical complexity of constructing context from chat history
- Process chat history in default order instead of processing it in reverse. Improve legibility of context construction for minor performance hit in dropping message from front of list. - Handle multiple system messages by collating them into list - Remove logic to drop system role for gemma-2, o1 models. Better to make code more readable than support old models.
This commit is contained in:
@@ -262,7 +262,7 @@ def construct_question_history(
|
||||
continue
|
||||
|
||||
message = chat.message
|
||||
inferred_queries_list = chat.intent.inferred_queries or []
|
||||
inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else []
|
||||
|
||||
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
|
||||
if not inferred_queries_list:
|
||||
@@ -593,18 +593,10 @@ def generate_chatml_messages_with_context(
|
||||
role = "user" if chat.by == "you" else "assistant"
|
||||
|
||||
# Legacy code to handle excalidraw diagrams prior to Dec 2024
|
||||
if chat.by == "khoj" and "excalidraw" in chat.intent.type or "":
|
||||
if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type:
|
||||
chat_message = (chat.intent.inferred_queries or [])[0]
|
||||
|
||||
if chat.queryFiles:
|
||||
raw_query_files = chat.queryFiles
|
||||
query_files_dict = dict()
|
||||
for file in raw_query_files:
|
||||
query_files_dict[file["name"]] = file["content"]
|
||||
|
||||
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
|
||||
|
||||
# Add search and action context
|
||||
if not is_none_or_empty(chat.onlineContext):
|
||||
message_context += [
|
||||
{
|
||||
@@ -643,7 +635,7 @@ def generate_chatml_messages_with_context(
|
||||
|
||||
if not is_none_or_empty(message_context):
|
||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||
chatml_messages.insert(0, reconstructed_context_message)
|
||||
chatml_messages.append(reconstructed_context_message)
|
||||
|
||||
# Add generated assets
|
||||
if not is_none_or_empty(chat.images) and role == "assistant":
|
||||
@@ -664,8 +656,17 @@ def generate_chatml_messages_with_context(
|
||||
)
|
||||
)
|
||||
|
||||
# Add user query with attached file, images or khoj response
|
||||
if chat.queryFiles:
|
||||
raw_query_files = chat.queryFiles
|
||||
query_files_dict = dict()
|
||||
for file in raw_query_files:
|
||||
query_files_dict[file["name"]] = file["content"]
|
||||
|
||||
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||
|
||||
message_content = construct_structured_message(
|
||||
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
|
||||
chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files
|
||||
)
|
||||
|
||||
reconstructed_message = ChatMessage(
|
||||
@@ -673,19 +674,36 @@ def generate_chatml_messages_with_context(
|
||||
role=role,
|
||||
additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
|
||||
)
|
||||
chatml_messages.insert(0, reconstructed_message)
|
||||
chatml_messages.append(reconstructed_message)
|
||||
|
||||
if len(chatml_messages) >= 3 * lookback_turns:
|
||||
break
|
||||
|
||||
messages: list[ChatMessage] = []
|
||||
|
||||
if not is_none_or_empty(system_message):
|
||||
messages.append(ChatMessage(content=system_message, role="system"))
|
||||
|
||||
if len(chatml_messages) > 0:
|
||||
messages += chatml_messages
|
||||
|
||||
if program_execution_context:
|
||||
program_context_text = "\n".join(program_execution_context)
|
||||
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
|
||||
|
||||
if not is_none_or_empty(context_message):
|
||||
messages.append(ChatMessage(content=context_message, role="user"))
|
||||
|
||||
if generated_files:
|
||||
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
||||
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
||||
|
||||
if not is_none_or_empty(generated_asset_results):
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n",
|
||||
content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)),
|
||||
role="user",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if not is_none_or_empty(user_message):
|
||||
@@ -698,23 +716,6 @@ def generate_chatml_messages_with_context(
|
||||
)
|
||||
)
|
||||
|
||||
if generated_files:
|
||||
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
||||
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
||||
|
||||
if program_execution_context:
|
||||
program_context_text = "\n".join(program_execution_context)
|
||||
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
|
||||
|
||||
if not is_none_or_empty(context_message):
|
||||
messages.append(ChatMessage(content=context_message, role="user"))
|
||||
|
||||
if len(chatml_messages) > 0:
|
||||
messages += chatml_messages
|
||||
|
||||
if not is_none_or_empty(system_message):
|
||||
messages.append(ChatMessage(content=system_message, role="system"))
|
||||
|
||||
# Normalize message content to list of chatml dictionaries
|
||||
for message in messages:
|
||||
if isinstance(message.content, str):
|
||||
@@ -723,8 +724,8 @@ def generate_chatml_messages_with_context(
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
# Return messages in chronological order
|
||||
return messages
|
||||
|
||||
|
||||
def get_encoder(
|
||||
@@ -795,7 +796,9 @@ def count_tokens(
|
||||
|
||||
def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
|
||||
"""Count total tokens in messages including system message"""
|
||||
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
|
||||
system_message_tokens = (
|
||||
sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0
|
||||
)
|
||||
message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
||||
total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
|
||||
@@ -812,11 +815,14 @@ def truncate_messages(
|
||||
encoder = get_encoder(model_name, tokenizer_name)
|
||||
|
||||
# Extract system message from messages
|
||||
system_message = None
|
||||
for idx, message in enumerate(messages):
|
||||
system_message = []
|
||||
non_system_messages = []
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
system_message = messages.pop(idx)
|
||||
break
|
||||
system_message.append(message)
|
||||
else:
|
||||
non_system_messages.append(message)
|
||||
messages = non_system_messages
|
||||
|
||||
# Drop older messages until under max supported prompt size by model
|
||||
total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
|
||||
@@ -824,20 +830,20 @@ def truncate_messages(
|
||||
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
|
||||
# If the last message has more than one content part, pop the oldest content part.
|
||||
# For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
|
||||
if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call":
|
||||
if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call":
|
||||
# The oldest content part is earlier in content list. So pop from the front.
|
||||
messages[-1].content.pop(0)
|
||||
messages[0].content.pop(0)
|
||||
# Otherwise, pop the last message if it has only one content part or is a tool call.
|
||||
else:
|
||||
# The oldest message is the last one. So pop from the back.
|
||||
dropped_message = messages.pop()
|
||||
dropped_message = messages.pop(0)
|
||||
# Drop tool result pair of tool call, if tool call message has been removed
|
||||
if (
|
||||
dropped_message.additional_kwargs.get("message_type") == "tool_call"
|
||||
and messages
|
||||
and messages[-1].additional_kwargs.get("message_type") == "tool_result"
|
||||
and messages[0].additional_kwargs.get("message_type") == "tool_result"
|
||||
):
|
||||
messages.pop()
|
||||
messages.pop(0)
|
||||
|
||||
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
|
||||
|
||||
@@ -876,11 +882,7 @@ def truncate_messages(
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
|
||||
)
|
||||
|
||||
if system_message:
|
||||
# Default system message role is system.
|
||||
# Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
|
||||
system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
|
||||
return messages + [system_message] if system_message else messages
|
||||
return system_message + messages if system_message else messages
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestTruncateMessage:
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
chat_history.append(big_chat_message)
|
||||
chat_history.insert(0, big_chat_message)
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
@@ -68,7 +68,7 @@ class TestTruncateMessage:
|
||||
content_list += [{"type": "text", "text": "Question?"}]
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history.insert(0, big_chat_message)
|
||||
chat_history.append(big_chat_message)
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
@@ -92,7 +92,7 @@ class TestTruncateMessage:
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.model_copy()
|
||||
chat_history.insert(0, big_chat_message)
|
||||
chat_history.append(big_chat_message)
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
@@ -104,8 +104,8 @@ class TestTruncateMessage:
|
||||
assert len(chat_history) == 1, (
|
||||
"Only most recent message should be present as it itself is larger than context size"
|
||||
)
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
@@ -116,7 +116,7 @@ class TestTruncateMessage:
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.model_copy()
|
||||
|
||||
chat_history.insert(0, big_chat_message)
|
||||
chat_history.append(big_chat_message)
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
@@ -130,8 +130,8 @@ class TestTruncateMessage:
|
||||
) # Because the system_prompt is popped off from the chat_messages list
|
||||
assert len(truncated_chat_history) < 10
|
||||
assert len(truncated_chat_history) > 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user