All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.decathlon.tzatziki.kafka.KafkaInterceptor Maven / Gradle / Ivy

There is a newer version: 1.7.0
Show newest version
package com.decathlon.tzatziki.kafka;

import com.decathlon.tzatziki.utils.Fields;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.TopicPartition;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.jetbrains.annotations.NotNull;
import org.springframework.kafka.core.ConsumerFactory;
import org.springframework.kafka.core.DefaultKafkaConsumerFactory;
import org.springframework.kafka.support.SendResult;
import org.springframework.stereotype.Component;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy;
import java.time.Duration;
import java.util.*;

import static com.decathlon.tzatziki.utils.Asserts.awaitUntil;
import static java.util.Collections.synchronizedMap;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

@SuppressWarnings("unchecked")
@Aspect
@Component
@AllArgsConstructor
@Slf4j
public class KafkaInterceptor {

    private static final Set PROCESSED = new LinkedHashSet<>();
    private static final Set PROCESSING = new LinkedHashSet<>();
    private static final Map PAST_OFFSETS = synchronizedMap(new LinkedHashMap<>());
    private static final Map CURRENT_OFFSETS = synchronizedMap(new LinkedHashMap<>());
    public static boolean awaitForSuccessfullOnly;

    private static boolean enabled = true;

    public static void enable() {
        enabled = true;
    }

    public static void disable() {
        enabled = false;
    }

    @Around(value = "@annotation(org.springframework.kafka.annotation.KafkaListener)")
    public Object receiveMessages(ProceedingJoinPoint joinPoint) throws Throwable {
        Throwable throwable = null;
        try {
            return joinPoint.proceed();
        } catch (Throwable t) {
            throwable = t;
            throw t;
        } finally {
            if (throwable == null || !awaitForSuccessfullOnly) {
                PROCESSED.addAll(PROCESSING);
                PROCESSING.clear();
            }
        }
    }

    @Around("@annotation(org.springframework.context.annotation.Bean)")
    public Object beanCreation(ProceedingJoinPoint joinPoint) throws Throwable {
        Object bean = joinPoint.proceed();
        if (bean instanceof DefaultKafkaConsumerFactory) {
            return proxyOfConsumerFactory((DefaultKafkaConsumerFactory) bean);
        }
        return bean;
    }

    public static Map offsets() {
        return PAST_OFFSETS;
    }

    private ConsumerFactory proxyOfConsumerFactory(DefaultKafkaConsumerFactory consumerFactory) {
        return (ConsumerFactory) Proxy.newProxyInstance(ConsumerFactory.class.getClassLoader(), new Class[]{ConsumerFactory.class}, (proxy, method, args) -> {
            Object result = method.invoke(consumerFactory, args);
            if ("createConsumer".equals(method.getName())) {
                return createConsumerProxy((Consumer) result);
            }
            return result;
        });
    }

    private Object createConsumerProxy(Consumer consumer) {
        return Proxy.newProxyInstance(Consumer.class.getClassLoader(), new Class[]{Consumer.class}, (proxy, method, args) -> {
            try {
                return switch (method.getName()) {
                    case "poll" -> {
                        ConsumerRecords consumerRecords = (ConsumerRecords) method.invoke(consumer, args);
                        if (consumerRecords.count() == 0) {
                            yield consumerRecords;
                        }

                        for (ConsumerRecord record : consumerRecords) {
                            PROCESSING.add("%s-%s@%s".formatted(record.topic(), record.partition(), record.offset()));
                            CURRENT_OFFSETS.compute(topicPartitionOf(record),
                                    (t, current) -> Math.max(ofNullable(current).orElse(0L), record.offset() + 1));
                            adjustedOffsetFor(topicPartitionOf(record));
                        }

                        yield new ConsumerRecords<>(rewriteOffsets(Fields.getValue(consumerRecords, "records")));
                    }
                    case "seek" -> {
                        TopicPartition topicPartition = (TopicPartition) args[0];
                        if (args[1] instanceof Long) {
                            long offset = (long) args[1];
                            offset += adjustedOffsetFor(topicPartition);
                            consumer.seek(topicPartition, offset);
                        } else {
                            OffsetAndMetadata offsetAndMetadata = (OffsetAndMetadata) args[1];
                            offsetAndMetadata = new OffsetAndMetadata(
                                    adjustedOffsetFor(topicPartition) + offsetAndMetadata.offset(),
                                    offsetAndMetadata.metadata());
                            consumer.seek(topicPartition, offsetAndMetadata);
                        }
                        yield null;
                    }
                    case "position" -> {
                        TopicPartition topicPartition = (TopicPartition) args[0];
                        long offset = adjustedOffsetFor(topicPartition);
                        if (args.length == 2) {
                            yield consumer.position(topicPartition, (Duration) args[1]) - offset;
                        }
                        yield consumer.position(topicPartition) - offset;
                    }
                    case "endOffsets" -> {
                        Map endOffsets = (Map) method.invoke(consumer, args);
                        yield endOffsets.entrySet().stream().collect(toMap(Map.Entry::getKey, e -> e.getValue() - adjustedOffsetFor(e.getKey())));
                    }
                    case "seekToBeginning" -> {
                        Collection topicPartitions = (Collection) args[0];
                        topicPartitions.forEach(topicPartition -> consumer.seek(topicPartition, adjustedOffsetFor(topicPartition)));
                        yield null;
                    }
                    default -> method.invoke(consumer, args);
                };
            } catch (InvocationTargetException e) {
                throw e.getTargetException();
            }
        });
    }

    private  Map>> rewriteOffsets(Map>> records) {
        // rewriting the offsets in the messages
        records = records.entrySet().stream()
                .collect(toMap(Map.Entry::getKey, e -> e.getValue().stream()
                        .map(record -> new ConsumerRecord<>(
                                record.topic(),
                                record.partition(),
                                record.offset() - PAST_OFFSETS.getOrDefault(topicPartitionOf(record), 0L),
                                record.timestamp(),
                                record.timestampType(),
                                record.serializedKeySize(),
                                record.serializedValueSize(),
                                record.key(),
                                record.value(),
                                record.headers(),
                                record.leaderEpoch()))
                        .filter(record -> record.offset() >= 0)
                        .collect(toList())));
        return records;
    }

    @NotNull
    private TopicPartition topicPartitionOf(ConsumerRecord record) {
        return new TopicPartition(record.topic(), record.partition());
    }

    public static long adjustedOffsetFor(TopicPartition topicPartition) {
        long offset = enabled ? PAST_OFFSETS.getOrDefault(topicPartition, 0L) : 0L;
        if (offset > 0) {
            log.debug("adjusted offset for %s is %s".formatted(topicPartition, offset));
        }
        return offset;
    }

    public static SendResult waitUntilProcessed(SendResult result) {
        awaitUntil(() -> PROCESSED.contains(result.getRecordMetadata().toString()));
        PROCESSED.remove(result.getRecordMetadata().toString());
        return result;
    }

    public static void before() {
        PAST_OFFSETS.putAll(CURRENT_OFFSETS);
        CURRENT_OFFSETS.clear();
        PROCESSING.clear();
        PROCESSED.clear();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy