net.dv8tion.jda.api.requests.SequentialRestRateLimiter Maven / Gradle / Ivy
Show all versions of JDA Show documentation
/*
* Copyright 2015 Austin Keener, Michael Ritter, Florian Spieß, and the JDA contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.dv8tion.jda.api.requests;
import net.dv8tion.jda.api.utils.MiscUtil;
import net.dv8tion.jda.internal.utils.JDALogger;
import okhttp3.Headers;
import okhttp3.Response;
import org.slf4j.Logger;
import javax.annotation.Nonnull;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.ReentrantLock;
/**
* A bucket is determined via the Path+Method+Major in the following way:
*
* - Get Hash from Path+Method (we call this route)
* - Get bucket from Hash+Major (we call this bucketid)
*
*
* If no hash is known we default to the constant "uninit" hash. The hash is loaded from HTTP responses using the "X-RateLimit-Bucket" response header.
* This hash is per Method+Path and can be stored indefinitely once received.
* Some endpoints don't return a hash, this means that the endpoint is uninit and will be in queue with only the major parameter.
*
*
To explain this further, lets look at the example of message history. The endpoint to fetch message history is {@code GET/channels/{channel.id}/messages}.
* This endpoint does not have any rate limit (uninit) and will thus use the hash {@code uninit+GET/channels/{channel.id}/messages}.
* The bucket id for this will be {@code uninit+GET/channels/{channel.id}/messages:channel_id={channel.id}} where {@code {channel.id}} would be replaced with the respective id.
* This means you can fetch history concurrently for multiple channels, but it will be in sequence for the same channel.
*
*
If the endpoint is not uninit we will receive a hash on the first response.
* Once this happens every uninit bucket will start moving its queue to the correct bucket.
* This is done during the queue work iteration so many requests to one endpoint would be moved correctly.
*
*
For example, the first message sending:
*
{@code
* public void onReady(ReadyEvent event) {
* TextChannel channel = event.getJDA().getTextChannelById("123");
* for (int i = 1; i <= 100; i++) {
* channel.sendMessage("Message: " + i).queue();
* }
* }
* }
*
* This will send 100 messages on startup. At this point we don't yet know the hash for this route, so we put them all in {@code uninit+POST/channels/{channel.id}/messages:channel_id=123}.
* The bucket iterates the requests in sync and gets the first response. This response provides the hash for this route, and we create a bucket for it.
* Once the response is handled we continue with the next request in the uninit bucket and notice the new bucket. We then move all related requests to this bucket.
*/
public final class SequentialRestRateLimiter implements RestRateLimiter
{
private static final Logger log = JDALogger.getLog(RestRateLimiter.class);
private static final String UNINIT_BUCKET = "uninit"; // we generate an uninit bucket for every major parameter configuration
private final CompletableFuture> shutdownHandle = new CompletableFuture<>();
private final Future> cleanupWorker;
private final RateLimitConfig config;
private boolean isStopped, isShutdown;
private final ReentrantLock lock = new ReentrantLock();
// Route -> Should we print warning for 429? AKA did we already hit it once before
private final Set hitRatelimit = new HashSet<>(5);
// Route -> Hash
private final Map hashes = new HashMap<>();
// Hash + Major Parameter -> Bucket
private final Map buckets = new HashMap<>();
// Bucket -> Rate-Limit Worker
private final Map> rateLimitQueue = new HashMap<>();
public SequentialRestRateLimiter(@Nonnull RateLimitConfig config)
{
this.config = config;
this.cleanupWorker = config.getScheduler().scheduleAtFixedRate(this::cleanup, 30, 30, TimeUnit.SECONDS);
}
@Override
public void enqueue(@Nonnull RestRateLimiter.Work task)
{
MiscUtil.locked(lock, () -> {
Bucket bucket = getBucket(task.getRoute());
bucket.enqueue(task);
runBucket(bucket);
});
}
@Override
public void stop(boolean shutdown, @Nonnull Runnable callback)
{
MiscUtil.locked(lock, () -> {
boolean doShutdown = shutdown;
if (!isStopped)
{
isStopped = true;
shutdownHandle.thenRun(callback);
if (!doShutdown)
{
int count = buckets.values().stream()
.mapToInt(bucket -> bucket.getRequests().size())
.sum();
if (count > 0)
log.info("Waiting for {} requests to finish.", count);
doShutdown = count == 0;
}
}
if (doShutdown && !isShutdown)
shutdown();
});
}
@Override
public boolean isStopped()
{
return isStopped;
}
@Override
public int cancelRequests()
{
return MiscUtil.locked(lock, () -> {
// Empty buckets will be removed by the cleanup worker, which also checks for rate limit parameters
int cancelled = (int) buckets.values()
.stream()
.map(Bucket::getRequests)
.flatMap(Collection::stream)
.filter(request -> !request.isPriority() && !request.isCancelled())
.peek(Work::cancel)
.count();
if (cancelled == 1)
log.warn("Cancelled 1 request!");
else if (cancelled > 1)
log.warn("Cancelled {} requests!", cancelled);
return cancelled;
});
}
private void shutdown()
{
isShutdown = true;
cleanupWorker.cancel(false);
cleanup();
shutdownHandle.complete(null);
}
private void cleanup()
{
// This will remove buckets that are no longer needed every 30 seconds to avoid memory leakage
// We will keep the hashes in memory since they are very limited (by the amount of possible routes)
MiscUtil.locked(lock, () -> {
int size = buckets.size();
Iterator> entries = buckets.entrySet().iterator();
while (entries.hasNext())
{
Map.Entry entry = entries.next();
Bucket bucket = entry.getValue();
if (isShutdown)
bucket.requests.forEach(Work::cancel); // Cancel all requests
bucket.requests.removeIf(Work::isSkipped); // Remove cancelled requests
// Check if the bucket is empty
if (bucket.requests.isEmpty() && !rateLimitQueue.containsKey(bucket))
{
// remove uninit if requests are empty
if (bucket.isUninit())
entries.remove();
// If the requests of the bucket are drained and the reset is expired the bucket has no valuable information
else if (bucket.reset <= getNow())
entries.remove();
// Remove empty buckets when the rate limiter is stopped
else if (isStopped)
entries.remove();
}
}
// Log how many buckets were removed
size -= buckets.size();
if (size > 0)
log.debug("Removed {} expired buckets", size);
else if (isStopped && !isShutdown)
shutdown();
});
}
private String getRouteHash(Route route)
{
return hashes.getOrDefault(route, UNINIT_BUCKET + "+" + route);
}
private Bucket getBucket(Route.CompiledRoute route)
{
return MiscUtil.locked(lock, () ->
{
// Retrieve the hash via the route
String hash = getRouteHash(route.getBaseRoute());
// Get or create a bucket for the hash + major parameters
String bucketId = hash + ":" + route.getMajorParameters();
return this.buckets.computeIfAbsent(bucketId, (id) ->
{
if (route.getBaseRoute().isInteractionBucket())
return new InteractionBucket(id);
else
return new ClassicBucket(id);
});
});
}
private void scheduleElastic(Bucket bucket)
{
if (isShutdown)
return;
ExecutorService elastic = config.getElastic();
ScheduledExecutorService scheduler = config.getScheduler();
try
{
// Avoid context switch if unnecessary
if (elastic == scheduler)
bucket.run();
else
elastic.execute(bucket);
}
catch (RejectedExecutionException ex)
{
if (!isShutdown)
log.error("Failed to execute bucket worker", ex);
}
catch (Throwable t)
{
log.error("Caught throwable in bucket worker", t);
if (t instanceof Error)
throw t;
}
}
private void runBucket(Bucket bucket)
{
if (isShutdown)
return;
// Schedule a new bucket worker if no worker is running
MiscUtil.locked(lock, () ->
rateLimitQueue.computeIfAbsent(bucket,
k -> config.getScheduler().schedule(
() -> scheduleElastic(bucket),
bucket.getRateLimit(), TimeUnit.MILLISECONDS))
);
}
private long parseLong(String input)
{
return input == null ? 0L : Long.parseLong(input);
}
private long parseDouble(String input)
{
//The header value is using a double to represent milliseconds and seconds:
// 5.250 this is 5 seconds and 250 milliseconds (5250 milliseconds)
return input == null ? 0L : (long) (Double.parseDouble(input) * 1000);
}
private long getNow()
{
return System.currentTimeMillis();
}
private Bucket updateBucket(Route.CompiledRoute route, Response response)
{
return MiscUtil.locked(lock, () ->
{
try
{
Bucket bucket = getBucket(route);
Headers headers = response.headers();
boolean global = headers.get(GLOBAL_HEADER) != null;
boolean cloudflare = headers.get("via") == null;
String hash = headers.get(HASH_HEADER);
String scope = headers.get(SCOPE_HEADER);
long now = getNow();
// Create a new bucket for the hash if needed
Route baseRoute = route.getBaseRoute();
if (hash != null)
{
if (!this.hashes.containsKey(baseRoute))
{
this.hashes.put(baseRoute, hash);
log.debug("Caching bucket hash {} -> {}", baseRoute, hash);
}
bucket = getBucket(route);
}
if (response.code() == 429)
{
String retryAfterHeader = headers.get(RETRY_AFTER_HEADER);
long retryAfter = parseLong(retryAfterHeader) * 1000; // seconds precision
// Handle global rate limit if necessary
if (global)
{
config.getGlobalRateLimit().setClassic(now + retryAfter);
log.error("Encountered global rate limit! Retry-After: {} ms Scope: {}", retryAfter, scope);
}
// Handle cloudflare rate limits, this applies to all routes and uses seconds for retry-after
else if (cloudflare)
{
config.getGlobalRateLimit().setCloudflare(now + retryAfter);
log.error("Encountered cloudflare rate limit! Retry-After: {} s", retryAfter / 1000);
}
// Handle hard rate limit, pretty much just log that it happened
else
{
boolean firstHit = hitRatelimit.add(baseRoute) && retryAfter < 60000;
// Update the bucket to the new information
bucket.remaining = 0;
bucket.reset = now + retryAfter;
// don't log warning if we hit the rate limit for the first time, likely due to initialization of the bucket
// unless its a long retry-after delay (more than a minute)
if (firstHit)
log.debug("Encountered 429 on route {} with bucket {} Retry-After: {} ms Scope: {}", baseRoute, bucket.bucketId, retryAfter, scope);
else
log.warn("Encountered 429 on route {} with bucket {} Retry-After: {} ms Scope: {}", baseRoute, bucket.bucketId, retryAfter, scope);
}
log.trace("Updated bucket {} to retry after {}", bucket.bucketId, bucket.reset - now);
return bucket;
}
// If hash is null this means we didn't get enough information to update a bucket
if (hash == null)
return bucket;
// Update the bucket parameters with new information
String limitHeader = headers.get(LIMIT_HEADER);
String remainingHeader = headers.get(REMAINING_HEADER);
String resetAfterHeader = headers.get(RESET_AFTER_HEADER);
String resetHeader = headers.get(RESET_HEADER);
// bucket.limit = (int) Math.max(1L, parseLong(limitHeader));
bucket.remaining = (int) parseLong(remainingHeader);
if (config.isRelative())
bucket.reset = now + parseDouble(resetAfterHeader);
else
bucket.reset = parseDouble(resetHeader);
log.trace("Updated bucket {} to ({}/{}, {})", bucket.bucketId, bucket.remaining, limitHeader, bucket.reset - now);
return bucket;
}
catch (Exception e)
{
Bucket bucket = getBucket(route);
log.error("Encountered Exception while updating a bucket. Route: {} Bucket: {} Code: {} Headers:\n{}",
route.getBaseRoute(), bucket, response.code(), response.headers(), e);
return bucket;
}
});
}
private abstract class Bucket implements Runnable
{
protected final String bucketId;
protected final Deque requests = new ConcurrentLinkedDeque<>();
protected long reset = 0;
protected int remaining = 1;
public Bucket(String bucketId)
{
this.bucketId = bucketId;
}
public boolean isUninit()
{
return bucketId.startsWith(UNINIT_BUCKET);
}
public void enqueue(Work request)
{
requests.addLast(request);
}
public void retry(Work request)
{
if (!moveRequest(request))
requests.addFirst(request);
}
public long getReset()
{
return reset;
}
public int getRemaining()
{
return remaining;
}
public abstract long getGlobalRateLimit(long now);
public long getRateLimit()
{
long now = getNow();
long global = getGlobalRateLimit(now);
// Check if the bucket reset time has expired
if (reset <= now)
{
// Update the remaining uses to the limit (we don't know better)
remaining = 1;
}
// If there are remaining requests we don't need to do anything, otherwise return backoff in milliseconds
return Math.max(global, remaining < 1 ? reset - now : 0L);
}
protected boolean isGlobalRateLimit()
{
return getGlobalRateLimit(getNow()) > 0;
}
protected void backoff()
{
// Schedule backoff if requests are not done
MiscUtil.locked(lock, () -> {
rateLimitQueue.remove(this);
if (!requests.isEmpty())
runBucket(this);
else if (isStopped)
buckets.remove(bucketId);
if (isStopped && buckets.isEmpty())
shutdown();
});
}
@Nonnull
public Queue getRequests()
{
return requests;
}
protected boolean moveRequest(Work request)
{
return MiscUtil.locked(lock, () ->
{
// Attempt moving request to correct bucket if it has been created
Bucket bucket = getBucket(request.getRoute());
if (bucket != this)
{
bucket.enqueue(request);
runBucket(bucket);
}
return bucket != this;
});
}
protected boolean execute(Work request)
{
try
{
Response response = request.execute();
if (response != null)
updateBucket(request.getRoute(), response);
if (!request.isDone())
retry(request);
}
catch (Throwable ex)
{
log.error("Encountered exception trying to execute request", ex);
if (ex instanceof Error)
throw (Error) ex;
return true;
}
return false;
}
public void run()
{
log.trace("Bucket {} is running {} requests", bucketId, requests.size());
while (!requests.isEmpty())
{
long rateLimit = getRateLimit();
if (rateLimit > 0L)
{
// We need to backoff since we ran out of remaining uses or hit the global rate limit
Work request = requests.peekFirst(); // this *should* not be null
String baseRoute = request != null ? request.getRoute().getBaseRoute().toString() : "N/A";
if (!isGlobalRateLimit() && rateLimit >= 1000 * 60 * 30) // 30 minutes
log.warn("Encountered long {} minutes Rate-Limit on route {}", TimeUnit.MILLISECONDS.toMinutes(rateLimit), baseRoute);
log.debug("Backing off {} ms for bucket {} on route {}", rateLimit, bucketId, baseRoute);
break;
}
Work request = requests.removeFirst();
if (request.isSkipped())
continue;
if (isUninit() && moveRequest(request))
continue;
if (execute(request)) break;
}
backoff();
}
@Override
public String toString()
{
return bucketId;
}
@Override
public int hashCode()
{
return bucketId.hashCode();
}
@Override
public boolean equals(Object obj)
{
if (obj == this)
return true;
if (!(obj instanceof Bucket))
return false;
return this.bucketId.equals(((Bucket) obj).bucketId);
}
}
private class ClassicBucket extends Bucket
{
public ClassicBucket(String bucketId)
{
super(bucketId);
}
@Override
public long getGlobalRateLimit(long now)
{
GlobalRateLimit holder = config.getGlobalRateLimit();
long global = Math.max(holder.getClassic(), holder.getCloudflare());
return global - now;
}
}
private class InteractionBucket extends Bucket
{
public InteractionBucket(@Nonnull String bucketId)
{
super(bucketId);
}
@Override
public long getGlobalRateLimit(long now)
{
// Only cloudflare bans apply to interactions
return config.getGlobalRateLimit().getCloudflare() - now;
}
}
}