All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.lambda.powertools.sqs.internal.BatchContext Maven / Gradle / Ivy

package software.amazon.lambda.powertools.sqs.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.GetQueueAttributesRequest;
import software.amazon.awssdk.services.sqs.model.GetQueueAttributesResponse;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.QueueAttributeName;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
import software.amazon.lambda.powertools.sqs.SQSBatchProcessingException;
import software.amazon.lambda.powertools.sqs.SqsUtils;

import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage;
import static java.lang.String.format;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.toList;

public final class BatchContext {
    private static final Logger LOG = LoggerFactory.getLogger(BatchContext.class);
    private static final Map QUEUE_ARN_TO_DLQ_URL_MAPPING = new HashMap<>();

    private final Map messageToException = new HashMap<>();
    private final List success = new ArrayList<>();

    private final SqsClient client;

    public BatchContext(SqsClient client) {
        this.client = client;
    }

    public void addSuccess(SQSMessage event) {
        success.add(event);
    }

    public void addFailure(SQSMessage event, Exception e) {
        messageToException.put(event, e);
    }

    @SafeVarargs
    public final  void processSuccessAndHandleFailed(final List successReturns,
                                                        final boolean suppressException,
                                                        final boolean deleteNonRetryableMessageFromQueue,
                                                        final Class... nonRetryableExceptions) {
        if (hasFailures()) {

            List exceptions = new ArrayList<>();
            List failedMessages = new ArrayList<>();
            Map nonRetryableMessageToException = new HashMap<>();

            if (nonRetryableExceptions.length == 0) {
                exceptions.addAll(messageToException.values());
                failedMessages.addAll(messageToException.keySet());
            } else {
                messageToException.forEach((sqsMessage, exception) -> {
                    boolean nonRetryableException = isNonRetryableException(exception, nonRetryableExceptions);

                    if (nonRetryableException) {
                        nonRetryableMessageToException.put(sqsMessage, exception);
                    } else {
                        exceptions.add(exception);
                        failedMessages.add(sqsMessage);
                    }
                });
            }

            List messagesToBeDeleted = new ArrayList<>(success);

            if (!nonRetryableMessageToException.isEmpty() && deleteNonRetryableMessageFromQueue) {
                messagesToBeDeleted.addAll(nonRetryableMessageToException.keySet());
            } else if (!nonRetryableMessageToException.isEmpty()) {

                boolean isMovedToDlq = moveNonRetryableMessagesToDlqIfConfigured(nonRetryableMessageToException);

                if (!isMovedToDlq) {
                    exceptions.addAll(nonRetryableMessageToException.values());
                    failedMessages.addAll(nonRetryableMessageToException.keySet());
                }
            }

            deleteMessagesFromQueue(messagesToBeDeleted);

            processFailedMessages(successReturns, suppressException, exceptions, failedMessages);
        }
    }

    private  void processFailedMessages(List successReturns,
                                           boolean suppressException,
                                           List exceptions,
                                           List failedMessages) {
        if (failedMessages.isEmpty()) {
            return;
        }

        if (suppressException) {
            List messageIds = failedMessages.stream().
                    map(SQSMessage::getMessageId)
                    .collect(toList());

            LOG.debug(format("[%s] records failed processing, but exceptions are suppressed. " +
                    "Failed messages %s", failedMessages.size(), messageIds));
        } else {
            throw new SQSBatchProcessingException(exceptions, failedMessages, successReturns);
        }
    }

    private boolean isNonRetryableException(Exception exception, Class[] nonRetryableExceptions) {
        return Arrays.stream(nonRetryableExceptions)
                .anyMatch(aClass -> aClass.isInstance(exception));
    }

    private boolean moveNonRetryableMessagesToDlqIfConfigured(Map nonRetryableMessageToException) {
        Optional dlqUrl = fetchDlqUrl(nonRetryableMessageToException);

        if (!dlqUrl.isPresent()) {
            return false;
        }

        List dlqMessages = nonRetryableMessageToException.keySet().stream()
                .map(sqsMessage -> {
                    Map messageAttributesMap = new HashMap<>();

                    sqsMessage.getMessageAttributes().forEach((s, messageAttribute) -> {
                        MessageAttributeValue.Builder builder = MessageAttributeValue.builder();

                        builder
                                .dataType(messageAttribute.getDataType())
                                .stringValue(messageAttribute.getStringValue());

                        if (null != messageAttribute.getBinaryValue()) {
                            builder.binaryValue(SdkBytes.fromByteBuffer(messageAttribute.getBinaryValue()));
                        }

                        messageAttributesMap.put(s, builder.build());
                    });

                    return SendMessageBatchRequestEntry.builder()
                            .messageBody(sqsMessage.getBody())
                            .id(sqsMessage.getMessageId())
                            .messageAttributes(messageAttributesMap)
                            .build();
                })
                .collect(toList());

        List sendMessageBatchResponses = batchRequest(dlqMessages, 10, entriesToSend -> {

            SendMessageBatchResponse sendMessageBatchResponse = client.sendMessageBatch(SendMessageBatchRequest.builder()
                    .entries(entriesToSend)
                    .queueUrl(dlqUrl.get())
                    .build());


            LOG.debug("Response from send batch message to DLQ request {}", sendMessageBatchResponse);

            return sendMessageBatchResponse;
        });

        return sendMessageBatchResponses.stream()
                .filter(response -> null != response && response.hasFailed())
                .peek(sendMessageBatchResponse -> LOG.error("Failed sending message to the DLQ. Entire batch will be re processed. Check if needed permissions are configured for the function. Response: {}", sendMessageBatchResponse))
                .count()  == 0;
    }


    private Optional fetchDlqUrl(Map nonRetryableMessageToException) {
        return nonRetryableMessageToException.keySet().stream()
                .findFirst()
                .map(sqsMessage -> QUEUE_ARN_TO_DLQ_URL_MAPPING.computeIfAbsent(sqsMessage.getEventSourceArn(), sourceArn -> {
                    String queueUrl = url(sourceArn);

                    GetQueueAttributesResponse queueAttributes = client.getQueueAttributes(GetQueueAttributesRequest.builder()
                            .attributeNames(QueueAttributeName.REDRIVE_POLICY)
                            .queueUrl(queueUrl)
                            .build());

                    return ofNullable(queueAttributes.attributes().get(QueueAttributeName.REDRIVE_POLICY))
                            .map(policy -> {
                                try {
                                    return SqsUtils.objectMapper().readTree(policy);
                                } catch (JsonProcessingException e) {
                                    LOG.debug("Unable to parse Re drive policy for queue {}. Even if DLQ exists, failed messages will be send back to main queue.", queueUrl, e);
                                    return null;
                                }
                            })
                            .map(node -> node.get("deadLetterTargetArn"))
                            .map(JsonNode::asText)
                            .map(this::url)
                            .orElse(null);
                }));
    }

    private boolean hasFailures() {
        return !messageToException.isEmpty();
    }

    private void deleteMessagesFromQueue(final List messages) {
        if (!messages.isEmpty()) {

            List entries = messages.stream().map(m -> DeleteMessageBatchRequestEntry.builder()
                    .id(m.getMessageId())
                    .receiptHandle(m.getReceiptHandle())
                    .build()).collect(toList());

            batchRequest(entries, 10, entriesToDelete -> {
                DeleteMessageBatchRequest request = DeleteMessageBatchRequest.builder()
                        .queueUrl(url(messages.get(0).getEventSourceArn()))
                        .entries(entriesToDelete)
                        .build();

                DeleteMessageBatchResponse deleteMessageBatchResponse = client.deleteMessageBatch(request);

                LOG.debug("Response from delete request {}", deleteMessageBatchResponse);

                return deleteMessageBatchResponse;
            });
        }
    }

    private  List batchRequest(final List listOFEntries,
                                        final int size,
                                        final Function, R> batchLogic) {

        return IntStream.range(0, listOFEntries.size())
                .filter(index -> index % size == 0)
                .mapToObj(index -> listOFEntries.subList(index, Math.min(index + size, listOFEntries.size())))
                .map(batchLogic)
                .collect(Collectors.toList());
    }

    private String url(String queueArn) {
        String[] arnArray = queueArn.split(":");
        return String.format("https://sqs.%s.amazonaws.com/%s/%s", arnArray[3], arnArray[4], arnArray[5]);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy