org.redisson.RedissonRateLimiter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of redisson-all Show documentation
Show all versions of redisson-all Show documentation
Easy Redis Java client and Real-Time Data Platform. Valkey compatible. Sync/Async/RxJava3/Reactive API. Client side caching. Over 50 Redis based Java objects and services: JCache API, Apache Tomcat, Hibernate, Spring, Set, Multimap, SortedSet, Map, List, Queue, Deque, Semaphore, Lock, AtomicLong, Map Reduce, Bloom filter, Scheduler, RPC
The newest version!
/**
* Copyright (c) 2013-2024 Nikita Koksharov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.redisson;
import org.redisson.api.*;
import org.redisson.client.codec.LongCodec;
import org.redisson.client.codec.StringCodec;
import org.redisson.client.handler.State;
import org.redisson.client.protocol.RedisCommand;
import org.redisson.client.protocol.RedisCommands;
import org.redisson.client.protocol.decoder.MapEntriesDecoder;
import org.redisson.client.protocol.decoder.MultiDecoder;
import org.redisson.command.CommandAsyncExecutor;
import org.redisson.misc.CompletableFutureWrapper;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
/**
*
* @author Nikita Koksharov
*
*/
public class RedissonRateLimiter extends RedissonExpirable implements RRateLimiter {
public RedissonRateLimiter(CommandAsyncExecutor commandExecutor, String name) {
super(commandExecutor, name);
}
String getPermitsName() {
return suffixName(getRawName(), "permits");
}
String getClientPermitsName() {
return suffixName(getPermitsName(), getServiceManager().getId());
}
String getValueName() {
return suffixName(getRawName(), "value");
}
String getClientValueName() {
return suffixName(getValueName(), getServiceManager().getId());
}
@Override
public boolean tryAcquire() {
return tryAcquire(1);
}
@Override
public RFuture tryAcquireAsync() {
return tryAcquireAsync(1L);
}
@Override
public boolean tryAcquire(long permits) {
return get(tryAcquireAsync(RedisCommands.EVAL_NULL_BOOLEAN, permits));
}
@Override
public RFuture tryAcquireAsync(long permits) {
return tryAcquireAsync(RedisCommands.EVAL_NULL_BOOLEAN, permits);
}
@Override
public void acquire() {
get(acquireAsync());
}
@Override
public RFuture acquireAsync() {
return acquireAsync(1);
}
@Override
public void acquire(long permits) {
get(acquireAsync(permits));
}
@Override
public RFuture acquireAsync(long permits) {
CompletionStage f = tryAcquireAsync(permits, Duration.ofMillis(-1)).thenApply(res -> null);
return new CompletableFutureWrapper<>(f);
}
@Override
public boolean tryAcquire(long timeout, TimeUnit unit) {
return get(tryAcquireAsync(timeout, unit));
}
@Override
public RFuture tryAcquireAsync(long timeout, TimeUnit unit) {
return tryAcquireAsync(1, timeout, unit);
}
@Override
public boolean tryAcquire(Duration timeout) {
return get(tryAcquireAsync(timeout));
}
@Override
public RFuture tryAcquireAsync(Duration timeout) {
return tryAcquireAsync(1, timeout);
}
@Override
public boolean tryAcquire(long permits, Duration timeout) {
return get(tryAcquireAsync(permits, timeout));
}
@Override
public RFuture tryAcquireAsync(long permits, Duration timeout) {
CompletableFuture f = tryAcquireAsync(permits, timeout.toMillis());
return new CompletableFutureWrapper<>(f);
}
@Override
public boolean tryAcquire(long permits, long timeout, TimeUnit unit) {
return get(tryAcquireAsync(permits, timeout, unit));
}
@Override
public RFuture tryAcquireAsync(long permits, long timeout, TimeUnit unit) {
return tryAcquireAsync(permits, Duration.ofMillis(unit.toMillis(timeout)));
}
private CompletableFuture tryAcquireAsync(long permits, long timeoutInMillis) {
long s = System.currentTimeMillis();
RFuture future = tryAcquireAsync(RedisCommands.EVAL_LONG, permits);
return future.thenCompose(delay -> {
if (delay == null) {
return CompletableFuture.completedFuture(true);
}
if (timeoutInMillis == -1) {
CompletableFuture f = new CompletableFuture<>();
getServiceManager().newTimeout(t -> {
CompletableFuture r = tryAcquireAsync(permits, timeoutInMillis);
commandExecutor.transfer(r, f);
}, delay, TimeUnit.MILLISECONDS);
return f;
}
long el = System.currentTimeMillis() - s;
long remains = timeoutInMillis - el;
if (remains <= 0) {
return CompletableFuture.completedFuture(false);
}
CompletableFuture f = new CompletableFuture<>();
if (remains < delay) {
getServiceManager().newTimeout(t -> {
f.complete(false);
}, remains, TimeUnit.MILLISECONDS);
} else {
long start = System.currentTimeMillis();
getServiceManager().newTimeout(t -> {
long elapsed = System.currentTimeMillis() - start;
if (remains <= elapsed) {
f.complete(false);
return;
}
CompletableFuture r = tryAcquireAsync(permits, remains - elapsed);
commandExecutor.transfer(r, f);
}, delay, TimeUnit.MILLISECONDS);
}
return f;
}).toCompletableFuture();
}
private RFuture tryAcquireAsync(RedisCommand command, Long value) {
byte[] random = getServiceManager().generateIdArray();
return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, command,
"local rate = redis.call('hget', KEYS[1], 'rate');"
+ "local interval = redis.call('hget', KEYS[1], 'interval');"
+ "local type = redis.call('hget', KEYS[1], 'type');"
+ "assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')"
+ "local valueName = KEYS[2];"
+ "local permitsName = KEYS[4];"
+ "if type == '1' then "
+ "valueName = KEYS[3];"
+ "permitsName = KEYS[5];"
+ "end;"
+ "assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount cannot exceed defined rate'); "
+ "local currentValue = redis.call('get', valueName); "
+ "local res;"
+ "if currentValue ~= false then "
+ "local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
+ "local released = 0; "
+ "for i, v in ipairs(expiredValues) do "
+ "local random, permits = struct.unpack('Bc0I', v);"
+ "released = released + permits;"
+ "end; "
+ "if released > 0 then "
+ "redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
+ "if tonumber(currentValue) + released > tonumber(rate) then "
+ "local values = redis.call('zrange', permitsName, 0, -1); "
+ "local used = 0; "
+ "for i, v in ipairs(values) do "
+ "local random, permits = struct.unpack('Bc0I', v);"
+ "used = used + permits;"
+ "end; "
+ "currentValue = tonumber(rate) - used; "
+ "else "
+ "currentValue = tonumber(currentValue) + released; "
+ "end; "
+ "redis.call('set', valueName, currentValue);"
+ "end;"
+ "if tonumber(currentValue) < tonumber(ARGV[1]) then "
+ "local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); "
+ "res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));"
+ "else "
+ "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
+ "redis.call('decrby', valueName, ARGV[1]); "
+ "res = nil; "
+ "end; "
+ "else "
+ "redis.call('set', valueName, rate); "
+ "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
+ "redis.call('decrby', valueName, ARGV[1]); "
+ "res = nil; "
+ "end;"
+ "local keepAliveTime = redis.call('hget', KEYS[1], 'keepAliveTime'); "
+ "if (keepAliveTime ~= false and tonumber(keepAliveTime) > 0) then "
+ "redis.call('pexpire', KEYS[1], keepAliveTime); "
+ "redis.call('pexpire', valueName, keepAliveTime); "
+ "redis.call('pexpire', permitsName, keepAliveTime); "
+ "else "
+ "local ttl = redis.call('pttl', KEYS[1]); "
+ "if ttl > 0 then "
+ "redis.call('pexpire', valueName, ttl); "
+ "redis.call('pexpire', permitsName, ttl); "
+ "end; "
+ "end; "
+ "return res;",
Arrays.asList(getRawName(), getValueName(), getClientValueName(), getPermitsName(), getClientPermitsName()),
value, System.currentTimeMillis(), random);
}
@Override
public boolean trySetRate(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
return get(trySetRateAsync(type, rate, rateInterval, unit));
}
@Override
public RFuture trySetRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
return trySetRateAsync(type, rate, Duration.ofMillis(unit.toMillis(rateInterval)), Duration.ZERO);
}
@Override
public void setRate(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
get(setRateAsync(type, rate, rateInterval, unit));
}
@Override
public RFuture setRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
return setRateAsync(type, rate, Duration.ofMillis(unit.toMillis(rateInterval)), Duration.ZERO);
}
@Override
public RFuture trySetRateAsync(RateType type, long rate, Duration rateInterval, Duration keepAliveTime) {
if (!keepAliveTime.equals(Duration.ZERO) && keepAliveTime.toMillis() < rateInterval.toMillis()) {
throw new IllegalArgumentException("The parameter keepAliveTime should be greater than or equal to rateInterval");
}
return commandExecutor.evalWriteNoRetryAsync(getRawName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
"redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);"
+ "redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);"
+ "redis.call('hsetnx', KEYS[1], 'keepAliveTime', ARGV[4]);"
+ "local res = redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);"
+ "if res == 1 and tonumber(ARGV[4]) > 0 then "
+ "redis.call('pexpire', KEYS[1], ARGV[4]); "
+ "end; "
+ "return res;",
Collections.singletonList(getRawName()),
rate, rateInterval.toMillis(), type.ordinal(), keepAliveTime.toMillis());
}
@Override
public boolean trySetRate(RateType mode, long rate, Duration rateInterval, Duration keepAliveTime) {
return get(trySetRateAsync(mode, rate, rateInterval, keepAliveTime));
}
@Override
public RFuture trySetRateAsync(RateType mode, long rate, Duration rateInterval) {
return trySetRateAsync(mode, rate, rateInterval, Duration.ZERO);
}
@Override
public boolean trySetRate(RateType mode, long rate, Duration rateInterval) {
return get(trySetRateAsync(mode, rate, rateInterval));
}
@Override
public void setRate(RateType mode, long rate, Duration rateInterval, Duration keepAliveTime) {
get(setRateAsync(mode, rate, rateInterval, keepAliveTime));
}
@Override
public RFuture setRateAsync(RateType type, long rate, Duration rateInterval, Duration keepAliveTime) {
return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
"local valueName = KEYS[2];"
+ "local permitsName = KEYS[4];"
+ "if ARGV[3] == '1' then "
+ " valueName = KEYS[3];"
+ " permitsName = KEYS[5];"
+ "end "
+ "redis.call('hset', KEYS[1], 'rate', ARGV[1]);"
+ "redis.call('hset', KEYS[1], 'interval', ARGV[2]);"
+ "redis.call('hset', KEYS[1], 'type', ARGV[3]);"
+ "redis.call('hset', KEYS[1], 'keepAliveTime', ARGV[4]);"
+ "if tonumber(ARGV[4]) > 0 then "
+ "redis.call('pexpire', KEYS[1], ARGV[4]); "
+ "end; "
+ "redis.call('del', valueName, permitsName);",
Arrays.asList(getRawName(), getValueName(), getClientValueName(), getPermitsName(), getClientPermitsName()),
rate, rateInterval.toMillis(), type.ordinal(), keepAliveTime.toMillis());
}
@Override
public void setRate(RateType mode, long rate, Duration rateInterval) {
get(setRateAsync(mode, rate, rateInterval));
}
@Override
public RFuture setRateAsync(RateType mode, long rate, Duration rateInterval) {
return setRateAsync(mode, rate, rateInterval, Duration.ZERO);
}
private static final RedisCommand HGETALL = new RedisCommand("HGETALL", new MapEntriesDecoder(new MultiDecoder() {
@Override
public RateLimiterConfig decode(List
© 2015 - 2024 Weber Informatics LLC | Privacy Policy