Skip non-serializable, binary content parts when token counting

This commit is contained in:
Debanjum
2025-11-18 10:59:31 -08:00
parent ec31df7154
commit a30c5f245d

View File

@@ -1021,14 +1021,20 @@ def count_tokens(
elif isinstance(part, dict): elif isinstance(part, dict):
# If part is a dict but not a recognized type, convert to JSON string # If part is a dict but not a recognized type, convert to JSON string
try: try:
message_content_parts.append(json.dumps(part)) # Skip non-serializable binary values for token counting
serializable_part = {
k: v for k, v in part.items() if not isinstance(v, (bytes, bytearray, memoryview))
}
message_content_parts.append(json.dumps(serializable_part))
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.") logger.warning(
f"Failed to serialize part {part} to JSON. Assume its an image for token counting.\n{e}."
)
image_count += 1 # Treat as an image/binary if serialization fails image_count += 1 # Treat as an image/binary if serialization fails
elif isinstance(part, str): elif isinstance(part, str):
message_content_parts.append(part) message_content_parts.append(part)
else: else:
logger.warning(f"Unknown message type: {part}. Skipping.") logger.warning(f"Unknown message type: {part}. Skip for token counting.")
message_content = "\n".join(message_content_parts).rstrip() message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500 return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str): elif isinstance(message_content, str):