io.github.dengchen2020.ratelimiter.RateLimiterInterceptor Maven / Gradle / Ivy
package io.github.dengchen2020.ratelimiter;
import io.github.dengchen2020.core.utils.IPUtils;
import io.github.dengchen2020.ratelimiter.annotation.RateLimit;
import io.github.dengchen2020.ratelimiter.annotation.RateLimitStrategy;
import io.github.dengchen2020.ratelimiter.exception.RateLimitException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import java.security.Principal;
import java.util.concurrent.TimeUnit;
/**
* 限流拦截器
*
* @author dengchen
* @since 2024/8/3
*/
public class RateLimiterInterceptor implements HandlerInterceptor {
private final RateLimiter secondRateLimiter;
private final RateLimiter minuteRateLimiter;
public RateLimiterInterceptor(RateLimiter secondRateLimiter, RateLimiter minuteRateLimiter) {
this.secondRateLimiter = secondRateLimiter;
this.minuteRateLimiter = minuteRateLimiter;
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
if(!(handler instanceof HandlerMethod handlerMethod)) return true;
RateLimit rateLimit = handlerMethod.getMethod().getAnnotation(RateLimit.class);
if (rateLimit == null) rateLimit = handlerMethod.getBeanType().getAnnotation(RateLimit.class);
if (rateLimit == null) return true;
RateLimitStrategy strategy = rateLimit.strategy();
String limitKey;
switch (strategy) {
case userAndUri -> {
Principal principal = request.getUserPrincipal();
if (principal == null) {
limitKey = IPUtils.getIpAddr(request) + request.getRequestURI() + request.getMethod();
} else {
limitKey = principal.getName() + request.getRequestURI() + request.getMethod();
}
}
case ip -> limitKey = IPUtils.getIpAddr(request);
case ipAndUri -> limitKey = IPUtils.getIpAddr(request) + request.getRequestURI() + request.getMethod();
case user -> {
Principal principal = request.getUserPrincipal();
if (principal == null) {
limitKey = IPUtils.getIpAddr(request);
} else {
limitKey = principal.getName();
}
}
case uri -> limitKey = request.getRequestURI() + request.getMethod();
case null, default -> limitKey = IPUtils.getIpAddr(request) + request.getRequestURI() + request.getMethod();
}
RateLimiter rateLimiter;
if (rateLimit.timeUnit() == TimeUnit.MINUTES) {
rateLimiter = minuteRateLimiter;
} else {
rateLimiter = secondRateLimiter;
}
if (rateLimiter.limit(limitKey, rateLimit.value())) throw new RateLimitException(rateLimit.errorMsg());
return true;
}
}