All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.migrations.replay.kafka.TrackingKafkaConsumer Maven / Gradle / Ivy

There is a newer version: 0.2.0.4
Show newest version
package org.opensearch.migrations.replay.kafka;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.TopicPartition;

import org.opensearch.migrations.replay.datatypes.ITrafficStreamKey;
import org.opensearch.migrations.replay.tracing.IKafkaConsumerContexts;
import org.opensearch.migrations.replay.tracing.ITrafficSourceContexts;
import org.opensearch.migrations.replay.tracing.KafkaConsumerContexts;
import org.opensearch.migrations.replay.tracing.RootReplayerContext;
import org.opensearch.migrations.replay.traffic.source.ITrafficCaptureSource;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.event.Level;

/**
 * This is a wrapper around Kafka's Consumer class that provides tracking of partitions
 * and their current (asynchronously 'committed' by the calling contexts) offsets.  It
 * manages those offsets and the 'active' set of records that have been rendered by this
 * consumer, when to pause a poll loop(), and how to deal with consumer rebalances.
 */
@Slf4j
public class TrackingKafkaConsumer implements ConsumerRebalanceListener {
    @AllArgsConstructor
    private static class OrderedKeyHolder implements Comparable {
        @Getter
        final long offset;
        @Getter
        @NonNull
        final ITrafficStreamKey tsk;

        @Override
        public int compareTo(OrderedKeyHolder o) {
            return Long.compare(offset, o.offset);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            OrderedKeyHolder that = (OrderedKeyHolder) o;

            if (offset != that.offset) return false;
            return tsk.equals(that.tsk);
        }

        @Override
        public int hashCode() {
            return Long.valueOf(offset).hashCode();
        }
    }

    /**
     * The keep-alive should already be set to a fraction of the max poll timeout for
     * the consumer (done outside of this class).  The keep-alive tells this class how
     * often the caller should be interacting with touch() and poll() calls.  As such,
     * we want to set up a long enough poll to not overwhelm a broker or client with
     * many empty poll() message responses.  We also don't want to poll() for so long
     * when there aren't messages that there isn't enough time to commit messages,
     * which happens after we poll() (on the same thread, as per Consumer requirements).
     */
    public static final int POLL_TIMEOUT_KEEP_ALIVE_DIVISOR = 4;

    @NonNull
    private final RootReplayerContext globalContext;
    private final Consumer kafkaConsumer;

    final String topic;
    private final Clock clock;
    /**
     * This collection holds the definitive list, as per the rebalance callback, of the partitions
     * that are currently assigned to this consumer.  The objects are removed when partitions are
     * revoked and new objects are only created/inserted when they're assigned.  That means that
     * the generations of each OffsetLifecycleTracker value may be different.
     */
    final Map partitionToOffsetLifecycleTrackerMap;
    private final Object commitDataLock = new Object();
    // loosening visibility so that a unit test can read this
    final Map nextSetOfCommitsMap;
    final Map> nextSetOfKeysContextsBeingCommitted;
    final java.util.function.Consumer onCommitKeyCallback;
    private final Duration keepAliveInterval;
    private final AtomicReference lastTouchTimeRef;
    private final AtomicInteger consumerConnectionGeneration;
    private final AtomicInteger kafkaRecordsLeftToCommitEventually;
    private final AtomicBoolean kafkaRecordsReadyToCommit;

    public TrackingKafkaConsumer(
        @NonNull RootReplayerContext globalContext,
        Consumer kafkaConsumer,
        String topic,
        Duration keepAliveInterval,
        Clock c,
        java.util.function.Consumer onCommitKeyCallback
    ) {
        this.globalContext = globalContext;
        this.kafkaConsumer = kafkaConsumer;
        this.topic = topic;
        this.clock = c;
        this.partitionToOffsetLifecycleTrackerMap = new HashMap<>();
        this.nextSetOfCommitsMap = new HashMap<>();
        this.nextSetOfKeysContextsBeingCommitted = new HashMap<>();
        this.lastTouchTimeRef = new AtomicReference<>(Instant.EPOCH);
        consumerConnectionGeneration = new AtomicInteger();
        kafkaRecordsLeftToCommitEventually = new AtomicInteger();
        kafkaRecordsReadyToCommit = new AtomicBoolean();
        this.keepAliveInterval = keepAliveInterval;
        this.onCommitKeyCallback = onCommitKeyCallback;
    }

    @Override
    public void onPartitionsRevoked(Collection partitions) {
        if (partitions.isEmpty()) {
            log.atDebug().setMessage(() -> this + " revoked no partitions.").log();
            return;
        }

        new KafkaConsumerContexts.AsyncListeningContext(globalContext).onPartitionsRevoked(partitions);
        synchronized (commitDataLock) {
            safeCommit(globalContext::createCommitContext);
            partitions.forEach(p -> {
                var tp = new TopicPartition(topic, p.partition());
                nextSetOfCommitsMap.remove(tp);
                nextSetOfKeysContextsBeingCommitted.remove(tp);
                partitionToOffsetLifecycleTrackerMap.remove(p.partition());
            });
            kafkaRecordsLeftToCommitEventually.set(
                partitionToOffsetLifecycleTrackerMap.values().stream().mapToInt(OffsetLifecycleTracker::size).sum()
            );
            kafkaRecordsReadyToCommit.set(!nextSetOfCommitsMap.values().isEmpty());
            log.atWarn()
                .setMessage(
                    () -> this
                        + " partitions revoked for "
                        + partitions.stream().map(String::valueOf).collect(Collectors.joining(","))
                )
                .log();
        }
    }

    @Override
    public void onPartitionsAssigned(Collection newPartitions) {
        if (newPartitions.isEmpty()) {
            log.atInfo().setMessage(() -> this + " assigned no new partitions.").log();
            return;
        }

        new KafkaConsumerContexts.AsyncListeningContext(globalContext).onPartitionsAssigned(newPartitions);
        synchronized (commitDataLock) {
            consumerConnectionGeneration.incrementAndGet();
            newPartitions.forEach(
                p -> partitionToOffsetLifecycleTrackerMap.computeIfAbsent(
                    p.partition(),
                    x -> new OffsetLifecycleTracker(consumerConnectionGeneration.get())
                )
            );
            log.atInfo()
                .setMessage(
                    () -> this
                        + " partitions added for "
                        + newPartitions.stream().map(String::valueOf).collect(Collectors.joining(","))
                )
                .log();
        }
    }

    public void close() {
        log.atInfo()
            .setMessage(
                () -> "Kafka consumer closing.  "
                    + "Committing (implicitly by Kafka's consumer): "
                    + nextCommitsToString()
            )
            .log();
        kafkaConsumer.close();
    }

    public Optional getNextRequiredTouch() {
        var lastTouchTime = lastTouchTimeRef.get();
        var r = kafkaRecordsLeftToCommitEventually.get() == 0
            ? Optional.empty()
            : Optional.of(kafkaRecordsReadyToCommit.get() ? Instant.now() : lastTouchTime.plus(keepAliveInterval));
        log.atTrace()
            .setMessage(
                () -> "returning next required touch at "
                    + r.map(t -> "" + t).orElse("N/A")
                    + " from a lastTouchTime of "
                    + lastTouchTime
            )
            .log();
        return r;
    }

    public void touch(ITrafficSourceContexts.IBackPressureBlockContext context) {
        try (var touchCtx = context.createNewTouchContext()) {
            log.trace("touch() called.");
            pause();
            try (var pollCtx = touchCtx.createNewPollContext()) {
                var records = kafkaConsumer.poll(Duration.ZERO);
                if (!records.isEmpty()) {
                    throw new IllegalStateException(
                        "Expected no entries once the consumer was paused.  "
                            + "This may have happened because a new assignment slipped into the consumer AFTER pause calls."
                    );
                }
            } catch (IllegalStateException e) {
                throw e;
            } catch (RuntimeException e) {
                log.atWarn()
                    .setCause(e)
                    .setMessage(
                        "Unable to poll the topic: "
                            + topic
                            + " with our Kafka consumer. "
                            + "Swallowing and awaiting next metadata refresh to try again."
                    )
                    .log();
            } finally {
                resume();
            }
            safeCommit(context::createCommitContext);
            lastTouchTimeRef.set(clock.instant());
        }
    }

    private void pause() {
        var activePartitions = kafkaConsumer.assignment();
        try {
            kafkaConsumer.pause(activePartitions);
        } catch (IllegalStateException e) {
            log.atError()
                .setCause(e)
                .setMessage(
                    () -> "Unable to pause the topic partitions: "
                        + topic
                        + ".  "
                        + "The active partitions passed here : "
                        + activePartitions.stream().map(String::valueOf).collect(Collectors.joining(","))
                        + ".  "
                        + "The active partitions as tracked here are: "
                        + getActivePartitions().stream().map(String::valueOf).collect(Collectors.joining(","))
                        + ".  "
                        + "The active partitions according to the consumer:  "
                        + kafkaConsumer.assignment().stream().map(String::valueOf).collect(Collectors.joining(","))
                )
                .log();
        }
    }

    private void resume() {
        var activePartitions = kafkaConsumer.assignment();
        try {
            kafkaConsumer.resume(activePartitions);
        } catch (IllegalStateException e) {
            log.atError()
                .setCause(e)
                .setMessage(
                    () -> "Unable to resume the topic partitions: "
                        + topic
                        + ".  "
                        + "This may not be a fatal error for the entire process as the consumer should eventually"
                        + " rejoin and rebalance.  "
                        + "The active partitions passed here : "
                        + activePartitions.stream().map(String::valueOf).collect(Collectors.joining(","))
                        + ".  "
                        + "The active partitions as tracked here are: "
                        + getActivePartitions().stream().map(String::valueOf).collect(Collectors.joining(","))
                        + ".  "
                        + "The active partitions according to the consumer:  "
                        + kafkaConsumer.assignment().stream().map(String::valueOf).collect(Collectors.joining(","))
                )
                .log();
        }
    }

    private Collection getActivePartitions() {
        return partitionToOffsetLifecycleTrackerMap.keySet()
            .stream()
            .map(p -> new TopicPartition(topic, p))
            .collect(Collectors.toList());
    }

    public  Stream getNextBatchOfRecords(
        ITrafficSourceContexts.IReadChunkContext context,
        BiFunction, T> builder
    ) {
        safeCommit(context::createCommitContext);
        var records = safePollWithSwallowedRuntimeExceptions(context);
        safeCommit(context::createCommitContext);
        return applyBuilder(builder, records);
    }

    private  Stream applyBuilder(
        BiFunction, T> builder,
        ConsumerRecords records
    ) {
        return StreamSupport.stream(records.spliterator(), false).map(kafkaRecord -> {
            var offsetTracker = partitionToOffsetLifecycleTrackerMap.get(kafkaRecord.partition());
            var offsetDetails = new PojoKafkaCommitOffsetData(
                offsetTracker.consumerConnectionGeneration,
                kafkaRecord.partition(),
                kafkaRecord.offset()
            );
            offsetTracker.add(offsetDetails.getOffset());
            kafkaRecordsLeftToCommitEventually.incrementAndGet();
            log.atTrace().setMessage(() -> "records in flight=" + kafkaRecordsLeftToCommitEventually.get()).log();
            return builder.apply(offsetDetails, kafkaRecord);
        });
    }

    private ConsumerRecords safePollWithSwallowedRuntimeExceptions(
        ITrafficSourceContexts.IReadChunkContext context
    ) {
        try {
            lastTouchTimeRef.set(clock.instant());
            ConsumerRecords records;
            try (var pollContext = context.createPollContext()) {
                records = kafkaConsumer.poll(keepAliveInterval.dividedBy(POLL_TIMEOUT_KEEP_ALIVE_DIVISOR));
            }
            log.atLevel(records.isEmpty() ? Level.TRACE : Level.INFO)
                .setMessage(
                    () -> "Kafka consumer poll has fetched "
                        + records.count()
                        + " records.  "
                        + "Records in flight="
                        + kafkaRecordsLeftToCommitEventually.get()
                )
                .log();
            log.atTrace()
                .setMessage("{}")
                .addArgument(
                    () -> "All positions: {"
                        + kafkaConsumer.assignment()
                            .stream()
                            .map(tp -> tp + ": " + kafkaConsumer.position(tp))
                            .collect(Collectors.joining(","))
                        + "}"
                )
                .log();
            log.atTrace()
                .setMessage("{}")
                .addArgument(
                    () -> "All previously COMMITTED positions: {"
                        + kafkaConsumer.assignment()
                            .stream()
                            .map(tp -> tp + ": " + kafkaConsumer.committed(tp))
                            .collect(Collectors.joining(","))
                        + "}"
                )
                .log();
            return records;
        } catch (RuntimeException e) {
            log.atWarn()
                .setCause(e)
                .setMessage(
                    "Unable to poll the topic: {} with our Kafka consumer. "
                        + "Swallowing and awaiting next metadata refresh to try again."
                )
                .addArgument(topic)
                .log();
            return new ConsumerRecords<>(Collections.emptyMap());
        }
    }

    ITrafficCaptureSource.CommitResult commitKafkaKey(ITrafficStreamKey streamKey, KafkaCommitOffsetData kafkaTsk) {
        OffsetLifecycleTracker tracker;
        synchronized (commitDataLock) {
            tracker = partitionToOffsetLifecycleTrackerMap.get(kafkaTsk.getPartition());
        }
        if (tracker == null || tracker.consumerConnectionGeneration != kafkaTsk.getGeneration()) {
            log.atWarn()
                .setMessage(
                    () -> "trafficKey's generation ("
                        + kafkaTsk.getGeneration()
                        + ") is not current ("
                        + (Optional.ofNullable(tracker)
                            .map(t -> "new generation=" + t.consumerConnectionGeneration)
                            .orElse("Partition unassigned"))
                        + ").  Dropping this commit request since the record would "
                        + "have been handled again by a current consumer within this process or another. Full key="
                        + kafkaTsk
                )
                .log();
            return ITrafficCaptureSource.CommitResult.IGNORED;
        }

        var p = kafkaTsk.getPartition();
        Optional newHeadValue;

        var k = new TopicPartition(topic, p);

        newHeadValue = tracker.removeAndReturnNewHead(kafkaTsk.getOffset());
        return newHeadValue.map(o -> {
            var v = new OffsetAndMetadata(o);
            log.atDebug().setMessage(() -> "Adding new commit " + k + "->" + v + " to map").log();
            synchronized (commitDataLock) {
                addKeyContextForEventualCommit(streamKey, kafkaTsk, k);
                nextSetOfCommitsMap.put(k, v);
            }
            return ITrafficCaptureSource.CommitResult.AFTER_NEXT_READ;
        }).orElseGet(() -> {
            synchronized (commitDataLock) {
                addKeyContextForEventualCommit(streamKey, kafkaTsk, k);
            }
            return ITrafficCaptureSource.CommitResult.BLOCKED_BY_OTHER_COMMITS;
        });
    }

    private void addKeyContextForEventualCommit(
        ITrafficStreamKey streamKey,
        KafkaCommitOffsetData kafkaTsk,
        TopicPartition k
    ) {
        nextSetOfKeysContextsBeingCommitted.computeIfAbsent(k, k2 -> new PriorityQueue<>())
            .add(new OrderedKeyHolder(kafkaTsk.getOffset(), streamKey));
    }

