diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index fab7ed54..360f0a51 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -82,6 +82,9 @@ def anthropic_completion_with_backoff( anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema) for tool in tools ] + # Cache tool definitions + last_tool = model_kwargs["tools"][-1] + last_tool["cache_control"] = {"type": "ephemeral"} elif response_schema: tool = create_tool_definition(response_schema) model_kwargs["tools"] = [ @@ -291,11 +294,10 @@ def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt else: system = None - # Anthropic requires the first message to be a 'user' message - if len(messages) == 1: + # Anthropic requires the first message to be a user message unless its a tool call + message_type = messages[0].additional_kwargs.get("message_type", None) + if len(messages) == 1 and message_type != "tool_call": messages[0].role = "user" - elif len(messages) > 1 and messages[0].role == "assistant": - messages = messages[1:] for message in messages: # Handle tool call and tool result message types from additional_kwargs @@ -361,18 +363,15 @@ def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt if isinstance(block, dict) and "cache_control" in block: del block["cache_control"] - # Add cache control to the last content block of second to last message. - # In research mode, this message content is list of iterations, updated after each research iteration. - # Caching it should improve research efficiency. - cache_message = messages[-2] + # Add cache control to the last content block of last message. + # Caching should improve research efficiency. + cache_message = messages[-1] if isinstance(cache_message.content, list) and cache_message.content: # Add cache control to the last content block only if it's a text block with non-empty content last_block = cache_message.content[-1] - if ( - isinstance(last_block, dict) - and last_block.get("type") == "text" - and last_block.get("text") - and last_block.get("text").strip() + if isinstance(last_block, dict) and ( + (last_block.get("type") == "text" and last_block.get("text", "").strip()) + or (last_block.get("type") == "tool_result" and last_block.get("content", [])) ): last_block["cache_control"] = {"type": "ephemeral"} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 36390389..7d3f8fb1 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -184,7 +184,7 @@ def construct_iteration_history( iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) for iteration in previous_iterations: - if not iteration.query: + if not iteration.query or isinstance(iteration.query, str): iteration_history.append( ChatMessageModel( by="you", @@ -336,7 +336,7 @@ def construct_tool_chat_history( ), } for iteration in previous_iterations: - if not iteration.query: + if not iteration.query or isinstance(iteration.query, str): chat_history.append( ChatMessageModel( by="you", @@ -806,6 +806,15 @@ def count_tokens( return len(encoder.encode(json.dumps(message_content))) +def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]: + """Count total tokens in messages including system message""" + system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 + message_tokens = sum([count_tokens(message.content, encoder) for message in messages]) + # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) + total_tokens = message_tokens + system_message_tokens + 4 * len(messages) + return total_tokens, system_message_tokens + + def truncate_messages( messages: list[ChatMessage], max_prompt_size: int, @@ -824,23 +833,30 @@ def truncate_messages( break # Drop older messages until under max supported prompt size by model - # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) - system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 - tokens = sum([count_tokens(message.content, encoder) for message in messages]) - total_tokens = tokens + system_message_tokens + 4 * len(messages) + total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message) while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1): - if len(messages[-1].content) > 1: + # If the last message has more than one content part, pop the oldest content part. + # For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs. + if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call": # The oldest content part is earlier in content list. So pop from the front. messages[-1].content.pop(0) + # Otherwise, pop the last message if it has only one content part or is a tool call. else: # The oldest message is the last one. So pop from the back. - messages.pop() - tokens = sum([count_tokens(message.content, encoder) for message in messages]) - total_tokens = tokens + system_message_tokens + 4 * len(messages) + dropped_message = messages.pop() + # Drop tool result pair of tool call, if tool call message has been removed + if ( + dropped_message.additional_kwargs.get("message_type") == "tool_call" + and messages + and messages[-1].additional_kwargs.get("message_type") == "tool_result" + ): + messages.pop() + + total_tokens, _ = count_total_tokens(messages, encoder, system_message) # Truncate current message if still over max supported prompt size by model - total_tokens = tokens + system_message_tokens + 4 * len(messages) + total_tokens, _ = count_total_tokens(messages, encoder, system_message) if total_tokens > max_prompt_size: # At this point, a single message with a single content part of type dict should remain assert ( diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index eb613e46..449dbb95 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -48,17 +48,18 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=content_list) copy_big_chat_message = deepcopy(big_chat_message) chat_history = [big_chat_message] - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" - assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" def test_truncate_message_with_content_list(self): # Arrange @@ -68,11 +69,11 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=content_list) copy_big_chat_message = deepcopy(big_chat_message) chat_history.insert(0, big_chat_message) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties @@ -83,7 +84,8 @@ class TestTruncateMessage: copy_big_chat_message.content ), "message content list should be modified" assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" - assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" def test_truncate_message_first_large(self): # Arrange @@ -91,11 +93,11 @@ class TestTruncateMessage: big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties @@ -104,7 +106,8 @@ class TestTruncateMessage: ), "Only most recent message should be present as it itself is larger than context size" assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" - assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" def test_truncate_message_large_system_message_first(self): # Arrange