All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.phantomthief.jedis.JedisHelper Maven / Gradle / Ivy

/**
 *
 */
package com.github.phantomthief.jedis;

import static com.github.phantomthief.tuple.Tuple.tuple;
import static com.github.phantomthief.util.MoreSuppliers.lazy;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.Iterables.partition;
import static com.google.common.collect.Maps.newHashMapWithExpectedSize;
import static com.google.common.reflect.Reflection.newProxy;
import static java.lang.System.currentTimeMillis;
import static java.lang.System.nanoTime;
import static java.lang.reflect.Proxy.newProxyInstance;
import static java.util.Collections.singleton;
import static java.util.function.Function.identity;

import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;

import javax.annotation.CheckReturnValue;
import javax.annotation.Nonnull;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.github.phantomthief.jedis.OpInterceptor.JedisOpCall;
import com.github.phantomthief.jedis.exception.NoAvailablePoolException;
import com.github.phantomthief.jedis.exception.RethrowException;
import com.github.phantomthief.tuple.TwoTuple;
import com.github.phantomthief.util.CursorIteratorEx;
import com.github.phantomthief.util.TriFunction;

import redis.clients.jedis.BasicCommands;
import redis.clients.jedis.BinaryJedis;
import redis.clients.jedis.BinaryJedisCommands;
import redis.clients.jedis.BinaryRedisPipeline;
import redis.clients.jedis.BinaryShardedJedis;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCommands;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.PipelineBase;
import redis.clients.jedis.RedisPipeline;
import redis.clients.jedis.Response;
import redis.clients.jedis.ScanParams;
import redis.clients.jedis.ScanResult;
import redis.clients.jedis.ShardedJedis;
import redis.clients.jedis.ShardedJedisPipeline;
import redis.clients.jedis.ShardedJedisPool;
import redis.clients.jedis.Tuple;
import redis.clients.util.Pool;

/**
 * @author w.vela
 */
public class JedisHelper {

    public static final long SETNX_KEY_NOT_SET = 0L;
    public static final long SETNX_KEY_SET = 1L;
    public static final int LIMIT_INFINITY = -1;
    public static final String POSITIVE_INF = "+inf";
    public static final String NEGATIVE_INF = "-inf";
    public static final String NOT_EXIST = "NX";
    public static final String ALREADY_EXIST = "XX";
    public static final String SECONDS = "EX";
    public static final String MILLISECONDS = "PX";
    private static final Logger logger = LoggerFactory.getLogger(JedisHelper.class);
    private static final int PARTITION_SIZE = 100;

    private static final Object EMPTY_KEY = new Object();

    private final Supplier poolFactory;
    private final int pipelinePartitionSize;

    private final Class jedisType;
    private final Class binaryJedisType;

    private final Supplier basicCommandsSupplier = lazy(this::getBasic0);
    private final Supplier jedisCommandsSupplier = lazy(this::get0);
    private final Supplier binaryJedisCommandsSupplier = lazy(
            this::getBinary0);

    private final List> poolListeners;
    private final List> opListeners;
    private final List> pipelineOpListeners;

    private final List> opInterceptors;

    @SuppressWarnings("unchecked")
    private JedisHelper(Builder builder) {
        this.poolFactory = builder.poolFactory;
        this.pipelinePartitionSize = builder.pipelinePartitionSize;
        this.jedisType = builder.jedisType;
        this.binaryJedisType = builder.binaryJedisType;
        this.poolListeners = builder.poolListeners;
        this.opListeners = builder.opListeners;
        this.pipelineOpListeners = (List) builder.pipelineOpListeners;
        this.opInterceptors = builder.opInterceptors;
    }

    public static String getShardBitKey(long bit, String keyPrefix, int keyHashRange) {
        return keyPrefix + "_" + (bit / keyHashRange);
    }

    public static Map getShardBitKeys(Collection bits, String keyPrefix,
            int keyHashRange) {
        Map result = new HashMap<>();
        for (Long bit : bits) {
            result.put(bit, getShardBitKey(bit, keyPrefix, keyHashRange));
        }
        return result;
    }

    @SuppressWarnings("unchecked")
    public static Builder
            newShardedBuilder(Supplier poolFactory) {
        Builder builder = new Builder<>();
        builder.poolFactory = (Supplier) poolFactory;
        builder.jedisType = ShardedJedis.class;
        builder.binaryJedisType = BinaryShardedJedis.class;
        return builder;
    }

    @SuppressWarnings("unchecked")
    public static  Builder newBuilder(Supplier poolFactory) {
        Builder builder = new Builder<>();
        builder.poolFactory = (Supplier) poolFactory;
        builder.jedisType = Jedis.class;
        builder.binaryJedisType = BinaryJedis.class;
        return builder;
    }

    public void pipeline(Consumer function) {
        pipeline(p -> {
            function.accept(p);
            return null;
        });
    }

    public void binaryPipeline(Consumer function) {
        binaryPipeline(p -> {
            function.accept(p);
            return null;
        });
    }

    public  V pipeline(Function> function) {
        return pipeline(singleton(EMPTY_KEY), (p, k) -> function.apply(p)).get(EMPTY_KEY);
    }

    public  V binaryPipeline(Function> function) {
        return binaryPipeline(singleton(EMPTY_KEY), (p, k) -> function.apply(p)).get(EMPTY_KEY);
    }

    public  Map pipeline(Iterable keys,
            BiFunction> function) {
        return pipeline(keys, function, identity());
    }

    public  Map binaryPipeline(Iterable keys,
            BiFunction> function) {
        return binaryPipeline(keys, function, identity());
    }

    public  Map pipeline(Iterable keys,
            BiFunction> function, Function decoder) {
        return pipeline(keys, function, decoder, true);
    }

    public  Map binaryPipeline(Iterable keys,
            BiFunction> function, Function decoder) {
        return binaryPipeline(keys, function, decoder, true);
    }

    public  Map pipeline(Iterable keys,
            BiFunction> function, Function decoder,
            boolean includeNullValue) {
        return pipeline0(keys, function, decoder, includeNullValue, this::generatePipeline);
    }

    public  Map binaryPipeline(Iterable keys,
            BiFunction> function, Function decoder,
            boolean includeNullValue) {
        return pipeline0(keys, function, decoder, includeNullValue, this::generateBinaryPipeline);
    }

    @SuppressWarnings({ "unchecked" })
    private  Map, Object> fireOnPipelineStarted(Object pool) {
        Map, Object> map = new HashMap<>();
        // never use collect, coz value may be null.
        for (PipelineOpListener pipelineOpListener : pipelineOpListeners) {
            Object value = null;
            try {
                value = pipelineOpListener.onPipelineStarted(pool);
                map.put((PipelineOpListener) pipelineOpListener, value);
            } catch (RethrowException e) {
                throw e.getThrowable();
            } catch (Throwable e) {
                logger.error("", e);
                map.put((PipelineOpListener) pipelineOpListener, value);
            }
        }
        return map;
    }

    private  Map pipeline0(Iterable keys,
            BiFunction> function, Function decoder,
            boolean includeNullValue,
            TriFunction, Object>, TwoTuple> pipelineGenerator) {
        int size;
        if (keys instanceof Collection) {
            size = ((Collection) keys).size();
        } else {
            size = 16;
        }
        Map result = newHashMapWithExpectedSize(size);
        if (keys != null) {
            Iterable> partition = partition(keys, pipelinePartitionSize);

            for (List list : partition) {
                Object pool = poolFactory.get();
                Map, Object> started = fireOnPipelineStarted(
                        pool);
                Throwable t = null;
                try (J jedis = getJedis(pool)) {
                    TwoTuple tuple = pipelineGenerator.apply(pool, jedis,
                            started);
                    PipelineBase pipeline = tuple.getFirst();
                    P1 p1 = tuple.getSecond();
                    Map> thisMap = new HashMap<>(list.size());
                    for (K key : list) {
                        Response apply = function.apply(p1, key);
                        if (apply != null) {
                            thisMap.put(key, apply);
                        }
                    }
                    fireBeforeSync(pool, pipeline, started, t);
                    syncPipeline(pipeline);
                    thisMap.forEach((key, value) -> {
                        V rawValue = value.get();
                        if (rawValue != null || includeNullValue) {
                            T apply = decoder.apply(rawValue);
                            result.put(key, apply);
                        }
                    });
                } catch (Throwable e) {
                    t = e;
                } finally {
                    fireAfterSync(pool, started, t);
                }
            }
        }
        return result;
    }

    private void fireBeforeSync(Object pool, PipelineBase pipeline,
            Map, Object> s, Throwable t) {
        s.forEach((p, v) -> {
            try {
                p.beforeSync(pool, pipeline, v);
            } catch (Throwable e) {
                logger.error("", e);
            }
        });
    }

    private void fireAfterSync(Object pool, Map, Object> s,
            Throwable t) {
        s.forEach((p, v) -> {
            try {
                p.afterSync(pool, v, t);
            } catch (RethrowException e) {
                throw e.getThrowable();
            } catch (Throwable e) {
                logger.error("", e);
            }
        });
    }

    public BasicCommands getBasic() {
        return basicCommandsSupplier.get();
    }

    private BasicCommands getBasic0() {
        return (BasicCommands) newProxyInstance(jedisType.getClassLoader(),
                jedisType.getInterfaces(), new PoolableJedisCommands());
    }

    public JedisCommands get() {
        return jedisCommandsSupplier.get();
    }

    private JedisCommands get0() {
        return (JedisCommands) newProxyInstance(jedisType.getClassLoader(),
                jedisType.getInterfaces(), new PoolableJedisCommands());
    }

    public BinaryJedisCommands getBinary() {
        return binaryJedisCommandsSupplier.get();
    }

    private BinaryJedisCommands getBinary0() {
        return (BinaryJedisCommands) newProxyInstance(binaryJedisType.getClassLoader(),
                binaryJedisType.getInterfaces(), new PoolableJedisCommands());
    }

    private void syncPipeline(PipelineBase pipeline) {
        if (pipeline instanceof Pipeline) {
            ((Pipeline) pipeline).sync();
        } else if (pipeline instanceof ShardedJedisPipeline) {
            ((ShardedJedisPipeline) pipeline).sync();
        }
    }

    @SuppressWarnings("unchecked")
    private J getJedis(Object pool) {
        if (pool instanceof Pool) {
            long borrowedTime = currentTimeMillis();
            long borrowedNanoTime = nanoTime();
            try {
                J resource = ((Pool) pool).getResource();
                firePoolListener(pool, borrowedTime, borrowedNanoTime, null);
                return resource;
            } catch (Throwable e) {
                firePoolListener(pool, borrowedTime, borrowedNanoTime, e);
                throw e;
            }
        } else {
            throw new IllegalArgumentException("invalid pool:" + pool);
        }
    }

    private void firePoolListener(Object pool, long borrowedTime, long borrowedNanoTime, Throwable e) {
        for (PoolListener poolListener : poolListeners) {
            try {
                poolListener.onPoolBorrowed(pool, borrowedTime, borrowedNanoTime, e);
            } catch (Throwable ex) {
                logger.error("", ex);
            }
        }
    }

    @SuppressWarnings("unchecked")
    private TwoTuple generatePipeline(Object pool, J jedis,
            Map, Object> startPipeline) {
        if (jedis instanceof Jedis) {
            Pipeline pipelined = ((Jedis) jedis).pipelined();
            PipelineBase p = pipelined;
            RedisPipeline p1 = newProxy(RedisPipeline.class,
                    new PipelineListenerHandler<>(pool, p, pipelineOpListeners, startPipeline));
            return tuple(p, p1);
        } else if (jedis instanceof ShardedJedis) {
            ShardedJedisPipeline pipelined = ((ShardedJedis) jedis).pipelined();
            PipelineBase p = pipelined;
            RedisPipeline p1 = newProxy(RedisPipeline.class,
                    new PipelineListenerHandler<>(pool, p, pipelineOpListeners, startPipeline));
            return tuple(p, p1);
        } else {
            throw new IllegalArgumentException("invalid jedis:" + jedis);
        }
    }

    @SuppressWarnings("unchecked")
    private TwoTuple generateBinaryPipeline(Object pool, J jedis,
            Map, Object> startPipeline) {
        if (jedis instanceof Jedis) {
            Pipeline pipelined = ((Jedis) jedis).pipelined();
            PipelineBase p = pipelined;
            BinaryRedisPipeline p1 = newProxy(BinaryRedisPipeline.class,
                    new PipelineListenerHandler<>(pool, p, pipelineOpListeners, startPipeline));
            return tuple(p, p1);
        } else if (jedis instanceof ShardedJedis) {
            ShardedJedisPipeline pipelined = ((ShardedJedis) jedis).pipelined();
            PipelineBase p = pipelined;
            BinaryRedisPipeline p1 = newProxy(BinaryRedisPipeline.class,
                    new PipelineListenerHandler<>(pool, p, pipelineOpListeners, startPipeline));
            return tuple(p, p1);
        } else {
            throw new IllegalArgumentException("invalid jedis:" + jedis);
        }
    }

    public boolean getShardBit(long bit, String keyPrefix, int keyHashRange) {
        return getShardBit(singleton(bit), keyPrefix, keyHashRange).getOrDefault(bit, false);
    }

    public Map getShardBit(Collection bits, String keyPrefix,
            int keyHashRange) {
        return pipeline(bits, (p, bit) -> p.getbit(getShardBitKey(bit, keyPrefix, keyHashRange),
                bit % keyHashRange));
    }

    public long getShardBitCount(String keyPrefix, int keyHashRange, long start, long end) {
        return generateKeys(keyPrefix, keyHashRange, start, end).values().stream()
                .mapToLong(get()::bitcount).sum();
    }

    public boolean setShardBit(long bit, String keyPrefix, int keyHashRange) {
        return setShardBit(singleton(bit), keyPrefix, keyHashRange).get(bit);
    }

    public boolean setShardBit(long bit, String keyPrefix, int keyHashRange, boolean value) {
        return setShardBitSet(singleton(bit), keyPrefix, keyHashRange, value).get(bit);
    }

    public Map setShardBitSet(Collection bits, String keyPrefix,
            int keyHashRange, boolean value) {
        return pipeline(bits, (p, bit) -> p.setbit(getShardBitKey(bit, keyPrefix, keyHashRange),
                bit % keyHashRange, value));
    }

    public Map setShardBit(Collection bits, String keyPrefix,
            int keyHashRange) {
        return setShardBitSet(bits, keyPrefix, keyHashRange, true);
    }

    public void delShardBit(String keyPrefix, int keyHashRange, long start, long end) {
        Map allKeys = generateKeys(keyPrefix, keyHashRange, start, end);
        allKeys.values().forEach(get()::del);
    }

    public Stream iterateShardBit(String keyPrefix, int keyHashRange, long start, long end) {
        Map allKeys = generateKeys(keyPrefix, keyHashRange, start, end);
        return allKeys.entrySet().stream().flatMap(this::mapToLong);
    }

    private Map generateKeys(String keyPrefix, int keyHashRange, long start,
            long end) {
        Map result = new LinkedHashMap<>();
        for (long i = start; i <= end; i += keyHashRange) {
            result.put((i / keyHashRange) * keyHashRange, keyPrefix + "_" + (i / keyHashRange));
        }
        return result;
    }

    private Stream mapToLong(Entry entry) {
        byte[] bytes = getBinary().get(entry.getValue().getBytes());
        List result = new ArrayList<>();
        if (bytes != null && bytes.length > 0) {
            for (int i = 0; i < (bytes.length * 8); i++) {
                if ((bytes[i / 8] & (1 << (7 - (i % 8)))) != 0) {
                    result.add(entry.getKey() + i);
                }
            }
        }
        return result.stream();
    }

    @SuppressWarnings("RedundantTypeArguments")
    public Stream scan(ScanParams params) {
        // javac cannot infer types...
        return this. scan((j, c) -> {
            if (j instanceof Jedis) {
                return ((Jedis) j).scan(c, params);
            } else if (j instanceof ShardedJedis) {
                throw new UnsupportedOperationException();
            } else {
                throw new UnsupportedOperationException();
            }
        }, ScanResult::getStringCursor, "0").stream();
    }

    public Stream> hscan(String key) {
        return hscan(key, new ScanParams());
    }

    @SuppressWarnings("RedundantTypeArguments")
    public Stream> hscan(String key, ScanParams params) {
        // javac cannot infer types...
        return this.> scan((j, c) -> {
            if (j instanceof Jedis) {
                return ((Jedis) j).hscan(key, c, params);
            } else if (j instanceof ShardedJedis) {
                return ((ShardedJedis) j).hscan(key, c, params);
            } else {
                throw new UnsupportedOperationException();
            }
        }, ScanResult::getStringCursor, "0").stream();
    }

    public Stream zscan(String key) {
        return zscan(key, new ScanParams());
    }

    @SuppressWarnings("RedundantTypeArguments")
    public Stream zscan(String key, ScanParams params) {
        // javac cannot infer types...
        return this. scan((j, c) -> {
            if (j instanceof Jedis) {
                return ((Jedis) j).zscan(key, c, params);
            } else if (j instanceof ShardedJedis) {
                return ((ShardedJedis) j).zscan(key, c, params);
            } else {
                throw new UnsupportedOperationException();
            }
        }, ScanResult::getStringCursor, "0").stream();
    }

    public Stream sscan(String key) {
        return sscan(key, new ScanParams());
    }

    @SuppressWarnings("RedundantTypeArguments")
    public Stream sscan(String key, ScanParams params) {
        // javac cannot infer types...
        return this. scan((j, c) -> {
            if (j instanceof Jedis) {
                return ((Jedis) j).sscan(key, c, params);
            } else if (j instanceof ShardedJedis) {
                return ((ShardedJedis) j).sscan(key, c, params);
            } else {
                throw new UnsupportedOperationException();
            }
        }, ScanResult::getStringCursor, "0").stream();
    }

    private  CursorIteratorEx> scan(
            BiFunction> scanFunction,
            Function, K> cursorExtractor, K initCursor) {
        return CursorIteratorEx.newBuilder()
                .withDataRetriever((K cursor) -> {
                    Object pool = poolFactory.get();
                    try (J jedis = getJedis(pool)) {
                        return scanFunction.apply(jedis, cursor);
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                })
                .withCursorExtractor(cursorExtractor)
                .withDataExtractor((ScanResult s) -> s.getResult().iterator())
                .withEndChecker(s -> "0".equals(s) || s == null)
                .withInitCursor(initCursor)
                .build();
    }

    public static final class Builder {

        private Supplier poolFactory;
        private int pipelinePartitionSize;

        private Class jedisType;
        private Class binaryJedisType;

        private List> poolListeners = new ArrayList<>();
        private List> opListeners = new ArrayList<>();
        private List> pipelineOpListeners = new ArrayList<>();

        private List> opInterceptors = new ArrayList<>();

        @CheckReturnValue
        @Nonnull
        public Builder withPipelinePartitionSize(int size) {
            this.pipelinePartitionSize = size;
            return this;
        }

        @CheckReturnValue
        @Nonnull
        public Builder addPoolListener(@Nonnull PoolListener poolListener) {
            this.poolListeners.add(checkNotNull(poolListener));
            return this;
        }

        /**
         * @param op interceptors will be called as adding sequence.
         */
        @CheckReturnValue
        @Nonnull
        public Builder addOpInterceptor(@Nonnull OpInterceptor op) {
            this.opInterceptors.add(checkNotNull(op));
            return this;
        }

        @CheckReturnValue
        @Nonnull
        public Builder addOpListener(@Nonnull OpListener op) {
            this.opListeners.add(checkNotNull(op));
            return this;
        }

        @CheckReturnValue
        @Nonnull
        public Builder addPipelineOpListener(@Nonnull PipelineOpListener op) {
            this.pipelineOpListeners.add(checkNotNull(op));
            return this;
        }

        @SuppressWarnings("unchecked")
        @Nonnull
        public JedisHelper build() {
            ensure();
            return new JedisHelper(this);
        }

        private void ensure() {
            if (pipelinePartitionSize <= 0) {
                pipelinePartitionSize = PARTITION_SIZE;
            }
        }
    }

    private final class PoolableJedisCommands implements InvocationHandler {

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            long start = currentTimeMillis();
            long startNano = nanoTime();
            Object pool = poolFactory.get();
            if (pool == null) {
                NoAvailablePoolException exception = new NoAvailablePoolException();
                long costInNano = nanoTime() - startNano;
                for (OpListener opListener : opListeners) {
                    opListener.onComplete(null, start, startNano, method, args, costInNano, exception);
                }
                throw exception;
            }
            Throwable t = null;
            try (J jedis = getJedis(pool)) {
                J thisJedis = jedis;
                for (OpInterceptor opInterceptor : opInterceptors) {
                    JedisOpCall call = opInterceptor.interceptCall(pool, method, jedis, args);
                    if (call == null) {
                        continue;
                    }
                    method = call.getMethod();
                    thisJedis = call.getJedis();
                    args = call.getArgs();

                    if (call.hasFinalObject()) {
                        return call.getFinalObject();
                    }
                }
                return method.invoke(thisJedis, args);
            } catch (Throwable e) {
                if (e instanceof InvocationTargetException) {
                    t = ((InvocationTargetException) e).getTargetException();
                } else {
                    t = e;
                }
                throw e;
            } finally {
                long cost = currentTimeMillis() - start;
                long costInNano = nanoTime() - startNano;
                for (OpListener opListener : opListeners) {
                    try {
                        opListener.onComplete(pool, start, startNano, method, args, costInNano, t);
                    } catch (Throwable e) {
                        logger.error("", e);
                    }
                }
            }
        }
    }
}