All Downloads are FREE. Search and download functionalities are using the official Maven repository.

es.moki.ratelimitj.inmemory.request.InMemorySlidingWindowRequestRateLimiter Maven / Gradle / Ivy

package es.moki.ratelimitj.inmemory.request;

import de.jkeylockmanager.manager.KeyLockManager;
import de.jkeylockmanager.manager.KeyLockManagers;
import es.moki.ratelimitj.core.limiter.request.RequestLimitRule;
import es.moki.ratelimitj.core.limiter.request.RequestRateLimiter;
import es.moki.ratelimitj.core.time.SystemTimeSupplier;
import es.moki.ratelimitj.core.time.TimeSupplier;
import net.jodah.expiringmap.ExpirationPolicy;
import net.jodah.expiringmap.ExpiringMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.concurrent.ThreadSafe;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

import static es.moki.ratelimitj.core.RateLimitUtils.coalesce;
import static java.util.Objects.requireNonNull;

@ThreadSafe
public class InMemorySlidingWindowRequestRateLimiter implements RequestRateLimiter {

    private static final Logger LOG = LoggerFactory.getLogger(InMemorySlidingWindowRequestRateLimiter.class);

    private final Set rules;
    private final TimeSupplier timeSupplier;
    private final ExpiringMap> expiringKeyMap;
    private final KeyLockManager lockManager = KeyLockManagers.newLock();

    public InMemorySlidingWindowRequestRateLimiter(Set rules) {
        this(rules, new SystemTimeSupplier());
    }

    public InMemorySlidingWindowRequestRateLimiter(Set rules, TimeSupplier timeSupplier) {
        this.rules = rules;
        this.timeSupplier = timeSupplier;
        this.expiringKeyMap = ExpiringMap.builder().variableExpiration().build();
    }

    InMemorySlidingWindowRequestRateLimiter(ExpiringMap> expiringKeyMap, Set rules, TimeSupplier timeSupplier) {
        this.expiringKeyMap = expiringKeyMap;
        this.rules = rules;
        this.timeSupplier = timeSupplier;
    }

    @Override
    public boolean overLimitWhenIncremented(String key) {
        return overLimitWhenIncremented(key, 1);
    }

    @Override
    public boolean overLimitWhenIncremented(String key, int weight) {
        return eqOrGeLimit(key, weight, true);
    }

    @Override
    public boolean geLimitWhenIncremented(String key) {
        return geLimitWhenIncremented(key, 1);
    }

    @Override
    public boolean geLimitWhenIncremented(String key, int weight) {
        return eqOrGeLimit(key, weight, false);
    }

//    @Override
//    public boolean isOverLimit(String key) {
//        return overLimitWhenIncremented(key, 0);
//    }
//
//    @Override
//    public boolean isGeLimit(String key) {
//        return geLimitWhenIncremented(key, 0);
//    }

    @Override
    public boolean resetLimit(String key) {
        return expiringKeyMap.remove(key) != null;
    }

    private ConcurrentMap getMap(String key, int longestDuration) {

        // Currently unable to putIfAbsent when using jodah's expiry map so will wrap in a lock
        return lockManager.executeLocked(key, () -> {
            ConcurrentMap keyMap = expiringKeyMap.get(key);
            if (keyMap == null) {
                keyMap = new ConcurrentHashMap<>();
                expiringKeyMap.put(key, keyMap, ExpirationPolicy.CREATED, longestDuration, TimeUnit.SECONDS);
            }
            return keyMap;
        });
    }

    private boolean eqOrGeLimit(String key, int weight, boolean strictlyGreater) {

        requireNonNull(key, "key cannot be null");
        requireNonNull(rules, "rules cannot be null");
        if (rules.isEmpty()) {
            throw new IllegalArgumentException("at least one rule must be provided");
        }

        final long now = timeSupplier.get();
        // TODO implement cleanup
        final int longestDurationSeconds = rules.stream().map(RequestLimitRule::getDurationSeconds).reduce(Integer::max).orElse(0);
        List savedKeys = new ArrayList<>(rules.size());

        Map keyMap = getMap(key, longestDurationSeconds);
        boolean geLimit = false;

        // TODO perform each rule calculation in parallel
        for (RequestLimitRule rule : rules) {

            SavedKey savedKey = new SavedKey(now, rule.getDurationSeconds(), rule.getPrecision());
            savedKeys.add(savedKey);

            Long oldTs = keyMap.get(savedKey.tsKey);
            oldTs = oldTs != null ? oldTs : savedKey.trimBefore;

            if (oldTs > now) {
                // don't write in the past
                return true;
            }

            // discover what needs to be cleaned up
            long decr = 0;
            List dele = new ArrayList<>();
            long trim = Math.min(savedKey.trimBefore, oldTs + savedKey.blocks);

            for (long oldBlock = oldTs; oldBlock == trim - 1; oldBlock++) {
                String bkey = savedKey.countKey + oldBlock;
                Long bcount = keyMap.get(bkey);
                if (bcount != null) {
                    decr = decr + bcount;
                    dele.add(bkey);
                }
            }

            // handle cleanup
            Long cur;
            if (!dele.isEmpty()) {
                dele.forEach(keyMap::remove);
                final long decrement = decr;
                cur = keyMap.compute(savedKey.countKey, (k, v) -> v - decrement);
            } else {
                cur = keyMap.get(savedKey.countKey);
            }

            // check our limits
            long count = coalesce(cur, 0L) + weight;
            if (count > rule.getLimit()) {
                return true; // over limit, don't record request
            } else if (!strictlyGreater && count == rule.getLimit()) {
                geLimit = true; // at limit, do record request
            }
        }

        // there is enough resources, update the counts
        if (weight != 0) {
            for (SavedKey savedKey : savedKeys) {
                //update the current timestamp, count, and bucket count
                keyMap.put(savedKey.tsKey, savedKey.trimBefore);

                Long computedCountKeyValue = keyMap.compute(savedKey.countKey, (k, v) -> coalesce(v, 0L) + weight);
                Long computedCountKeyBlockIdValue = keyMap.compute(savedKey.countKey + savedKey.blockId, (k, v) -> coalesce(v, 0L) + weight);

                if (LOG.isDebugEnabled()) {
                    LOG.debug("{} {}={}", key, savedKey.countKey, computedCountKeyValue);
                    LOG.debug("{} {}={}", key, savedKey.countKey + savedKey.blockId, computedCountKeyBlockIdValue);
                }
            }
        }

        return geLimit;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy