All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.redis.spring.batch.reader.StreamItemReader Maven / Gradle / Ivy

There is a newer version: 4.0.7
Show newest version
package com.redis.spring.batch.reader;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.springframework.batch.item.ExecutionContext;
import org.springframework.batch.item.ItemStreamException;
import org.springframework.batch.item.support.AbstractItemStreamItemReader;
import org.springframework.util.ClassUtils;

import com.redis.spring.batch.util.ConnectionUtils;

import io.lettuce.core.AbstractRedisClient;
import io.lettuce.core.Consumer;
import io.lettuce.core.ReadFrom;
import io.lettuce.core.RedisBusyException;
import io.lettuce.core.StreamMessage;
import io.lettuce.core.XGroupCreateArgs;
import io.lettuce.core.XReadArgs;
import io.lettuce.core.XReadArgs.StreamOffset;
import io.lettuce.core.api.StatefulConnection;
import io.lettuce.core.api.sync.RedisStreamCommands;
import io.lettuce.core.codec.RedisCodec;

public class StreamItemReader extends AbstractItemStreamItemReader>
        implements PollableItemReader> {

    public enum StreamAckPolicy {
        AUTO, MANUAL
    }

    public static final Duration DEFAULT_POLL_DURATION = Duration.ofSeconds(1);

    public static final String DEFAULT_OFFSET = "0-0";

    public static final Duration DEFAULT_BLOCK = Duration.ofMillis(100);

    public static final long DEFAULT_COUNT = 50;

    public static final StreamAckPolicy DEFAULT_ACK_POLICY = StreamAckPolicy.AUTO;

    private final AbstractRedisClient client;

    private final RedisCodec codec;

    private final K stream;

    private final Consumer consumer;

    private String offset = DEFAULT_OFFSET;

    private Duration block = DEFAULT_BLOCK;

    private long count = DEFAULT_COUNT;

    private StreamAckPolicy ackPolicy = DEFAULT_ACK_POLICY;

    private StatefulConnection connection;

    private Iterator> iterator = Collections.emptyIterator();

    private MessageReader messageReader;

    private String lastId;

    private RedisStreamCommands commands;

    private ReadFrom readFrom;

    public StreamItemReader(AbstractRedisClient client, RedisCodec codec, K stream, Consumer consumer) {
        setName(ClassUtils.getShortName(getClass()));
        this.client = client;
        this.codec = codec;
        this.stream = stream;
        this.consumer = consumer;
    }

    public ReadFrom getReadFrom() {
        return readFrom;
    }

    public void setReadFrom(ReadFrom readFrom) {
        this.readFrom = readFrom;
    }

    public String getOffset() {
        return offset;
    }

    public void setOffset(String offset) {
        this.offset = offset;
    }

    public Duration getBlock() {
        return block;
    }

    public void setBlock(Duration block) {
        this.block = block;
    }

    public long getCount() {
        return count;
    }

    public void setCount(long count) {
        this.count = count;
    }

    public StreamAckPolicy getAckPolicy() {
        return ackPolicy;
    }

    public void setAckPolicy(StreamAckPolicy policy) {
        this.ackPolicy = policy;
    }

    private XReadArgs args(long blockMillis) {
        return XReadArgs.Builder.count(count).block(blockMillis);
    }

    @Override
    public synchronized void open(ExecutionContext executionContext) {
        super.open(executionContext);
        if (!isOpen()) {
            doOpen();
        }
    }

    private void doOpen() {
        connection = ConnectionUtils.supplier(client, codec, readFrom).get();
        commands = ConnectionUtils.sync(connection);
        StreamOffset streamOffset = StreamOffset.from(stream, offset);
        XGroupCreateArgs args = XGroupCreateArgs.Builder.mkstream(true);
        try {
            commands.xgroupCreate(streamOffset, consumer.getGroup(), args);
        } catch (RedisBusyException e) {
            // Consumer Group name already exists, ignore
        }
        lastId = offset;
        messageReader = reader();
    }

    public boolean isOpen() {
        return messageReader != null;
    }

    @Override
    public synchronized void close() {
        if (isOpen()) {
            doClose();
        }
        super.close();
    }

    private void doClose() {
        messageReader = null;
        lastId = null;
        connection.close();
        connection = null;
        commands = null;
    }

    private MessageReader reader() {
        if (ackPolicy == StreamAckPolicy.MANUAL) {
            return new ExplicitAckPendingMessageReader();
        }
        return new AutoAckPendingMessageReader();
    }

    @Override
    public void update(ExecutionContext executionContext) throws ItemStreamException {
        // Do nothing
    }

    @Override
    public StreamMessage read() throws Exception {
        return poll(DEFAULT_POLL_DURATION.toMillis(), TimeUnit.MILLISECONDS);
    }

    @Override
    public synchronized StreamMessage poll(long timeout, TimeUnit unit) throws PollingException {
        if (!iterator.hasNext()) {
            List> messages = messageReader.read(unit.toMillis(timeout));
            if (messages == null || messages.isEmpty()) {
                return null;
            }
            iterator = messages.iterator();
        }
        return iterator.next();
    }

    public List> readMessages() {
        return messageReader.read(block.toMillis());
    }

    /**
     * Acks given messages
     * 
     * @param messages to be acked
     */
    public Long ack(Iterable> messages) {
        if (messages == null) {
            return 0L;
        }
        Stream ids = StreamSupport.stream(messages.spliterator(), false).map(StreamMessage::getId);
        return doAck(ids.toArray(String[]::new));
    }

    /**
     * Acks given message ids
     * 
     * @param ids message ids to be acked
     * @return
     */
    public Long ack(String... ids) {
        if (ids.length == 0) {
            return 0L;
        }
        lastId = ids[ids.length - 1];
        return doAck(ids);
    }

    private Long doAck(String... ids) {
        if (ids.length == 0) {
            return 0L;
        }
        return commands.xack(stream, consumer.getGroup(), ids);
    }

    public static class StreamId implements Comparable {

        public static final StreamId ZERO = StreamId.of(0, 0);

        private final long millis;

        private final long sequence;

        public StreamId(long millis, long sequence) {
            this.millis = millis;
            this.sequence = sequence;
        }

        private static void checkPositive(String id, long number) {
            if (number < 0) {
                throw new IllegalArgumentException(String.format("not an id: %s", id));
            }
        }

        public static StreamId parse(String id) {
            int off = id.indexOf("-");
            if (off == -1) {
                long millis = Long.parseLong(id);
                checkPositive(id, millis);
                return StreamId.of(millis, 0L);
            }
            long millis = Long.parseLong(id.substring(0, off));
            checkPositive(id, millis);
            long sequence = Long.parseLong(id.substring(off + 1));
            checkPositive(id, sequence);
            return StreamId.of(millis, sequence);
        }

        public static StreamId of(long millis, long sequence) {
            return new StreamId(millis, sequence);
        }

        public String toStreamId() {
            return millis + "-" + sequence;
        }

        @Override
        public String toString() {
            return toStreamId();
        }

        @Override
        public int compareTo(StreamId o) {
            long diff = millis - o.millis;
            if (diff != 0) {
                return Long.signum(diff);
            }
            return Long.signum(sequence - o.sequence);
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof StreamId)) {
                return false;
            }
            StreamId o = (StreamId) obj;
            return o.millis == millis && o.sequence == sequence;
        }

        @Override
        public int hashCode() {
            long val = millis * 31 * sequence;
            return (int) (val ^ (val >> 32));
        }

    }

    private interface MessageReader {

        /**
         * Reads messages from a stream
         * 
         * @param commands Synchronous executed commands for Streams
         * @param args Stream read command args
         * @return list of messages retrieved from the stream or empty list if no messages available
         * @throws MessageReadException
         */
        List> read(long blockMillis);

    }

    private class ExplicitAckPendingMessageReader implements MessageReader {

        @SuppressWarnings("unchecked")
        protected List> readMessages(XReadArgs args) {
            return recover(commands.xreadgroup(consumer, args, StreamOffset.from(stream, DEFAULT_OFFSET)));
        }

        protected List> recover(List> messages) {
            if (messages.isEmpty()) {
                return messages;
            }
            List> recoveredMessages = new ArrayList<>();
            List> messagesToAck = new ArrayList<>();
            StreamId recoveryId = StreamId.parse(lastId);
            for (StreamMessage message : messages) {
                StreamId messageId = StreamId.parse(message.getId());
                if (messageId.compareTo(recoveryId) > 0) {
                    recoveredMessages.add(message);
                    lastId = message.getId();
                } else {
                    messagesToAck.add(message);
                }
            }
            ack(messagesToAck);
            return recoveredMessages;
        }

        protected MessageReader messageReader() {
            return new ExplicitAckMessageReader();
        }

        @Override
        public List> read(long blockMillis) {
            List> messages;
            messages = readMessages(args(blockMillis));
            if (messages.isEmpty()) {
                messageReader = messageReader();
                return messageReader.read(blockMillis);
            }
            return messages;
        }

    }

    private class ExplicitAckMessageReader implements MessageReader {

        @SuppressWarnings("unchecked")
        @Override
        public List> read(long blockMillis) {
            return commands.xreadgroup(consumer, args(blockMillis), StreamOffset.lastConsumed(stream));
        }

    }

    private class AutoAckPendingMessageReader extends ExplicitAckPendingMessageReader {

        @Override
        protected StreamItemReader.MessageReader messageReader() {
            return new AutoAckMessageReader();
        }

        @Override
        protected List> recover(List> messages) {
            ack(messages);
            return Collections.emptyList();
        }

    }

    private class AutoAckMessageReader extends ExplicitAckMessageReader {

        @Override
        public List> read(long blockMillis) {
            List> messages = super.read(blockMillis);
            ack(messages);
            return messages;
        }

    }

    public long streamLength() {
        return commands.xlen(stream);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy