All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.networknt.limit.RateLimiter Maven / Gradle / Ivy

Go to download

A handler that does rate limit in order to prevent attacks for public facing services.

The newest version!
package com.networknt.limit;

import com.networknt.exception.FrameworkException;
import com.networknt.limit.key.KeyResolver;
import com.networknt.status.Status;
import com.networknt.utility.Constants;
import io.undertow.server.HttpServerExchange;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Instant;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

/**
 *  Rate limit logic for light-4j framework. The config will define in the limit.yml config file.
 *
 * By default, Rate limit will handle on the server(service) level. But framework support client and address level limitation
 *
 * @author Gavin Chen
 */
public class RateLimiter {
    private static final String LIMIT_KEY_NOT_FOUND = "ERR10073";
    protected LimitConfig config;

    private final Map> serverTimeMap = new ConcurrentHashMap<>();

    private final Map>> directTimeMap = new ConcurrentHashMap<>();
    private static final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
    static final String UNKNOWN_PREFIX = "/unknown/prefix"; // this is the bucket for all other request path that is not defined in the config
    static final String ADDRESS_TYPE = "address";
    static final String CLIENT_TYPE = "client";
    static final String USER_TYPE = "user";

    private KeyResolver clientIdKeyResolver;
    private KeyResolver addressKeyResolver;
    private KeyResolver userIdKeyResolver;


    /**
     * Load config and initial model by Rate limit key.
     * @param config LimitConfig object
     * @throws Exception runtime exception
     */
    public RateLimiter(LimitConfig config) throws Exception {
        this.config = config;
        if (LimitKey.SERVER.equals(config.getKey())) {
            if (this.config.getServer()!=null && !this.config.getServer().isEmpty()) {
                this.config.getServer().forEach((k,v)->serverTimeMap.put(k, new ConcurrentHashMap<>()));
            }
        } else if (LimitKey.ADDRESS.equals(config.getKey())) {
            if (this.config.getAddress()!=null) {
                if (this.config.getAddress().getDirectMaps()!=null && !this.config.getAddress().getDirectMaps().isEmpty()) {
                    this.config.getAddress().getDirectMaps().forEach((k,v)->{
                        Map> directMap = new ConcurrentHashMap<>();
                        v.forEach(i->{
                            directMap.put(i.getUnit(), new ConcurrentHashMap<>());
                        });
                        directTimeMap.put(k, directMap);
                    });
                }
            }
            String addressKey = this.config.getAddressKeyResolver()==null? "com.networknt.limit.key.RemoteAddressKeyResolver":this.config.getAddressKeyResolver();
            addressKeyResolver = (KeyResolver)Class.forName(addressKey).getDeclaredConstructor().newInstance();
        } else if (LimitKey.CLIENT.equals(config.getKey())) {
            if (this.config.getClient()!=null) {
                if (this.config.getClient().getDirectMaps()!=null && !this.config.getClient().getDirectMaps().isEmpty()) {
                    this.config.getClient().getDirectMaps().forEach((k,v)->{
                        Map> directMap = new ConcurrentHashMap<>();
                        v.forEach(i->{
                            directMap.put(i.getUnit(), new ConcurrentHashMap<>());
                        });
                        directTimeMap.put(k, directMap);
                    });
                }
            }
            String clientIdKey = this.config.getClientIdKeyResolver()==null? "com.networknt.limit.key.JwtClientIdKeyResolver":this.config.getClientIdKeyResolver();
            clientIdKeyResolver = (KeyResolver)Class.forName(clientIdKey).getDeclaredConstructor().newInstance();
        } else if (LimitKey.USER.equals(config.getKey())) {
            if (this.config.getUser()!=null) {
                if (this.config.getUser().getDirectMaps()!=null && !this.config.getUser().getDirectMaps().isEmpty()) {
                    this.config.getUser().getDirectMaps().forEach((k,v)->{
                        Map> directMap = new ConcurrentHashMap<>();
                        v.forEach(i->{
                            directMap.put(i.getUnit(), new ConcurrentHashMap<>());
                        });
                        directTimeMap.put(k, directMap);
                    });
                }
            }
            String userIdKey = this.config.getUserIdKeyResolver()==null? "com.networknt.limit.key.JwtUserIdKeyResolver":this.config.getUserIdKeyResolver();
            userIdKeyResolver = (KeyResolver)Class.forName(userIdKey).getDeclaredConstructor().newInstance();
        }
    }

    public RateLimitResponse handleRequest(final HttpServerExchange exchange, LimitKey limitKey) {
        if (LimitKey.ADDRESS.equals(limitKey)) {
            String address = addressKeyResolver.resolve(exchange);
            if(address == null) {
                logger.error("Failed to resolve the address with the address resolver " + addressKeyResolver.getClass().getPackageName());
                Status status = new Status(LIMIT_KEY_NOT_FOUND, LimitKey.ADDRESS, addressKeyResolver.toString());
                throw new FrameworkException(status);
            }
            String path = exchange.getRequestPath();
            return isAllowDirect(address, path, ADDRESS_TYPE);
        } else if (LimitKey.CLIENT.equals(limitKey)) {
            String clientId = clientIdKeyResolver.resolve(exchange);
            if(clientId == null) {
                logger.error("Failed to resolve the clientId with the clientId resolver " + addressKeyResolver.getClass().getPackageName()  + ". You must put the limit handler after the security handler in the request/response chain.");
                Status status = new Status(LIMIT_KEY_NOT_FOUND, LimitKey.CLIENT, clientIdKeyResolver.getClass().getPackageName());
                throw new FrameworkException(status);
            }
            String path = exchange.getRequestPath();
            return isAllowDirect(clientId, path, CLIENT_TYPE);
        } else if (LimitKey.USER.equals(limitKey)) {
            String userId = userIdKeyResolver.resolve(exchange);
            if(userId == null) {
                logger.error("Failed to resolve the userId with the userId resolver " + userIdKeyResolver.getClass().getPackageName()  + ". You must put the limit handler after the security handler in the request/response chain.");
                Status status = new Status(LIMIT_KEY_NOT_FOUND, LimitKey.USER, userIdKeyResolver.getClass().getPackageName());
                throw new FrameworkException(status);
            }
            String path = exchange.getRequestPath();
            return isAllowDirect(userId, path, USER_TYPE);
        } else  {
            //By default, the key is server
            String path = exchange.getRequestPath();
            return isAllowByServer(path);
        }
    }

    /**
     * Handle logic for direct rate limit setting for address, client and user.
     * Use the type for differential the address/client/user
     * @param directKey direct key
     * @param path String
     * @param type String
     * @return RateLimitResponse response
     */
    protected RateLimitResponse isAllowDirect(String directKey, String path, String type) {
        long currentTimeWindow = Instant.now().getEpochSecond();
        Map> localTimeMap;

        String keyWithPath = directKey + LimitConfig.SEPARATE_KEY + path;
        List rateLimit;
        String mapKey = directKey;
        if (ADDRESS_TYPE.equalsIgnoreCase(type)) {
            if (config.getAddress() != null && config.getAddress().directMaps.containsKey(keyWithPath)) {
                rateLimit = config.getAddress().directMaps.get(keyWithPath);
                mapKey = keyWithPath;
            } else if (config.getAddress() != null && config.getAddress().directMaps.containsKey(directKey)) {
                rateLimit = config.getAddress().directMaps.get(directKey);
            } else {
                if(logger.isTraceEnabled()) logger.trace("both keyWithPath and directKey not found in the config, use the default rate limit");
                rateLimit = config.rateLimit;
                // check if the map is already created
                synchronized(this) {
                    Map> directMap = directTimeMap.get(mapKey);
                    if (directMap == null) {
                        if (logger.isTraceEnabled())
                            logger.trace("directMap is null, create a new one for the key {}", mapKey);
                        Map> map = new ConcurrentHashMap<>();
                        this.config.getRateLimit().forEach(i -> {
                            map.put(i.getUnit(), new ConcurrentHashMap<>());
                        });
                        directTimeMap.put(mapKey, map);
                    }
                }
            }
        } else if(CLIENT_TYPE.equalsIgnoreCase(type)) {
            if (config.getClient() != null && config.getClient().directMaps.containsKey(keyWithPath)) {
                rateLimit = config.getClient().directMaps.get(keyWithPath);
                mapKey = keyWithPath;
            } else if (config.getClient() != null && config.getClient().directMaps.containsKey(directKey)) {
                rateLimit = config.getClient().directMaps.get(directKey);
            } else {
                rateLimit = config.rateLimit;
                synchronized(this) {
                    Map> directMap = directTimeMap.get(mapKey);
                    if(directMap == null) {
                        Map> map = new ConcurrentHashMap<>();
                        this.config.getRateLimit().forEach(i->{
                            map.put(i.getUnit(), new ConcurrentHashMap<>());
                        });
                        directTimeMap.put(mapKey, map);
                    }
                }
            }
        } else {
            if (config.getUser() != null && config.getUser().directMaps.containsKey(keyWithPath)) {
                rateLimit = config.getUser().directMaps.get(keyWithPath);
                mapKey = keyWithPath;
            } else if (config.getUser() != null && config.getUser().directMaps.containsKey(directKey)) {
                rateLimit = config.getUser().directMaps.get(directKey);
            } else {
                rateLimit = config.rateLimit;
                synchronized (this) {
                    Map> directMap = directTimeMap.get(mapKey);
                    if(directMap == null) {
                        Map> map = new ConcurrentHashMap<>();
                        this.config.getRateLimit().forEach(i->{
                            map.put(i.getUnit(), new ConcurrentHashMap<>());
                        });
                        directTimeMap.put(mapKey, map);
                    }
                }
            }
        }
        localTimeMap = directTimeMap.get(mapKey);
        synchronized(this) {
            for (LimitQuota limitQuota: rateLimit) {
                Map timeMap =  localTimeMap.get(limitQuota.getUnit());
                if (timeMap.isEmpty()) {
                    if(logger.isTraceEnabled()) logger.trace("timeMap is empty, put the first entry");
                    timeMap.put(currentTimeWindow, new AtomicLong(1L));
                    return new RateLimitResponse(true, null);
                } else {
                    Long countInOverallTime = removeOldEntriesForUser(currentTimeWindow, timeMap, limitQuota.unit);
                    if(logger.isTraceEnabled()) logger.trace("countInOverallTime: " + countInOverallTime + " limitQuota.value: " + limitQuota.value);
                    if (countInOverallTime < limitQuota.value) {
                        //Handle new time windows
                        Long newCount = timeMap.getOrDefault(currentTimeWindow, new AtomicLong(0)).longValue() + 1;
                        if(logger.isTraceEnabled()) logger.trace("newCount: " + newCount);
                        timeMap.put(currentTimeWindow, new AtomicLong(newCount));
                        logger.debug("CurrentTimeWindow:" + currentTimeWindow +" Result:true "+ " Count:"+countInOverallTime);
                        return new RateLimitResponse(true, null);
                    } else {
                        String reset = getRateLimitReset(currentTimeWindow, timeMap, limitQuota);
                        return new RateLimitResponse(false, buildHeaders(countInOverallTime, limitQuota, reset));
                    }
                }
            }
        }
        return null;
    }

    /**
     * Handle logic for Server type (key = server) rate limit
     * @param path String
     * @return RateLimitResponse rate limit response
     */
    public synchronized RateLimitResponse isAllowByServer(String path) {
        long currentTimeWindow = Instant.now().getEpochSecond();
        Map timeMap = lookupServerTimeMap(path);  // defined and unknown one if not defined.
        LimitQuota limitQuota = config.getServer() != null ? lookupLimitQuota(path) : null;
        if(limitQuota == null) {
            limitQuota = this.config.getRateLimit().get(0);
        }
        if (timeMap.isEmpty()) {
            timeMap.put(currentTimeWindow, new AtomicLong(1L));
        } else {
            Long countInOverallTime = removeOldEntriesForUser(currentTimeWindow, timeMap, limitQuota.unit);
            if (countInOverallTime < limitQuota.value) {
                //Handle new time windows
                Long newCount = timeMap.getOrDefault(currentTimeWindow, new AtomicLong(0)).longValue() + 1;
                timeMap.put(currentTimeWindow, new AtomicLong(newCount));
                logger.debug("CurrentTimeWindow:" + currentTimeWindow +" Result:true "+ " Count:"+countInOverallTime);
                return new RateLimitResponse(true, null);
            } else {
                String reset = getRateLimitReset(currentTimeWindow, timeMap, limitQuota);
                return new RateLimitResponse(false, buildHeaders(countInOverallTime, limitQuota, reset));
            }
        }
        return new RateLimitResponse(true, null);
    }

    private Map lookupServerTimeMap(String path) {
        String prefix = null;
        for(String s: serverTimeMap.keySet()) {
            if(path.startsWith(s)) {
                prefix = s;
                break;
            }
        }
        if(prefix == null) {
            // the request path is not in the defined path prefix. Use the default path prefix UNKNOWN_PREFIX.
            if(!serverTimeMap.containsKey(UNKNOWN_PREFIX)) {
                Map timeMap = new ConcurrentHashMap<>();
                serverTimeMap.put(UNKNOWN_PREFIX, timeMap);
                return timeMap;
            } else {
                return serverTimeMap.get(UNKNOWN_PREFIX);
            }
        } else {
            return serverTimeMap.get(prefix);
        }

    }

    private LimitQuota lookupLimitQuota(String path) {
        String prefix = null;
        for(String s: this.config.getServer().keySet()) {
            if(path.startsWith(s)) {
                prefix = s;
                break;
            }
        }
        if(prefix == null) {
            return null;
        } else {
            return this.config.getServer().get(prefix);
        }
    }

    private String getRateLimitReset(Long currentTimeWindow, Map timeMap,  LimitQuota limitQuota) {
        String res = null;
        if (TimeUnit.SECONDS.equals(limitQuota.unit)){
            res = "1s";
        } else {
            Optional firstKey = timeMap.keySet().stream().findFirst();
            long reset = getWindow(limitQuota.unit) + firstKey.get() - currentTimeWindow;
            res = reset + "s";
        }
        return res;
    }

    private Map buildHeaders(Long countInOverallTime, LimitQuota limitQuota, String reset) {
        Map headers = new HashMap<>();
        headers.put(Constants.RATELIMIT_LIMIT, limitQuota.value + "/" + limitQuota.unit);
        headers.put(Constants.RATELIMIT_REMAINING, String.valueOf(limitQuota.value - countInOverallTime));
        if (reset!=null) {
            headers.put(Constants.RATELIMIT_RESET, reset);
        }

        return headers;
    }

    private  long removeOldEntriesForUser( long currentTimeWindow, Map timeWindowVSCountMap, TimeUnit unit)
    {
        List  oldEntriesToBeDeleted=new ArrayList<>();
        long overallCount=0L;
        for (Long timeWindow : timeWindowVSCountMap.keySet()) {
            //Mark old entries (Entries older than the oldest valid time window within the time limit) for deletion
            if ((currentTimeWindow - timeWindow) >= getWindow(unit))
                oldEntriesToBeDeleted.add(timeWindow);
            else
                overallCount+=timeWindowVSCountMap.get(timeWindow).longValue();
        }
        timeWindowVSCountMap.keySet().removeAll(oldEntriesToBeDeleted);
        return overallCount;
    }

    private int getWindow(TimeUnit unit) {
        if (TimeUnit.DAYS.equals(unit)) {
            return 24*60*60;
        } else if (TimeUnit.HOURS.equals(unit)) {
            return 60*60;
        } else if (TimeUnit.MINUTES.equals(unit)) {
            return 60;
        } else {
            return 1;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy