All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.awssdk.services.sqs.internal.MessageMD5ChecksumInterceptor Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package software.amazon.awssdk.services.sqs.internal;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkResponse;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.services.sqs.endpoints.SqsClientContextParams;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResultEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Md5Utils;

/**
 * SQS operations on sending and receiving messages will return the MD5 digest of the message body.
 * This custom request handler will verify that the message is correctly received by SQS, by
 * comparing the returned MD5 with the calculation according to the original request.
 */
@SdkInternalApi
public final class MessageMD5ChecksumInterceptor implements ExecutionInterceptor {

    private static final int INTEGER_SIZE_IN_BYTES = 4;
    private static final byte STRING_TYPE_FIELD_INDEX = 1;
    private static final byte BINARY_TYPE_FIELD_INDEX = 2;
    private static final byte STRING_LIST_TYPE_FIELD_INDEX = 3;
    private static final byte BINARY_LIST_TYPE_FIELD_INDEX = 4;

    /*
     * Constant strings for composing error message.
     */
    private static final String MD5_MISMATCH_ERROR_MESSAGE =
            "MD5 returned by SQS does not match the calculation on the original request. " +
            "(MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")";
    private static final String MD5_MISMATCH_ERROR_MESSAGE_WITH_ID =
            "MD5 returned by SQS does not match the calculation on the original request. " +
            "(Message ID: %s, MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")";
    private static final String MESSAGE_BODY = "message body";
    private static final String MESSAGE_ATTRIBUTES = "message attributes";

    private static final Logger log = Logger.loggerFor(MessageMD5ChecksumInterceptor.class);

    @Override
    public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
        SdkResponse response = context.response();
        SdkRequest originalRequest = context.request();

        if (response != null && validateMessageMD5Enabled(executionAttributes)) {
            if (originalRequest instanceof SendMessageRequest) {
                SendMessageRequest sendMessageRequest = (SendMessageRequest) originalRequest;
                SendMessageResponse sendMessageResult = (SendMessageResponse) response;
                sendMessageOperationMd5Check(sendMessageRequest, sendMessageResult);

            } else if (originalRequest instanceof ReceiveMessageRequest) {
                ReceiveMessageResponse receiveMessageResult = (ReceiveMessageResponse) response;
                receiveMessageResultMd5Check(receiveMessageResult);

            } else if (originalRequest instanceof SendMessageBatchRequest) {
                SendMessageBatchRequest sendMessageBatchRequest = (SendMessageBatchRequest) originalRequest;
                SendMessageBatchResponse sendMessageBatchResult = (SendMessageBatchResponse) response;
                sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResult);
            }
        }
    }

    private static boolean validateMessageMD5Enabled(ExecutionAttributes executionAttributes) {
        AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS);
        Boolean enableMd5Validation = clientContextParams.get(SqsClientContextParams.CHECKSUM_VALIDATION_ENABLED);
        return enableMd5Validation == null || enableMd5Validation;
    }

    /**
     * Throw an exception if the MD5 checksums returned in the SendMessageResponse do not match the
     * client-side calculation based on the original message in the SendMessageRequest.
     */
    private static void sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest,
                                                     SendMessageResponse sendMessageResult) {
        String messageBodySent = sendMessageRequest.messageBody();
        String bodyMd5Returned = sendMessageResult.md5OfMessageBody();
        String clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent);
        if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
            throw SdkClientException.builder()
                                    .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, clientSideBodyMd5,
                                                          bodyMd5Returned))
                                    .build();
        }

        Map messageAttrSent = sendMessageRequest.messageAttributes();
        if (messageAttrSent != null && !messageAttrSent.isEmpty()) {
            String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent);
            String attrMd5Returned = sendMessageResult.md5OfMessageAttributes();
            if (!clientSideAttrMd5.equals(attrMd5Returned)) {
                throw SdkClientException.builder()
                                        .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES,
                                                              clientSideAttrMd5, attrMd5Returned))
                                        .build();
            }
        }
    }

    /**
     * Throw an exception if the MD5 checksums included in the ReceiveMessageResponse do not match the
     * client-side calculation on the received messages.
     */
    private static void receiveMessageResultMd5Check(ReceiveMessageResponse receiveMessageResult) {
        if (receiveMessageResult.messages() != null) {
            for (Message messageReceived : receiveMessageResult.messages()) {
                String messageBody = messageReceived.body();
                String bodyMd5Returned = messageReceived.md5OfBody();
                String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody);
                if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
                    throw SdkClientException.builder()
                                            .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY,
                                                                  clientSideBodyMd5, bodyMd5Returned))
                                            .build();
                }

                Map messageAttr = messageReceived.messageAttributes();
                if (messageAttr != null && !messageAttr.isEmpty()) {
                    String attrMd5Returned = messageReceived.md5OfMessageAttributes();
                    String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr);
                    if (!clientSideAttrMd5.equals(attrMd5Returned)) {
                        throw SdkClientException.builder()
                                                .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES,
                                                                      clientSideAttrMd5, attrMd5Returned))
                                                .build();
                    }
                }
            }
        }
    }

    /**
     * Throw an exception if the MD5 checksums returned in the SendMessageBatchResponse do not match
     * the client-side calculation based on the original messages in the SendMessageBatchRequest.
     */
    private static void sendMessageBatchOperationMd5Check(SendMessageBatchRequest sendMessageBatchRequest,
                                                          SendMessageBatchResponse sendMessageBatchResult) {
        Map idToRequestEntryMap = new HashMap<>();
        if (sendMessageBatchRequest.entries() != null) {
            for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) {
                idToRequestEntryMap.put(entry.id(), entry);
            }
        }

        if (sendMessageBatchResult.successful() != null) {
            for (SendMessageBatchResultEntry entry : sendMessageBatchResult.successful()) {
                String messageBody = idToRequestEntryMap.get(entry.id()).messageBody();
                String bodyMd5Returned = entry.md5OfMessageBody();
                String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody);
                if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
                    throw SdkClientException.builder()
                                            .message(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, MESSAGE_BODY,
                                                                  entry.id(), clientSideBodyMd5, bodyMd5Returned))
                                            .build();
                }

                Map messageAttr = idToRequestEntryMap.get(entry.id())
                                                                                    .messageAttributes();
                if (messageAttr != null && !messageAttr.isEmpty()) {
                    String attrMd5Returned = entry.md5OfMessageAttributes();
                    String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr);
                    if (!clientSideAttrMd5.equals(attrMd5Returned)) {
                        throw SdkClientException.builder()
                                                .message(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID,
                                                                      MESSAGE_ATTRIBUTES, entry.id(), clientSideAttrMd5,
                                                                      attrMd5Returned))
                                                .build();
                    }
                }
            }
        }
    }

    /**
     * Returns the hex-encoded MD5 hash String of the given message body.
     */
    private static String calculateMessageBodyMd5(String messageBody) {
        log.debug(() -> "Message body: " + messageBody);
        byte[] expectedMd5;
        try {
            expectedMd5 = Md5Utils.computeMD5Hash(messageBody.getBytes(StandardCharsets.UTF_8));
        } catch (Exception e) {
            throw SdkClientException.builder()
                                    .message("Unable to calculate the MD5 hash of the message body. " + e.getMessage())
                                    .cause(e)
                                    .build();
        }
        String expectedMd5Hex = BinaryUtils.toHex(expectedMd5);
        log.debug(() -> "Expected  MD5 of message body: " + expectedMd5Hex);
        return expectedMd5Hex;
    }

    /**
     * Returns the hex-encoded MD5 hash String of the given message attributes.
     */
    private static String calculateMessageAttributesMd5(final Map messageAttributes) {
        log.debug(() -> "Message attributes: " + messageAttributes);
        List sortedAttributeNames = new ArrayList<>(messageAttributes.keySet());
        Collections.sort(sortedAttributeNames);

        MessageDigest md5Digest;
        try {
            md5Digest = MessageDigest.getInstance("MD5");

            for (String attrName : sortedAttributeNames) {
                MessageAttributeValue attrValue = messageAttributes.get(attrName);

                // Encoded Name
                updateLengthAndBytes(md5Digest, attrName);

                // Encoded Type
                updateLengthAndBytes(md5Digest, attrValue.dataType());

                // Encoded Value
                if (attrValue.stringValue() != null) {
                    md5Digest.update(STRING_TYPE_FIELD_INDEX);
                    updateLengthAndBytes(md5Digest, attrValue.stringValue());
                } else if (attrValue.binaryValue() != null) {
                    md5Digest.update(BINARY_TYPE_FIELD_INDEX);
                    updateLengthAndBytes(md5Digest, attrValue.binaryValue().asByteBuffer());
                } else if (attrValue.stringListValues() != null &&
                           attrValue.stringListValues().size() > 0) {
                    md5Digest.update(STRING_LIST_TYPE_FIELD_INDEX);
                    for (String strListMember : attrValue.stringListValues()) {
                        updateLengthAndBytes(md5Digest, strListMember);
                    }
                } else if (attrValue.binaryListValues() != null &&
                           attrValue.binaryListValues().size() > 0) {
                    md5Digest.update(BINARY_LIST_TYPE_FIELD_INDEX);
                    for (SdkBytes byteListMember : attrValue.binaryListValues()) {
                        updateLengthAndBytes(md5Digest, byteListMember.asByteBuffer());
                    }
                }
            }
        } catch (Exception e) {
            throw SdkClientException.builder()
                                    .message("Unable to calculate the MD5 hash of the message attributes. " + e.getMessage())
                                    .cause(e)
                                    .build();
        }

        String expectedMd5Hex = BinaryUtils.toHex(md5Digest.digest());
        log.debug(() -> "Expected  MD5 of message attributes: " + expectedMd5Hex);
        return expectedMd5Hex;
    }

    /**
     * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
     * input String and the actual utf8-encoded byte values.
     */
    private static void updateLengthAndBytes(MessageDigest digest, String str) {
        byte[] utf8Encoded = str.getBytes(StandardCharsets.UTF_8);
        ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(utf8Encoded.length);
        digest.update(lengthBytes.array());
        digest.update(utf8Encoded);
    }

    /**
     * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
     * input ByteBuffer and all the bytes it contains.
     */
    private static void updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue) {
        ByteBuffer readOnlyBuffer = binaryValue.asReadOnlyBuffer();
        int size = readOnlyBuffer.remaining();
        ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(size);
        digest.update(lengthBytes.array());
        digest.update(readOnlyBuffer);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy