All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.giffing.bucket4j.spring.boot.starter.service.RateLimitService Maven / Gradle / Ivy

There is a newer version: 0.12.7
Show newest version
package com.giffing.bucket4j.spring.boot.starter.service;


import com.giffing.bucket4j.spring.boot.starter.config.cache.ProxyManagerWrapper;
import com.giffing.bucket4j.spring.boot.starter.context.*;
import com.giffing.bucket4j.spring.boot.starter.context.metrics.MetricBucketListener;
import com.giffing.bucket4j.spring.boot.starter.context.metrics.MetricHandler;
import com.giffing.bucket4j.spring.boot.starter.context.metrics.MetricTagResult;
import com.giffing.bucket4j.spring.boot.starter.context.properties.*;
import com.giffing.bucket4j.spring.boot.starter.exception.ExecutePredicateInstantiationException;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.BucketConfiguration;
import io.github.bucket4j.ConfigurationBuilder;
import lombok.Builder;
import lombok.Data;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.InvocationTargetException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;

@Slf4j
@RequiredArgsConstructor
public class RateLimitService {

    private final ExpressionService expressionService;

    @Builder
    @Data
    public static class RateLimitConfig {
        @NonNull
        private List rateLimits;
        @NonNull
        private List metricHandlers;
        @NonNull
        private Map> executePredicates;
        @NonNull
        private String cacheName;
        @NonNull
        private ProxyManagerWrapper proxyWrapper;
        @NonNull
        private BiFunction, String> keyFunction;
        @NonNull
        private Metrics metrics;
        private long configVersion;
    }

    @Builder
    @Data
    public static class RateLimitConfigresult {
        private List> rateLimitChecks;
        private List> postRateLimitChecks;
    }

    public  RateLimitConfigresult configureRateLimit(RateLimitConfig rateLimitConfig) {


        var executePredicates = rateLimitConfig.getExecutePredicates();
        

        List> rateLimitChecks = new ArrayList<>();
        List> postRateLimitChecks = new ArrayList<>();
        rateLimitConfig.getRateLimits().forEach(rl -> {
            log.debug("RL: {}", rl.toString());
            var bucketConfiguration = prepareBucket4jConfigurationBuilder(rl).build();
            var executionPredicate = prepareExecutionPredicates(rl, executePredicates);
            var skipPredicate = prepareSkipPredicates(rl, executePredicates);

            RateLimitCheck rlc = (expressionParams, overridableRateLimit) -> {

                var rlToUse = rl.copy();
                rlToUse.consumeNotNullValues(overridableRateLimit);

                var skipRateLimit = performSkipRateLimitCheck(rlToUse, executionPredicate, skipPredicate, expressionParams);
                boolean isEstimation = rlToUse.getPostExecuteCondition() != null;
                RateLimitResultWrapper rateLimitResultWrapper = null;
                if (!skipRateLimit) {

                    rateLimitResultWrapper = tryConsume(rateLimitConfig, expressionParams, rlToUse, isEstimation, bucketConfiguration);
                }
                return rateLimitResultWrapper;
            };
            rateLimitChecks.add(rlc);


            if (rl.getPostExecuteCondition() != null) {
                log.debug("PRL: {}", rl);
                PostRateLimitCheck postRlc = (request, response) -> {
                    ExpressionParams expressionParams = new ExpressionParams<>(request);
                    var skipRateLimit = performPostSkipRateLimitCheck(rl,
                            executionPredicate,
                            skipPredicate,
                            expressionParams,
                            response);
                    boolean isEstimation = false;
                    RateLimitResultWrapper rateLimitResultWrapper = null;
                    if (!skipRateLimit) {
                        rateLimitResultWrapper = tryConsume(rateLimitConfig, expressionParams, rl, isEstimation, bucketConfiguration);
                    }
                    return rateLimitResultWrapper;
                };
                postRateLimitChecks.add(postRlc);

            }
        });

        return new RateLimitConfigresult<>(rateLimitChecks, postRateLimitChecks);
    }

    private  RateLimitResultWrapper tryConsume(RateLimitConfig rateLimitConfig, ExpressionParams expressionParams, RateLimit rlToUse, boolean isEstimation, BucketConfiguration bucketConfiguration) {
        RateLimitResultWrapper rateLimitResultWrapper;
        var metricHandlers = rateLimitConfig.getMetricHandlers();
        var cacheName = rateLimitConfig.getCacheName();
        var metrics = rateLimitConfig.getMetrics();
        var keyFunction = rateLimitConfig.getKeyFunction();
        var proxyWrapper = rateLimitConfig.getProxyWrapper();
        var configVersion = rateLimitConfig.getConfigVersion();

        var key = keyFunction.apply(rlToUse, expressionParams);
        var metricBucketListener = createMetricListener(cacheName, metrics, metricHandlers, expressionParams);
        log.debug("try-and-consume;key:{};tokens:{}", key, rlToUse.getNumTokens());
        rateLimitResultWrapper = proxyWrapper.tryConsumeAndReturnRemaining(
                key,
                rlToUse.getNumTokens(),
                isEstimation,
                bucketConfiguration,
                metricBucketListener,
                configVersion,
                rlToUse.getTokensInheritanceStrategy()
        );
        return rateLimitResultWrapper;
    }


    private  boolean performPostSkipRateLimitCheck(RateLimit rl,
                                                         Predicate executionPredicate,
                                                         Predicate skipPredicate,
                                                         ExpressionParams expressionParams,
                                                         P response
    ) {
        var skipRateLimit = performSkipRateLimitCheck(
                rl, executionPredicate,
                skipPredicate, expressionParams);

        if (!skipRateLimit && rl.getPostExecuteCondition() != null) {
            Condition

condition = exp -> expressionService.parseBoolean(rl.getPostExecuteCondition(), exp); skipRateLimit = !condition.evaluate(new ExpressionParams<>(response).addParams(expressionParams.getParams())); log.debug("skip-rate-limit - post-execute-condition: {}", skipRateLimit); } return skipRateLimit; } private boolean performSkipRateLimitCheck(RateLimit rl, Predicate executionPredicate, Predicate skipPredicate, ExpressionParams expressionParams) { boolean skipRateLimit = false; if (rl.getSkipCondition() != null) { Condition expresison = exp -> expressionService.parseBoolean(rl.getSkipCondition(), exp); skipRateLimit = expresison.evaluate(expressionParams); log.debug("skip-rate-limit - skip-condition: {}", skipRateLimit); } if (!skipRateLimit) { skipRateLimit = skipPredicate.test(expressionParams.getRootObject()); log.debug("skip-rate-limit - skip-predicates: {}", skipRateLimit); } if (!skipRateLimit && rl.getExecuteCondition() != null) { Condition condition = exp -> expressionService.parseBoolean(rl.getExecuteCondition(), exp); skipRateLimit = !condition.evaluate(expressionParams); log.debug("skip-rate-limit - execute-condition: {}", skipRateLimit); } if (!skipRateLimit) { skipRateLimit = !executionPredicate.test(expressionParams.getRootObject()); log.debug("skip-rate-limit - execute-predicates: {}", skipRateLimit); } return skipRateLimit; } public List getMetricTagResults(ExpressionParams expressionParams, Metrics metrics) { return metrics .getTags() .stream() .map(metricMetaTag -> { var value = expressionService.parseString(metricMetaTag.getExpression(), expressionParams); return new MetricTagResult(metricMetaTag.getKey(), value, metricMetaTag.getTypes()); }).toList(); } /** * Creates the key filter lambda which is responsible to decide how the rate limit will be performed. The key * is the unique identifier like an IP address or a username. * * @param url is used to generated a unique cache key * @param rateLimit the {@link RateLimit} configuration which holds the skip condition string * @return should not been null. If no filter key type is matching a plain 1 is returned so that all requests uses the same key. */ public KeyFilter getKeyFilter(String url, RateLimit rateLimit) { return expressionParams -> { String value = expressionService.parseString(rateLimit.getCacheKey(), expressionParams); return url + "-" + value; }; } private ConfigurationBuilder prepareBucket4jConfigurationBuilder(RateLimit rl) { var configBuilder = BucketConfiguration.builder(); for (BandWidth bandWidth : rl.getBandwidths()) { long capacity = bandWidth.getCapacity(); long refillCapacity = bandWidth.getRefillCapacity() != null ? bandWidth.getRefillCapacity() : bandWidth.getCapacity(); var refillPeriod = Duration.of(bandWidth.getTime(), bandWidth.getUnit()); var bucket4jBandWidth = switch (bandWidth.getRefillSpeed()) { case GREEDY -> Bandwidth.builder().capacity(capacity).refillGreedy(refillCapacity, refillPeriod).id(bandWidth.getId()); case INTERVAL -> Bandwidth.builder().capacity(capacity).refillIntervally(refillCapacity, refillPeriod).id(bandWidth.getId()); }; if (bandWidth.getInitialCapacity() != null) { bucket4jBandWidth = bucket4jBandWidth.initialTokens(bandWidth.getInitialCapacity()); } configBuilder = configBuilder.addLimit(bucket4jBandWidth.build()); } return configBuilder; } private MetricBucketListener createMetricListener(String cacheName, Metrics metrics, List metricHandlers, ExpressionParams expressionParams) { var metricTagResults = getMetricTags(metrics, expressionParams); return new MetricBucketListener( cacheName, metricHandlers, metrics.getTypes(), metricTagResults); } private List getMetricTags( Metrics metrics, ExpressionParams expressionParams) { return getMetricTagResults(expressionParams, metrics); } public void addDefaultMetricTags(Bucket4JBootProperties properties, Bucket4JConfiguration filter) { if (!properties.getDefaultMetricTags().isEmpty()) { var metricTags = filter.getMetrics().getTags(); var filterMetricTagKeys = metricTags .stream() .map(MetricTag::getKey) .collect(Collectors.toSet()); properties.getDefaultMetricTags().forEach(defaultTag -> { if (!filterMetricTagKeys.contains(defaultTag.getKey())) { metricTags.add(defaultTag); } }); } } private Predicate prepareExecutionPredicates(RateLimit rl, Map> executePredicates) { return rl.getExecutePredicates() .stream() .map(p -> createPredicate(p, executePredicates)) .reduce(Predicate::and) .orElseGet(() -> p -> true); } private Predicate prepareSkipPredicates(RateLimit rl, Map> executePredicates) { return rl.getSkipPredicates() .stream() .map(p -> createPredicate(p, executePredicates)) .reduce(Predicate::and) .orElseGet(() -> p -> false); } protected Predicate createPredicate(ExecutePredicateDefinition pd, Map> executePredicates) { var predicate = executePredicates.getOrDefault(pd.getName(), null); log.debug("create-predicate;name:{};value:{}", pd.getName(), pd.getArgs()); try { @SuppressWarnings("unchecked") ExecutePredicate newPredicateInstance = predicate.getClass().getDeclaredConstructor().newInstance(); return newPredicateInstance.init(pd.getArgs()); } catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException e) { throw new ExecutePredicateInstantiationException(pd.getName(), predicate.getClass()); } } public static long getRemainingLimit(Long remaining, RateLimitResult rateLimitResult) { if (rateLimitResult != null && (remaining == null || rateLimitResult.getRemainingTokens() < remaining)) { remaining = rateLimitResult.getRemainingTokens(); } return remaining; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy