From 8a16f5a2afe3b6343c8bf661d63c7d1af5b4879c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 25 Aug 2025 00:33:19 -0700 Subject: [PATCH] Reduce logical complexity of constructing context from chat history - Process chat history in default order instead of processing it in reverse. Improve legibility of context construction for minor performance hit in dropping message from front of list. - Handle multiple system messages by collating them into list - Remove logic to drop system role for gemma-2, o1 models. Better to make code more readable than support old models. --- src/khoj/processor/conversation/utils.py | 102 ++++++++++++----------- tests/test_conversation_utils.py | 16 ++-- 2 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index dcd66a2e..4feedf08 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -262,7 +262,7 @@ def construct_question_history( continue message = chat.message - inferred_queries_list = chat.intent.inferred_queries or [] + inferred_queries_list = chat.intent.inferred_queries or [] if chat.intent else [] # Ensure inferred_queries_list is a list, defaulting to the original query in a list if not inferred_queries_list: @@ -593,18 +593,10 @@ def generate_chatml_messages_with_context( role = "user" if chat.by == "you" else "assistant" # Legacy code to handle excalidraw diagrams prior to Dec 2024 - if chat.by == "khoj" and "excalidraw" in chat.intent.type or "": + if chat.by == "khoj" and chat.intent and "excalidraw" in chat.intent.type: chat_message = (chat.intent.inferred_queries or [])[0] - if chat.queryFiles: - raw_query_files = chat.queryFiles - query_files_dict = dict() - for file in raw_query_files: - query_files_dict[file["name"]] = file["content"] - - message_attached_files = gather_raw_query_files(query_files_dict) - chatml_messages.append(ChatMessage(content=message_attached_files, role=role)) - + # Add search and action context if not is_none_or_empty(chat.onlineContext): message_context += [ { @@ -643,7 +635,7 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") - chatml_messages.insert(0, reconstructed_context_message) + chatml_messages.append(reconstructed_context_message) # Add generated assets if not is_none_or_empty(chat.images) and role == "assistant": @@ -664,8 +656,17 @@ def generate_chatml_messages_with_context( ) ) + # Add user query with attached file, images or khoj response + if chat.queryFiles: + raw_query_files = chat.queryFiles + query_files_dict = dict() + for file in raw_query_files: + query_files_dict[file["name"]] = file["content"] + + message_attached_files = gather_raw_query_files(query_files_dict) + message_content = construct_structured_message( - chat_message, chat.images if role == "user" else [], model_type, vision_enabled + chat_message, chat.images if role == "user" else [], model_type, vision_enabled, message_attached_files ) reconstructed_message = ChatMessage( @@ -673,19 +674,36 @@ def generate_chatml_messages_with_context( role=role, additional_kwargs={"message_type": chat.intent.type if chat.intent else None}, ) - chatml_messages.insert(0, reconstructed_message) + chatml_messages.append(reconstructed_message) if len(chatml_messages) >= 3 * lookback_turns: break messages: list[ChatMessage] = [] + if not is_none_or_empty(system_message): + messages.append(ChatMessage(content=system_message, role="system")) + + if len(chatml_messages) > 0: + messages += chatml_messages + + if program_execution_context: + program_context_text = "\n".join(program_execution_context) + context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n" + + if not is_none_or_empty(context_message): + messages.append(ChatMessage(content=context_message, role="user")) + + if generated_files: + message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files}) + messages.append(ChatMessage(content=message_attached_files, role="assistant")) + if not is_none_or_empty(generated_asset_results): messages.append( ChatMessage( - content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n", + content=prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results)), role="user", - ) + ), ) if not is_none_or_empty(user_message): @@ -698,23 +716,6 @@ def generate_chatml_messages_with_context( ) ) - if generated_files: - message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files}) - messages.append(ChatMessage(content=message_attached_files, role="assistant")) - - if program_execution_context: - program_context_text = "\n".join(program_execution_context) - context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n" - - if not is_none_or_empty(context_message): - messages.append(ChatMessage(content=context_message, role="user")) - - if len(chatml_messages) > 0: - messages += chatml_messages - - if not is_none_or_empty(system_message): - messages.append(ChatMessage(content=system_message, role="system")) - # Normalize message content to list of chatml dictionaries for message in messages: if isinstance(message.content, str): @@ -723,8 +724,8 @@ def generate_chatml_messages_with_context( # Truncate oldest messages from conversation history until under max supported prompt size by model messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) - # Return message in chronological order - return messages[::-1] + # Return messages in chronological order + return messages def get_encoder( @@ -795,7 +796,9 @@ def count_tokens( def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]: """Count total tokens in messages including system message""" - system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 + system_message_tokens = ( + sum([count_tokens(message.content, encoder) for message in system_message]) if system_message else 0 + ) message_tokens = sum([count_tokens(message.content, encoder) for message in messages]) # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) total_tokens = message_tokens + system_message_tokens + 4 * len(messages) @@ -812,11 +815,14 @@ def truncate_messages( encoder = get_encoder(model_name, tokenizer_name) # Extract system message from messages - system_message = None - for idx, message in enumerate(messages): + system_message = [] + non_system_messages = [] + for message in messages: if message.role == "system": - system_message = messages.pop(idx) - break + system_message.append(message) + else: + non_system_messages.append(message) + messages = non_system_messages # Drop older messages until under max supported prompt size by model total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message) @@ -824,20 +830,20 @@ def truncate_messages( while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1): # If the last message has more than one content part, pop the oldest content part. # For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs. - if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call": + if len(messages[0].content) > 1 and messages[0].additional_kwargs.get("message_type") != "tool_call": # The oldest content part is earlier in content list. So pop from the front. - messages[-1].content.pop(0) + messages[0].content.pop(0) # Otherwise, pop the last message if it has only one content part or is a tool call. else: # The oldest message is the last one. So pop from the back. - dropped_message = messages.pop() + dropped_message = messages.pop(0) # Drop tool result pair of tool call, if tool call message has been removed if ( dropped_message.additional_kwargs.get("message_type") == "tool_call" and messages - and messages[-1].additional_kwargs.get("message_type") == "tool_result" + and messages[0].additional_kwargs.get("message_type") == "tool_result" ): - messages.pop() + messages.pop(0) total_tokens, _ = count_total_tokens(messages, encoder, system_message) @@ -876,11 +882,7 @@ def truncate_messages( f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}" ) - if system_message: - # Default system message role is system. - # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series. - system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system" - return messages + [system_message] if system_message else messages + return system_message + messages if system_message else messages def reciprocal_conversation_to_chatml(message_pair): diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index f50b7fdb..88613086 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -29,7 +29,7 @@ class TestTruncateMessage: # Arrange chat_history = generate_chat_history(5) big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) - chat_history.append(big_chat_message) + chat_history.insert(0, big_chat_message) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) @@ -68,7 +68,7 @@ class TestTruncateMessage: content_list += [{"type": "text", "text": "Question?"}] big_chat_message = ChatMessage(role="user", content=content_list) copy_big_chat_message = deepcopy(big_chat_message) - chat_history.insert(0, big_chat_message) + chat_history.append(big_chat_message) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act @@ -92,7 +92,7 @@ class TestTruncateMessage: chat_history = generate_chat_history(5) big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.model_copy() - chat_history.insert(0, big_chat_message) + chat_history.append(big_chat_message) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act @@ -104,8 +104,8 @@ class TestTruncateMessage: assert len(chat_history) == 1, ( "Only most recent message should be present as it itself is larger than context size" ) - assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" - assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" + assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" @@ -116,7 +116,7 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.model_copy() - chat_history.insert(0, big_chat_message) + chat_history.append(big_chat_message) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act @@ -130,8 +130,8 @@ class TestTruncateMessage: ) # Because the system_prompt is popped off from the chat_messages list assert len(truncated_chat_history) < 10 assert len(truncated_chat_history) > 1 - assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" - assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" + assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"