All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.kafka.clients.consumer.MockConsumer Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.kafka.clients.consumer;

import org.apache.kafka.clients.Metadata;
import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.internals.SubscriptionState;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.Metric;
import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.errors.WakeupException;
import org.apache.kafka.common.utils.LogContext;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static java.util.Collections.singleton;
import static org.apache.kafka.clients.consumer.KafkaConsumer.DEFAULT_CLOSE_TIMEOUT_MS;


/**
 * A mock of the {@link Consumer} interface you can use for testing code that uses Kafka. This class is  not
 * threadsafe . However, you can use the {@link #schedulePollTask(Runnable)} method to write multithreaded tests
 * where a driver thread waits for {@link #poll(Duration)} to be called by a background thread and then can safely perform
 * operations during a callback.
 */
public class MockConsumer implements Consumer {

    private final Map> partitions;
    private final SubscriptionState subscriptions;
    private final Map beginningOffsets;
    private final Map endOffsets;
    private final Map committed;
    private final Queue pollTasks;
    private final Set paused;

    private Map>> records;
    private KafkaException pollException;
    private KafkaException offsetsException;
    private AtomicBoolean wakeup;
    private Duration lastPollTimeout;
    private boolean closed;
    private boolean shouldRebalance;

    public MockConsumer(OffsetResetStrategy offsetResetStrategy) {
        this.subscriptions = new SubscriptionState(new LogContext(), offsetResetStrategy);
        this.partitions = new HashMap<>();
        this.records = new HashMap<>();
        this.paused = new HashSet<>();
        this.closed = false;
        this.beginningOffsets = new HashMap<>();
        this.endOffsets = new HashMap<>();
        this.pollTasks = new LinkedList<>();
        this.pollException = null;
        this.wakeup = new AtomicBoolean(false);
        this.committed = new HashMap<>();
        this.shouldRebalance = false;
    }

    @Override
    public synchronized Set assignment() {
        return this.subscriptions.assignedPartitions();
    }

    /** Simulate a rebalance event. */
    public synchronized void rebalance(Collection newAssignment) {
        // TODO: Rebalance callbacks
        this.records.clear();
        this.subscriptions.assignFromSubscribed(newAssignment);
    }

    @Override
    public synchronized Set subscription() {
        return this.subscriptions.subscription();
    }

    @Override
    public synchronized void subscribe(Collection topics) {
        subscribe(topics, new NoOpConsumerRebalanceListener());
    }

    @Override
    public synchronized void subscribe(Pattern pattern, final ConsumerRebalanceListener listener) {
        ensureNotClosed();
        committed.clear();
        this.subscriptions.subscribe(pattern, listener);
        Set topicsToSubscribe = new HashSet<>();
        for (String topic: partitions.keySet()) {
            if (pattern.matcher(topic).matches() &&
                !subscriptions.subscription().contains(topic))
                topicsToSubscribe.add(topic);
        }
        ensureNotClosed();
        this.subscriptions.subscribeFromPattern(topicsToSubscribe);
        final Set assignedPartitions = new HashSet<>();
        for (final String topic : topicsToSubscribe) {
            for (final PartitionInfo info : this.partitions.get(topic)) {
                assignedPartitions.add(new TopicPartition(topic, info.partition()));
            }

        }
        subscriptions.assignFromSubscribed(assignedPartitions);
    }

    @Override
    public synchronized void subscribe(Pattern pattern) {
        subscribe(pattern, new NoOpConsumerRebalanceListener());
    }

    @Override
    public synchronized void subscribe(Collection topics, final ConsumerRebalanceListener listener) {
        ensureNotClosed();
        committed.clear();
        this.subscriptions.subscribe(new HashSet<>(topics), listener);
    }

    @Override
    public synchronized void assign(Collection partitions) {
        ensureNotClosed();
        committed.clear();
        this.subscriptions.assignFromUser(new HashSet<>(partitions));
    }

    @Override
    public synchronized void unsubscribe() {
        ensureNotClosed();
        committed.clear();
        subscriptions.unsubscribe();
    }

    @Deprecated
    @Override
    public synchronized ConsumerRecords poll(long timeout) {
        return poll(Duration.ofMillis(timeout));
    }

    @Override
    public synchronized ConsumerRecords poll(final Duration timeout) {
        ensureNotClosed();

        lastPollTimeout = timeout;

        // Synchronize around the entire execution so new tasks to be triggered on subsequent poll calls can be added in
        // the callback
        synchronized (pollTasks) {
            Runnable task = pollTasks.poll();
            if (task != null)
                task.run();
        }

        if (wakeup.get()) {
            wakeup.set(false);
            throw new WakeupException();
        }

        if (pollException != null) {
            RuntimeException exception = this.pollException;
            this.pollException = null;
            throw exception;
        }

        // Handle seeks that need to wait for a poll() call to be processed
        for (TopicPartition tp : subscriptions.assignedPartitions())
            if (!subscriptions.hasValidPosition(tp))
                updateFetchPosition(tp);

        // update the consumed offset
        final Map>> results = new HashMap<>();
        final List toClear = new ArrayList<>();

        for (Map.Entry>> entry : this.records.entrySet()) {
            if (!subscriptions.isPaused(entry.getKey())) {
                final List> recs = entry.getValue();
                for (final ConsumerRecord rec : recs) {
                    long position = subscriptions.position(entry.getKey()).offset;

                    if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > position) {
                        throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position));
                    }

                    if (assignment().contains(entry.getKey()) && rec.offset() >= position) {
                        results.computeIfAbsent(entry.getKey(), partition -> new ArrayList<>()).add(rec);
                        Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(Optional.empty(), rec.leaderEpoch());
                        SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
                                rec.offset() + 1, rec.leaderEpoch(), leaderAndEpoch);
                        subscriptions.position(entry.getKey(), newPosition);
                    }
                }
                toClear.add(entry.getKey());
            }
        }

        toClear.forEach(p -> this.records.remove(p));
        return new ConsumerRecords<>(results);
    }

    public synchronized void addRecord(ConsumerRecord record) {
        ensureNotClosed();
        TopicPartition tp = new TopicPartition(record.topic(), record.partition());
        Set currentAssigned = this.subscriptions.assignedPartitions();
        if (!currentAssigned.contains(tp))
            throw new IllegalStateException("Cannot add records for a partition that is not assigned to the consumer");
        List> recs = this.records.computeIfAbsent(tp, k -> new ArrayList<>());
        recs.add(record);
    }

    /**
     * @deprecated Use {@link #setPollException(KafkaException)} instead
     */
    @Deprecated
    public synchronized void setException(KafkaException exception) {
        setPollException(exception);
    }

    public synchronized void setPollException(KafkaException exception) {
        this.pollException = exception;
    }

    public synchronized void setOffsetsException(KafkaException exception) {
        this.offsetsException = exception;
    }

    @Override
    public synchronized void commitAsync(Map offsets, OffsetCommitCallback callback) {
        ensureNotClosed();
        for (Map.Entry entry : offsets.entrySet())
            committed.put(entry.getKey(), entry.getValue());
        if (callback != null) {
            callback.onComplete(offsets, null);
        }
    }

    @Override
    public synchronized void commitSync(Map offsets) {
        commitAsync(offsets, null);
    }

    @Override
    public synchronized void commitAsync() {
        commitAsync(null);
    }

    @Override
    public synchronized void commitAsync(OffsetCommitCallback callback) {
        ensureNotClosed();
        commitAsync(this.subscriptions.allConsumed(), callback);
    }

    @Override
    public synchronized void commitSync() {
        commitSync(this.subscriptions.allConsumed());
    }

    @Override
    public synchronized void commitSync(Duration timeout) {
        commitSync(this.subscriptions.allConsumed());
    }

    @Override
    public void commitSync(Map offsets, final Duration timeout) {
        commitSync(offsets);
    }

    @Override
    public synchronized void seek(TopicPartition partition, long offset) {
        ensureNotClosed();
        subscriptions.seek(partition, offset);
    }

    @Override
    public void seek(TopicPartition partition, OffsetAndMetadata offsetAndMetadata) {
        ensureNotClosed();
        subscriptions.seek(partition, offsetAndMetadata.offset());
    }

    @Deprecated
    @Override
    public synchronized OffsetAndMetadata committed(final TopicPartition partition) {
        return committed(singleton(partition)).get(partition);
    }

    @Deprecated
    @Override
    public OffsetAndMetadata committed(final TopicPartition partition, final Duration timeout) {
        return committed(partition);
    }

    @Override
    public synchronized Map committed(final Set partitions) {
        ensureNotClosed();

        return partitions.stream()
            .filter(committed::containsKey)
            .collect(Collectors.toMap(tp -> tp, tp -> subscriptions.isAssigned(tp) ?
                committed.get(tp) : new OffsetAndMetadata(0)));
    }

    @Override
    public synchronized Map committed(final Set partitions, final Duration timeout) {
        return committed(partitions);
    }

    @Override
    public synchronized long position(TopicPartition partition) {
        ensureNotClosed();
        if (!this.subscriptions.isAssigned(partition))
            throw new IllegalArgumentException("You can only check the position for partitions assigned to this consumer.");
        SubscriptionState.FetchPosition position = this.subscriptions.position(partition);
        if (position == null) {
            updateFetchPosition(partition);
            position = this.subscriptions.position(partition);
        }
        return position.offset;
    }

    @Override
    public synchronized long position(TopicPartition partition, final Duration timeout) {
        return position(partition);
    }

    @Override
    public synchronized void seekToBeginning(Collection partitions) {
        ensureNotClosed();
        subscriptions.requestOffsetReset(partitions, OffsetResetStrategy.EARLIEST);
    }

    public synchronized void updateBeginningOffsets(Map newOffsets) {
        beginningOffsets.putAll(newOffsets);
    }

    @Override
    public synchronized void seekToEnd(Collection partitions) {
        ensureNotClosed();
        subscriptions.requestOffsetReset(partitions, OffsetResetStrategy.LATEST);
    }

    public synchronized void updateEndOffsets(final Map newOffsets) {
        endOffsets.putAll(newOffsets);
    }

    @Override
    public synchronized Map metrics() {
        ensureNotClosed();
        return Collections.emptyMap();
    }

    @Override
    public synchronized List partitionsFor(String topic) {
        ensureNotClosed();
        return this.partitions.getOrDefault(topic, Collections.emptyList());
    }

    @Override
    public synchronized Map> listTopics() {
        ensureNotClosed();
        return partitions;
    }

    public synchronized void updatePartitions(String topic, List partitions) {
        ensureNotClosed();
        this.partitions.put(topic, partitions);
    }

    @Override
    public synchronized void pause(Collection partitions) {
        for (TopicPartition partition : partitions) {
            subscriptions.pause(partition);
            paused.add(partition);
        }
    }

    @Override
    public synchronized void resume(Collection partitions) {
        for (TopicPartition partition : partitions) {
            subscriptions.resume(partition);
            paused.remove(partition);
        }
    }

    @Override
    public synchronized Map offsetsForTimes(Map timestampsToSearch) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    @Override
    public synchronized Map beginningOffsets(Collection partitions) {
        if (offsetsException != null) {
            RuntimeException exception = this.offsetsException;
            this.offsetsException = null;
            throw exception;
        }
        Map result = new HashMap<>();
        for (TopicPartition tp : partitions) {
            Long beginningOffset = beginningOffsets.get(tp);
            if (beginningOffset == null)
                throw new IllegalStateException("The partition " + tp + " does not have a beginning offset.");
            result.put(tp, beginningOffset);
        }
        return result;
    }

    @Override
    public synchronized Map endOffsets(Collection partitions) {
        if (offsetsException != null) {
            RuntimeException exception = this.offsetsException;
            this.offsetsException = null;
            throw exception;
        }
        Map result = new HashMap<>();
        for (TopicPartition tp : partitions) {
            Long endOffset = endOffsets.get(tp);
            if (endOffset == null)
                throw new IllegalStateException("The partition " + tp + " does not have an end offset.");
            result.put(tp, endOffset);
        }
        return result;
    }

    @Override
    public void close() {
        close(Duration.ofMillis(DEFAULT_CLOSE_TIMEOUT_MS));
    }

    @Override
    public synchronized void close(Duration timeout) {
        this.closed = true;
    }

    public synchronized boolean closed() {
        return this.closed;
    }

    @Override
    public synchronized void wakeup() {
        wakeup.set(true);
    }

    /**
     * Schedule a task to be executed during a poll(). One enqueued task will be executed per {@link #poll(Duration)}
     * invocation. You can use this repeatedly to mock out multiple responses to poll invocations.
     * @param task the task to be executed
     */
    public synchronized void schedulePollTask(Runnable task) {
        synchronized (pollTasks) {
            pollTasks.add(task);
        }
    }

    public synchronized void scheduleNopPollTask() {
        schedulePollTask(() -> { });
    }

    public synchronized Set paused() {
        return Collections.unmodifiableSet(new HashSet<>(paused));
    }

    private void ensureNotClosed() {
        if (this.closed)
            throw new IllegalStateException("This consumer has already been closed.");
    }

    private void updateFetchPosition(TopicPartition tp) {
        if (subscriptions.isOffsetResetNeeded(tp)) {
            resetOffsetPosition(tp);
        } else if (!committed.containsKey(tp)) {
            subscriptions.requestOffsetReset(tp);
            resetOffsetPosition(tp);
        } else {
            subscriptions.seek(tp, committed.get(tp).offset());
        }
    }

    private void resetOffsetPosition(TopicPartition tp) {
        OffsetResetStrategy strategy = subscriptions.resetStrategy(tp);
        Long offset;
        if (strategy == OffsetResetStrategy.EARLIEST) {
            offset = beginningOffsets.get(tp);
            if (offset == null)
                throw new IllegalStateException("MockConsumer didn't have beginning offset specified, but tried to seek to beginning");
        } else if (strategy == OffsetResetStrategy.LATEST) {
            offset = endOffsets.get(tp);
            if (offset == null)
                throw new IllegalStateException("MockConsumer didn't have end offset specified, but tried to seek to end");
        } else {
            throw new NoOffsetForPartitionException(tp);
        }
        seek(tp, offset);
    }

    @Override
    public List partitionsFor(String topic, Duration timeout) {
        return partitionsFor(topic);
    }

    @Override
    public Map> listTopics(Duration timeout) {
        return listTopics();
    }

    @Override
    public Map offsetsForTimes(Map timestampsToSearch,
            Duration timeout) {
        return offsetsForTimes(timestampsToSearch);
    }

    @Override
    public Map beginningOffsets(Collection partitions, Duration timeout) {
        return beginningOffsets(partitions);
    }

    @Override
    public Map endOffsets(Collection partitions, Duration timeout) {
        return endOffsets(partitions);
    }

    @Override
    public OptionalLong currentLag(TopicPartition topicPartition) {
        if (endOffsets.containsKey(topicPartition)) {
            return OptionalLong.of(endOffsets.get(topicPartition) - position(topicPartition));
        } else {
            // if the test doesn't bother to set an end offset, we assume it wants to model being caught up.
            return OptionalLong.of(0L);
        }
    }

    @Override
    public ConsumerGroupMetadata groupMetadata() {
        return new ConsumerGroupMetadata("dummy.group.id", 1, "1", Optional.empty());
    }

    @Override
    public void enforceRebalance() {
        enforceRebalance(null);
    }

    @Override
    public void enforceRebalance(final String reason) {
        shouldRebalance = true;
    }

    public boolean shouldRebalance() {
        return shouldRebalance;
    }

    public void resetShouldRebalance() {
        shouldRebalance = false;
    }

    public Duration lastPollTimeout() {
        return lastPollTimeout;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy