
org.opensearch.migrations.replay.kafka.TrackingKafkaConsumer Maven / Gradle / Ivy
package org.opensearch.migrations.replay.kafka;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.TopicPartition;
import org.opensearch.migrations.replay.datatypes.ITrafficStreamKey;
import org.opensearch.migrations.replay.tracing.IKafkaConsumerContexts;
import org.opensearch.migrations.replay.tracing.ITrafficSourceContexts;
import org.opensearch.migrations.replay.tracing.KafkaConsumerContexts;
import org.opensearch.migrations.replay.tracing.RootReplayerContext;
import org.opensearch.migrations.replay.traffic.source.ITrafficCaptureSource;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.event.Level;
/**
* This is a wrapper around Kafka's Consumer class that provides tracking of partitions
* and their current (asynchronously 'committed' by the calling contexts) offsets. It
* manages those offsets and the 'active' set of records that have been rendered by this
* consumer, when to pause a poll loop(), and how to deal with consumer rebalances.
*/
@Slf4j
public class TrackingKafkaConsumer implements ConsumerRebalanceListener {
@AllArgsConstructor
private static class OrderedKeyHolder implements Comparable {
@Getter
final long offset;
@Getter
@NonNull
final ITrafficStreamKey tsk;
@Override
public int compareTo(OrderedKeyHolder o) {
return Long.compare(offset, o.offset);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OrderedKeyHolder that = (OrderedKeyHolder) o;
if (offset != that.offset) return false;
return tsk.equals(that.tsk);
}
@Override
public int hashCode() {
return Long.valueOf(offset).hashCode();
}
}
/**
* The keep-alive should already be set to a fraction of the max poll timeout for
* the consumer (done outside of this class). The keep-alive tells this class how
* often the caller should be interacting with touch() and poll() calls. As such,
* we want to set up a long enough poll to not overwhelm a broker or client with
* many empty poll() message responses. We also don't want to poll() for so long
* when there aren't messages that there isn't enough time to commit messages,
* which happens after we poll() (on the same thread, as per Consumer requirements).
*/
public static final int POLL_TIMEOUT_KEEP_ALIVE_DIVISOR = 4;
@NonNull
private final RootReplayerContext globalContext;
private final Consumer kafkaConsumer;
final String topic;
private final Clock clock;
/**
* This collection holds the definitive list, as per the rebalance callback, of the partitions
* that are currently assigned to this consumer. The objects are removed when partitions are
* revoked and new objects are only created/inserted when they're assigned. That means that
* the generations of each OffsetLifecycleTracker value may be different.
*/
final Map partitionToOffsetLifecycleTrackerMap;
private final Object commitDataLock = new Object();
// loosening visibility so that a unit test can read this
final Map nextSetOfCommitsMap;
final Map> nextSetOfKeysContextsBeingCommitted;
final java.util.function.Consumer onCommitKeyCallback;
private final Duration keepAliveInterval;
private final AtomicReference lastTouchTimeRef;
private final AtomicInteger consumerConnectionGeneration;
private final AtomicInteger kafkaRecordsLeftToCommitEventually;
private final AtomicBoolean kafkaRecordsReadyToCommit;
public TrackingKafkaConsumer(
@NonNull RootReplayerContext globalContext,
Consumer kafkaConsumer,
String topic,
Duration keepAliveInterval,
Clock c,
java.util.function.Consumer onCommitKeyCallback
) {
this.globalContext = globalContext;
this.kafkaConsumer = kafkaConsumer;
this.topic = topic;
this.clock = c;
this.partitionToOffsetLifecycleTrackerMap = new HashMap<>();
this.nextSetOfCommitsMap = new HashMap<>();
this.nextSetOfKeysContextsBeingCommitted = new HashMap<>();
this.lastTouchTimeRef = new AtomicReference<>(Instant.EPOCH);
consumerConnectionGeneration = new AtomicInteger();
kafkaRecordsLeftToCommitEventually = new AtomicInteger();
kafkaRecordsReadyToCommit = new AtomicBoolean();
this.keepAliveInterval = keepAliveInterval;
this.onCommitKeyCallback = onCommitKeyCallback;
}
@Override
public void onPartitionsRevoked(Collection partitions) {
if (partitions.isEmpty()) {
log.atDebug().setMessage(() -> this + " revoked no partitions.").log();
return;
}
new KafkaConsumerContexts.AsyncListeningContext(globalContext).onPartitionsRevoked(partitions);
synchronized (commitDataLock) {
safeCommit(globalContext::createCommitContext);
partitions.forEach(p -> {
var tp = new TopicPartition(topic, p.partition());
nextSetOfCommitsMap.remove(tp);
nextSetOfKeysContextsBeingCommitted.remove(tp);
partitionToOffsetLifecycleTrackerMap.remove(p.partition());
});
kafkaRecordsLeftToCommitEventually.set(
partitionToOffsetLifecycleTrackerMap.values().stream().mapToInt(OffsetLifecycleTracker::size).sum()
);
kafkaRecordsReadyToCommit.set(!nextSetOfCommitsMap.values().isEmpty());
log.atWarn()
.setMessage(
() -> this
+ " partitions revoked for "
+ partitions.stream().map(String::valueOf).collect(Collectors.joining(","))
)
.log();
}
}
@Override
public void onPartitionsAssigned(Collection newPartitions) {
if (newPartitions.isEmpty()) {
log.atInfo().setMessage(() -> this + " assigned no new partitions.").log();
return;
}
new KafkaConsumerContexts.AsyncListeningContext(globalContext).onPartitionsAssigned(newPartitions);
synchronized (commitDataLock) {
consumerConnectionGeneration.incrementAndGet();
newPartitions.forEach(
p -> partitionToOffsetLifecycleTrackerMap.computeIfAbsent(
p.partition(),
x -> new OffsetLifecycleTracker(consumerConnectionGeneration.get())
)
);
log.atInfo()
.setMessage(
() -> this
+ " partitions added for "
+ newPartitions.stream().map(String::valueOf).collect(Collectors.joining(","))
)
.log();
}
}
public void close() {
log.atInfo()
.setMessage(
() -> "Kafka consumer closing. "
+ "Committing (implicitly by Kafka's consumer): "
+ nextCommitsToString()
)
.log();
kafkaConsumer.close();
}
public Optional getNextRequiredTouch() {
var lastTouchTime = lastTouchTimeRef.get();
var r = kafkaRecordsLeftToCommitEventually.get() == 0
? Optional.empty()
: Optional.of(kafkaRecordsReadyToCommit.get() ? Instant.now() : lastTouchTime.plus(keepAliveInterval));
log.atTrace()
.setMessage(
() -> "returning next required touch at "
+ r.map(t -> "" + t).orElse("N/A")
+ " from a lastTouchTime of "
+ lastTouchTime
)
.log();
return r;
}
public void touch(ITrafficSourceContexts.IBackPressureBlockContext context) {
try (var touchCtx = context.createNewTouchContext()) {
log.trace("touch() called.");
pause();
try (var pollCtx = touchCtx.createNewPollContext()) {
var records = kafkaConsumer.poll(Duration.ZERO);
if (!records.isEmpty()) {
throw new IllegalStateException(
"Expected no entries once the consumer was paused. "
+ "This may have happened because a new assignment slipped into the consumer AFTER pause calls."
);
}
} catch (IllegalStateException e) {
throw e;
} catch (RuntimeException e) {
log.atWarn()
.setCause(e)
.setMessage(
"Unable to poll the topic: "
+ topic
+ " with our Kafka consumer. "
+ "Swallowing and awaiting next metadata refresh to try again."
)
.log();
} finally {
resume();
}
safeCommit(context::createCommitContext);
lastTouchTimeRef.set(clock.instant());
}
}
private void pause() {
var activePartitions = kafkaConsumer.assignment();
try {
kafkaConsumer.pause(activePartitions);
} catch (IllegalStateException e) {
log.atError()
.setCause(e)
.setMessage(
() -> "Unable to pause the topic partitions: "
+ topic
+ ". "
+ "The active partitions passed here : "
+ activePartitions.stream().map(String::valueOf).collect(Collectors.joining(","))
+ ". "
+ "The active partitions as tracked here are: "
+ getActivePartitions().stream().map(String::valueOf).collect(Collectors.joining(","))
+ ". "
+ "The active partitions according to the consumer: "
+ kafkaConsumer.assignment().stream().map(String::valueOf).collect(Collectors.joining(","))
)
.log();
}
}
private void resume() {
var activePartitions = kafkaConsumer.assignment();
try {
kafkaConsumer.resume(activePartitions);
} catch (IllegalStateException e) {
log.atError()
.setCause(e)
.setMessage(
() -> "Unable to resume the topic partitions: "
+ topic
+ ". "
+ "This may not be a fatal error for the entire process as the consumer should eventually"
+ " rejoin and rebalance. "
+ "The active partitions passed here : "
+ activePartitions.stream().map(String::valueOf).collect(Collectors.joining(","))
+ ". "
+ "The active partitions as tracked here are: "
+ getActivePartitions().stream().map(String::valueOf).collect(Collectors.joining(","))
+ ". "
+ "The active partitions according to the consumer: "
+ kafkaConsumer.assignment().stream().map(String::valueOf).collect(Collectors.joining(","))
)
.log();
}
}
private Collection getActivePartitions() {
return partitionToOffsetLifecycleTrackerMap.keySet()
.stream()
.map(p -> new TopicPartition(topic, p))
.collect(Collectors.toList());
}
public Stream getNextBatchOfRecords(
ITrafficSourceContexts.IReadChunkContext context,
BiFunction, T> builder
) {
safeCommit(context::createCommitContext);
var records = safePollWithSwallowedRuntimeExceptions(context);
safeCommit(context::createCommitContext);
return applyBuilder(builder, records);
}
private Stream applyBuilder(
BiFunction, T> builder,
ConsumerRecords records
) {
return StreamSupport.stream(records.spliterator(), false).map(kafkaRecord -> {
var offsetTracker = partitionToOffsetLifecycleTrackerMap.get(kafkaRecord.partition());
var offsetDetails = new PojoKafkaCommitOffsetData(
offsetTracker.consumerConnectionGeneration,
kafkaRecord.partition(),
kafkaRecord.offset()
);
offsetTracker.add(offsetDetails.getOffset());
kafkaRecordsLeftToCommitEventually.incrementAndGet();
log.atTrace().setMessage(() -> "records in flight=" + kafkaRecordsLeftToCommitEventually.get()).log();
return builder.apply(offsetDetails, kafkaRecord);
});
}
private ConsumerRecords safePollWithSwallowedRuntimeExceptions(
ITrafficSourceContexts.IReadChunkContext context
) {
try {
lastTouchTimeRef.set(clock.instant());
ConsumerRecords records;
try (var pollContext = context.createPollContext()) {
records = kafkaConsumer.poll(keepAliveInterval.dividedBy(POLL_TIMEOUT_KEEP_ALIVE_DIVISOR));
}
log.atLevel(records.isEmpty() ? Level.TRACE : Level.INFO)
.setMessage(
() -> "Kafka consumer poll has fetched "
+ records.count()
+ " records. "
+ "Records in flight="
+ kafkaRecordsLeftToCommitEventually.get()
)
.log();
log.atTrace()
.setMessage("{}")
.addArgument(
() -> "All positions: {"
+ kafkaConsumer.assignment()
.stream()
.map(tp -> tp + ": " + kafkaConsumer.position(tp))
.collect(Collectors.joining(","))
+ "}"
)
.log();
log.atTrace()
.setMessage("{}")
.addArgument(
() -> "All previously COMMITTED positions: {"
+ kafkaConsumer.assignment()
.stream()
.map(tp -> tp + ": " + kafkaConsumer.committed(tp))
.collect(Collectors.joining(","))
+ "}"
)
.log();
return records;
} catch (RuntimeException e) {
log.atWarn()
.setCause(e)
.setMessage(
"Unable to poll the topic: {} with our Kafka consumer. "
+ "Swallowing and awaiting next metadata refresh to try again."
)
.addArgument(topic)
.log();
return new ConsumerRecords<>(Collections.emptyMap());
}
}
ITrafficCaptureSource.CommitResult commitKafkaKey(ITrafficStreamKey streamKey, KafkaCommitOffsetData kafkaTsk) {
OffsetLifecycleTracker tracker;
synchronized (commitDataLock) {
tracker = partitionToOffsetLifecycleTrackerMap.get(kafkaTsk.getPartition());
}
if (tracker == null || tracker.consumerConnectionGeneration != kafkaTsk.getGeneration()) {
log.atWarn()
.setMessage(
() -> "trafficKey's generation ("
+ kafkaTsk.getGeneration()
+ ") is not current ("
+ (Optional.ofNullable(tracker)
.map(t -> "new generation=" + t.consumerConnectionGeneration)
.orElse("Partition unassigned"))
+ "). Dropping this commit request since the record would "
+ "have been handled again by a current consumer within this process or another. Full key="
+ kafkaTsk
)
.log();
return ITrafficCaptureSource.CommitResult.IGNORED;
}
var p = kafkaTsk.getPartition();
Optional newHeadValue;
var k = new TopicPartition(topic, p);
newHeadValue = tracker.removeAndReturnNewHead(kafkaTsk.getOffset());
return newHeadValue.map(o -> {
var v = new OffsetAndMetadata(o);
log.atDebug().setMessage(() -> "Adding new commit " + k + "->" + v + " to map").log();
synchronized (commitDataLock) {
addKeyContextForEventualCommit(streamKey, kafkaTsk, k);
nextSetOfCommitsMap.put(k, v);
}
return ITrafficCaptureSource.CommitResult.AFTER_NEXT_READ;
}).orElseGet(() -> {
synchronized (commitDataLock) {
addKeyContextForEventualCommit(streamKey, kafkaTsk, k);
}
return ITrafficCaptureSource.CommitResult.BLOCKED_BY_OTHER_COMMITS;
});
}
private void addKeyContextForEventualCommit(
ITrafficStreamKey streamKey,
KafkaCommitOffsetData kafkaTsk,
TopicPartition k
) {
nextSetOfKeysContextsBeingCommitted.computeIfAbsent(k, k2 -> new PriorityQueue<>())
.add(new OrderedKeyHolder(kafkaTsk.getOffset(), streamKey));
}
private void safeCommit(Supplier commitContextSupplier) {
HashMap nextCommitsMapCopy;
IKafkaConsumerContexts.ICommitScopeContext context = null;
synchronized (commitDataLock) {
if (nextSetOfCommitsMap.isEmpty()) {
return;
}
context = commitContextSupplier.get();
nextCommitsMapCopy = new HashMap<>(nextSetOfCommitsMap);
}
try {
safeCommitStatic(context, kafkaConsumer, nextCommitsMapCopy);
synchronized (commitDataLock) {
nextCommitsMapCopy.entrySet()
.stream()
.forEach(
kvp -> callbackUpTo(
onCommitKeyCallback,
nextSetOfKeysContextsBeingCommitted.get(kvp.getKey()),
kvp.getValue().offset()
)
);
nextCommitsMapCopy.forEach((k, v) -> nextSetOfCommitsMap.remove(k));
}
// This function will only ever be called in a threadsafe way, mutually exclusive from any
// other call other than commitKafkaKey(). Since commitKafkaKey() doesn't alter
// partitionToOffsetLifecycleTrackerMap, these lines can be outside of the commitDataLock mutex
log.trace("partitionToOffsetLifecycleTrackerMap=" + partitionToOffsetLifecycleTrackerMap);
kafkaRecordsLeftToCommitEventually.set(
partitionToOffsetLifecycleTrackerMap.values().stream().mapToInt(OffsetLifecycleTracker::size).sum()
);
log.atDebug()
.setMessage(() -> "Done committing now records in flight=" + kafkaRecordsLeftToCommitEventually.get())
.log();
} catch (RuntimeException e) {
log.atWarn()
.setCause(e)
.setMessage(
() -> "Error while committing. "
+ "Another consumer may already be processing messages before these commits. "
+ "Commits ARE NOT being discarded here, with the expectation that the revoked callback "
+ "(onPartitionsRevoked) will be called. "
+ "Within that method, commits for unassigned partitions will be discarded. "
+ "After that, touch() or poll() will trigger another commit attempt."
+ "Those calls will occur in the near future if assigned partitions have pending commits."
+ nextSetOfCommitsMap.entrySet()
.stream()
.map(kvp -> kvp.getKey() + "->" + kvp.getValue())
.collect(Collectors.joining(","))
)
.log();
} finally {
if (context != null) {
context.close();
}
}
}
private static void safeCommitStatic(
IKafkaConsumerContexts.ICommitScopeContext context,
Consumer kafkaConsumer,
HashMap nextCommitsMap
) {
assert !nextCommitsMap.isEmpty();
log.atDebug().setMessage(() -> "Committing " + nextCommitsMap).log();
try (var kafkaContext = context.createNewKafkaCommitContext()) {
kafkaConsumer.commitSync(nextCommitsMap);
}
}
private static void callbackUpTo(
java.util.function.Consumer onCommitKeyCallback,
PriorityQueue orderedKeyHolders,
long upToOffset
) {
for (var nextKeyHolder = orderedKeyHolders.peek(); nextKeyHolder != null
&& nextKeyHolder.offset <= upToOffset; nextKeyHolder = orderedKeyHolders.peek()) {
onCommitKeyCallback.accept(nextKeyHolder.tsk);
orderedKeyHolders.poll();
}
}
String nextCommitsToString() {
return "nextCommits="
+ nextSetOfCommitsMap.entrySet()
.stream()
.map(kvp -> kvp.getKey() + "->" + kvp.getValue())
.collect(Collectors.joining(","));
}
@Override
public String toString() {
synchronized (commitDataLock) {
int partitionCount = partitionToOffsetLifecycleTrackerMap.size();
int commitsPending = nextSetOfCommitsMap.size();
int recordsLeftToCommit = kafkaRecordsLeftToCommitEventually.get();
boolean recordsReadyToCommit = kafkaRecordsReadyToCommit.get();
return String.format(
"TrackingKafkaConsumer{topic='%s', partitionCount=%d, commitsPending=%d, "
+ "recordsLeftToCommit=%d, recordsReadyToCommit=%b}",
topic,
partitionCount,
commitsPending,
recordsLeftToCommit,
recordsReadyToCommit
);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy