From aa081913bf73b83fcb1634dbcbbf718e7051dd38 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 24 Jun 2025 02:13:04 -0700 Subject: [PATCH] Improve truncation with tool use and Anthropic caching - Cache last anthropic message. Given research mode now uses function calling paradigm and not the old research mode structure. - Cache tool definitions passed to anthropic models - Stop dropping first message if by assistant as seems like Anthropic API doesn't complain about it any more. - Drop tool result when tool call is truncated as invalid state - Do not truncate tool use message content, just drop the whole tool use message. AI model APIs need tool use assistant message content in specific form (e.g with thinking etc.). So dropping content items breaks expected tool use message content format. Handle tool use scenarios where iteration query isn't set for retry --- .../processor/conversation/anthropic/utils.py | 25 ++++++------ src/khoj/processor/conversation/utils.py | 38 +++++++++++++------ tests/test_conversation_utils.py | 21 +++++----- 3 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index fab7ed54..360f0a51 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -82,6 +82,9 @@ def anthropic_completion_with_backoff( anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema) for tool in tools ] + # Cache tool definitions + last_tool = model_kwargs["tools"][-1] + last_tool["cache_control"] = {"type": "ephemeral"} elif response_schema: tool = create_tool_definition(response_schema) model_kwargs["tools"] = [ @@ -291,11 +294,10 @@ def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt else: system = None - # Anthropic requires the first message to be a 'user' message - if len(messages) == 1: + # Anthropic requires the first message to be a user message unless its a tool call + message_type = messages[0].additional_kwargs.get("message_type", None) + if len(messages) == 1 and message_type != "tool_call": messages[0].role = "user" - elif len(messages) > 1 and messages[0].role == "assistant": - messages = messages[1:] for message in messages: # Handle tool call and tool result message types from additional_kwargs @@ -361,18 +363,15 @@ def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt if isinstance(block, dict) and "cache_control" in block: del block["cache_control"] - # Add cache control to the last content block of second to last message. - # In research mode, this message content is list of iterations, updated after each research iteration. - # Caching it should improve research efficiency. - cache_message = messages[-2] + # Add cache control to the last content block of last message. + # Caching should improve research efficiency. + cache_message = messages[-1] if isinstance(cache_message.content, list) and cache_message.content: # Add cache control to the last content block only if it's a text block with non-empty content last_block = cache_message.content[-1] - if ( - isinstance(last_block, dict) - and last_block.get("type") == "text" - and last_block.get("text") - and last_block.get("text").strip() + if isinstance(last_block, dict) and ( + (last_block.get("type") == "text" and last_block.get("text", "").strip()) + or (last_block.get("type") == "tool_result" and last_block.get("content", [])) ): last_block["cache_control"] = {"type": "ephemeral"} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 36390389..7d3f8fb1 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -184,7 +184,7 @@ def construct_iteration_history( iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) for iteration in previous_iterations: - if not iteration.query: + if not iteration.query or isinstance(iteration.query, str): iteration_history.append( ChatMessageModel( by="you", @@ -336,7 +336,7 @@ def construct_tool_chat_history( ), } for iteration in previous_iterations: - if not iteration.query: + if not iteration.query or isinstance(iteration.query, str): chat_history.append( ChatMessageModel( by="you", @@ -806,6 +806,15 @@ def count_tokens( return len(encoder.encode(json.dumps(message_content))) +def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]: + """Count total tokens in messages including system message""" + system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 + message_tokens = sum([count_tokens(message.content, encoder) for message in messages]) + # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) + total_tokens = message_tokens + system_message_tokens + 4 * len(messages) + return total_tokens, system_message_tokens + + def truncate_messages( messages: list[ChatMessage], max_prompt_size: int, @@ -824,23 +833,30 @@ def truncate_messages( break # Drop older messages until under max supported prompt size by model - # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) - system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 - tokens = sum([count_tokens(message.content, encoder) for message in messages]) - total_tokens = tokens + system_message_tokens + 4 * len(messages) + total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message) while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1): - if len(messages[-1].content) > 1: + # If the last message has more than one content part, pop the oldest content part. + # For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs. + if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call": # The oldest content part is earlier in content list. So pop from the front. messages[-1].content.pop(0) + # Otherwise, pop the last message if it has only one content part or is a tool call. else: # The oldest message is the last one. So pop from the back. - messages.pop() - tokens = sum([count_tokens(message.content, encoder) for message in messages]) - total_tokens = tokens + system_message_tokens + 4 * len(messages) + dropped_message = messages.pop() + # Drop tool result pair of tool call, if tool call message has been removed + if ( + dropped_message.additional_kwargs.get("message_type") == "tool_call" + and messages + and messages[-1].additional_kwargs.get("message_type") == "tool_result" + ): + messages.pop() + + total_tokens, _ = count_total_tokens(messages, encoder, system_message) # Truncate current message if still over max supported prompt size by model - total_tokens = tokens + system_message_tokens + 4 * len(messages) + total_tokens, _ = count_total_tokens(messages, encoder, system_message) if total_tokens > max_prompt_size: # At this point, a single message with a single content part of type dict should remain assert ( diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index eb613e46..449dbb95 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -48,17 +48,18 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=content_list) copy_big_chat_message = deepcopy(big_chat_message) chat_history = [big_chat_message] - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" - assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" def test_truncate_message_with_content_list(self): # Arrange @@ -68,11 +69,11 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=content_list) copy_big_chat_message = deepcopy(big_chat_message) chat_history.insert(0, big_chat_message) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties @@ -83,7 +84,8 @@ class TestTruncateMessage: copy_big_chat_message.content ), "message content list should be modified" assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" - assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" def test_truncate_message_first_large(self): # Arrange @@ -91,11 +93,11 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties @@ -104,7 +106,8 @@ class TestTruncateMessage: ), "Only most recent message should be present as it itself is larger than context size" assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" - assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" def test_truncate_message_large_system_message_first(self): # Arrange