    private void safeCommit(Supplier commitContextSupplier) {
        HashMap nextCommitsMapCopy;
        IKafkaConsumerContexts.ICommitScopeContext context = null;
        synchronized (commitDataLock) {
            if (nextSetOfCommitsMap.isEmpty()) {
                return;
            }
            context = commitContextSupplier.get();
            nextCommitsMapCopy = new HashMap<>(nextSetOfCommitsMap);
        }
        try {
            safeCommitStatic(context, kafkaConsumer, nextCommitsMapCopy);
            synchronized (commitDataLock) {
                nextCommitsMapCopy.entrySet()
                    .stream()
                    .forEach(
                        kvp -> callbackUpTo(
                            onCommitKeyCallback,
                            nextSetOfKeysContextsBeingCommitted.get(kvp.getKey()),
                            kvp.getValue().offset()
                        )
                    );
                nextCommitsMapCopy.forEach((k, v) -> nextSetOfCommitsMap.remove(k));
            }
            // This function will only ever be called in a threadsafe way, mutually exclusive from any
            // other call other than commitKafkaKey(). Since commitKafkaKey() doesn't alter
            // partitionToOffsetLifecycleTrackerMap, these lines can be outside of the commitDataLock mutex
            log.trace("partitionToOffsetLifecycleTrackerMap=" + partitionToOffsetLifecycleTrackerMap);
            kafkaRecordsLeftToCommitEventually.set(
                partitionToOffsetLifecycleTrackerMap.values().stream().mapToInt(OffsetLifecycleTracker::size).sum()
            );
            log.atDebug()
                .setMessage(() -> "Done committing now records in flight=" + kafkaRecordsLeftToCommitEventually.get())
                .log();
        } catch (RuntimeException e) {
            log.atWarn()
                .setCause(e)
                .setMessage(
                    () -> "Error while committing.  "
                        + "Another consumer may already be processing messages before these commits.  "
                        + "Commits ARE NOT being discarded here, with the expectation that the revoked callback "
                        + "(onPartitionsRevoked) will be called.  "
                        + "Within that method, commits for unassigned partitions will be discarded.  "
                        + "After that, touch() or poll() will trigger another commit attempt."
                        + "Those calls will occur in the near future if assigned partitions have pending commits."
                        + nextSetOfCommitsMap.entrySet()
                            .stream()
                            .map(kvp -> kvp.getKey() + "->" + kvp.getValue())
                            .collect(Collectors.joining(","))
                )
                .log();
        } finally {
            if (context != null) {
                context.close();
            }
        }
    }

    private static void safeCommitStatic(
        IKafkaConsumerContexts.ICommitScopeContext context,
        Consumer kafkaConsumer,
        HashMap nextCommitsMap
    ) {
        assert !nextCommitsMap.isEmpty();
        log.atDebug().setMessage(() -> "Committing " + nextCommitsMap).log();
        try (var kafkaContext = context.createNewKafkaCommitContext()) {
            kafkaConsumer.commitSync(nextCommitsMap);
        }
    }

    private static void callbackUpTo(
        java.util.function.Consumer onCommitKeyCallback,
        PriorityQueue orderedKeyHolders,
        long upToOffset
    ) {
        for (var nextKeyHolder = orderedKeyHolders.peek(); nextKeyHolder != null
            && nextKeyHolder.offset <= upToOffset; nextKeyHolder = orderedKeyHolders.peek()) {
            onCommitKeyCallback.accept(nextKeyHolder.tsk);
            orderedKeyHolders.poll();
        }
    }

    String nextCommitsToString() {
        return "nextCommits="
            + nextSetOfCommitsMap.entrySet()
                .stream()
                .map(kvp -> kvp.getKey() + "->" + kvp.getValue())
                .collect(Collectors.joining(","));
    }

    @Override
    public String toString() {
        synchronized (commitDataLock) {
            int partitionCount = partitionToOffsetLifecycleTrackerMap.size();
            int commitsPending = nextSetOfCommitsMap.size();
            int recordsLeftToCommit = kafkaRecordsLeftToCommitEventually.get();
            boolean recordsReadyToCommit = kafkaRecordsReadyToCommit.get();
            return String.format(
                "TrackingKafkaConsumer{topic='%s', partitionCount=%d, commitsPending=%d, "
                    + "recordsLeftToCommit=%d, recordsReadyToCommit=%b}",
                topic,
                partitionCount,
                commitsPending,
                recordsLeftToCommit,
                recordsReadyToCommit
            );
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy