All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.giffing.bucket4j.spring.boot.starter.config.aspect.RateLimitAspect Maven / Gradle / Ivy

There is a newer version: 0.12.7
Show newest version
package com.giffing.bucket4j.spring.boot.starter.config.aspect;

import com.giffing.bucket4j.spring.boot.starter.config.cache.SyncCacheResolver;
import com.giffing.bucket4j.spring.boot.starter.context.*;
import com.giffing.bucket4j.spring.boot.starter.context.properties.MethodProperties;
import com.giffing.bucket4j.spring.boot.starter.context.properties.Metrics;
import com.giffing.bucket4j.spring.boot.starter.context.properties.RateLimit;
import com.giffing.bucket4j.spring.boot.starter.service.RateLimitService;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Aspect
@Component
@RequiredArgsConstructor
@Slf4j
public class RateLimitAspect {

    private final RateLimitService rateLimitService;

    private final List methodProperties;

    private final SyncCacheResolver syncCacheResolver;

    private Map> rateLimitConfigResults = new HashMap<>();

    @PostConstruct
    public void init() {
        for(var methodProperty : methodProperties) {
            var proxyManagerWrapper = syncCacheResolver.resolve(methodProperty.getCacheName());
            var rateLimitConfig = RateLimitService.RateLimitConfig.builder()
                    .rateLimits(List.of(methodProperty.getRateLimit()))
                    .metricHandlers(List.of())
                    .executePredicates(Map.of())
                    .cacheName(methodProperty.getCacheName())
                    .configVersion(0)
                    .keyFunction((rl, sr) -> {
                        KeyFilter keyFilter = rateLimitService.getKeyFilter(sr.getRootObject().getName(), rl);
                        return keyFilter.key(sr);
                    })
                    .metrics(new Metrics())
                    .proxyWrapper(proxyManagerWrapper)
                    .build();
            var rateLimitConfigResult = rateLimitService.configureRateLimit(rateLimitConfig);
            rateLimitConfigResults.put(methodProperty.getName(), rateLimitConfigResult);
        }
    }

    @Pointcut("execution(public * *(..))")
    public void publicMethod() {}

    @Pointcut("@annotation(com.giffing.bucket4j.spring.boot.starter.context.RateLimiting)")
    private void methodsAnnotatedWithRateLimitAnnotation() {
    }

    @Pointcut("@within(com.giffing.bucket4j.spring.boot.starter.context.RateLimiting) && publicMethod()")
    private void classAnnotatedWithRateLimitAnnotation(){

    }

    @Around("methodsAnnotatedWithRateLimitAnnotation() || classAnnotatedWithRateLimitAnnotation()")
    public Object processMethodsAnnotatedWithRateLimitAnnotation(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();

        var ignoreRateLimitAnnotation = getAnnotationFromMethodOrClass(method, IgnoreRateLimiting.class);
        // if the class or method is annotated with IgnoreRateLimiting we will skip rate limiting
        if(ignoreRateLimitAnnotation != null){
            return joinPoint.proceed();
        }

        var rateLimitAnnotation = getAnnotationFromMethodOrClass(method, RateLimiting.class);

        Method fallbackMethod = null;
        if(rateLimitAnnotation.fallbackMethodName() != null) {
            var fallbackMethods = Arrays.stream(method.getDeclaringClass().getMethods())
                    .filter(p -> p.getName().equals(rateLimitAnnotation.fallbackMethodName()))
                    .toList();
            if(fallbackMethods.size() > 1) {
                throw new IllegalStateException("Found " + fallbackMethods.size() + " fallbackMethods for " + rateLimitAnnotation.fallbackMethodName());
            }
            if(!fallbackMethods.isEmpty()) {
                fallbackMethod = joinPoint.getTarget().getClass().getMethod(rateLimitAnnotation.fallbackMethodName(), ((MethodSignature) joinPoint.getSignature()).getParameterTypes());
            }
        }

        Map params = collectExpressionParameter(
                joinPoint.getArgs(),
                signature.getParameterNames());

        assertValidCacheName(rateLimitAnnotation);

        var annotationRateLimit = buildMainRateLimitConfiguration(rateLimitAnnotation);
        var rateLimitConfigResult = rateLimitConfigResults.get(rateLimitAnnotation.name());

        RateLimitConsumedResult consumedResult = performRateLimit(rateLimitConfigResult, method, params, annotationRateLimit);

        Object methodResult;

        if (consumedResult.allConsumed()) {
            // no rate limit - execute the surrounding method
            methodResult = joinPoint.proceed();
            performPostRateLimit(rateLimitConfigResult, method, methodResult);
        } else if (fallbackMethod != null){
            return fallbackMethod.invoke(joinPoint.getTarget(), joinPoint.getArgs());
        } else {
            throw new RateLimitException();
        }

        return methodResult;
    }

    private   R getAnnotationFromMethodOrClass(Method method, Class rateLimitingAnnotation) {
        R rateLimitAnnotation;
        if(method.getAnnotation(rateLimitingAnnotation) != null) {
            rateLimitAnnotation = method.getAnnotation(rateLimitingAnnotation);
        } else {
            rateLimitAnnotation = method.getDeclaringClass().getAnnotation(rateLimitingAnnotation);
        }
        return rateLimitAnnotation;
    }

    private static void performPostRateLimit(RateLimitService.RateLimitConfigresult rateLimitConfigResult, Method method, Object methodResult) {
        for (var rlc : rateLimitConfigResult.getPostRateLimitChecks()) {
            var result = rlc.rateLimit(method, methodResult);
            if (result != null) {
                log.debug("post-rate-limit;remaining-tokens:{}", result.getRateLimitResult().getRemainingTokens());
            }
        }
    }

    private static RateLimitConsumedResult performRateLimit(RateLimitService.RateLimitConfigresult rateLimitConfigResult, Method method, Map params, RateLimit annotationRateLimit) {
        boolean allConsumed = true;
        Long remainingLimit = null;
        for (RateLimitCheck rl : rateLimitConfigResult.getRateLimitChecks()) {
            var wrapper = rl.rateLimit(new ExpressionParams<>(method).addParams(params), annotationRateLimit);
            if (wrapper != null && wrapper.getRateLimitResult() != null) {
                var rateLimitResult = wrapper.getRateLimitResult();
                if (rateLimitResult.isConsumed()) {
                    remainingLimit = RateLimitService.getRemainingLimit(remainingLimit, rateLimitResult);
                } else {
                    allConsumed = false;
                    break;
                }
            }
        }
        if(allConsumed) {
            log.debug("rate-limit-remaining;limit:{}", remainingLimit);
        }
        return new RateLimitConsumedResult(allConsumed, remainingLimit);
    }

    private record RateLimitConsumedResult(boolean allConsumed, Long remainingLimit) {
    }

    /*
     * Uses the configuration of the annotation to crate a main RateLimit which overrides
     * the configuration from the property files.
     */
    private static RateLimit buildMainRateLimitConfiguration(RateLimiting rateLimitAnnotation) {
        var annotationRateLimit = new RateLimit();
        annotationRateLimit.setExecuteCondition(rateLimitAnnotation.executeCondition());
        annotationRateLimit.setCacheKey(rateLimitAnnotation.cacheKey());
        annotationRateLimit.setSkipCondition(rateLimitAnnotation.skipCondition());
        return annotationRateLimit;
    }

    private void assertValidCacheName(RateLimiting rateLimitAnnotation) {
        if(!rateLimitConfigResults.containsKey(rateLimitAnnotation.name())) {
            throw new IllegalStateException("Could not find cache " + rateLimitAnnotation.name());
        }
    }

    private static Map collectExpressionParameter(Object[] args, String[] parameterNames) {
        Map params = new HashMap<>();
        for (int i = 0; i< args.length; i++) {
            log.debug("expresion-params;name:{};arg:{}", parameterNames[i], args[i]);
            params.put(parameterNames[i], args[i]);
        }
        return params;
    }




}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy