com.github.wanggit.access.frequency.config.AccessFrequencyBeanFactoryPostProcessor Maven / Gradle / Ivy
package com.github.wanggit.access.frequency.config;
import com.github.wanggit.access.frequency.annotations.JoinToAccessFrequencyKey;
import com.github.wanggit.access.frequency.utils.MethodParameterUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Controller;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*;
import com.github.wanggit.access.frequency.annotations.AccessFrequency;
import com.github.wanggit.access.frequency.entity.AccessFrequencyMap;
import com.github.wanggit.access.frequency.entity.UrlParameter;
import com.github.wanggit.access.frequency.entity.UrlRate;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
public class AccessFrequencyBeanFactoryPostProcessor implements BeanFactoryPostProcessor {
private static final Logger logger = Logger.getLogger(AccessFrequencyBeanFactoryPostProcessor.class);
private Environment environment;
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory factory) throws BeansException {
environment = factory.getBean(Environment.class);
String[] controllers = factory.getBeanNamesForAnnotation(Controller.class);
String[] restControllers = factory.getBeanNamesForAnnotation(RestController.class);
for (String name : controllers) {
reduceName(name, factory);
}
for (String name : restControllers) {
reduceName(name, factory);
}
}
private Object getBean(String beanName, ConfigurableListableBeanFactory factory){
try {
return factory.getBean(beanName);
} catch (BeansException e) {
if (logger.isDebugEnabled()){
logger.debug(e.getMessage());
}
return null;
}
}
private void reduceName(String beanName, ConfigurableListableBeanFactory factory){
Object object = getBean(beanName, factory);
if (null == object){
return;
}
Class clazz = getOriginalClazz(object.getClass());
Method[] methods = ReflectionUtils.getAllDeclaredMethods(clazz);
// 获取Controller上配置RequestMapping的Path
RequestMapping requestMapping = AnnotationUtils.findAnnotation(clazz, RequestMapping.class);
String[] paths = null;
if (null != requestMapping){
paths = requestMapping.path();
}
for (Method method : methods) {
MethodMapping methodMapping = getMethodMappingPaths(method);
// 如果方法没有对外暴露接口,那么不处理此方法
if (null == methodMapping){
continue;
}
AccessFrequency accessFrequency = AnnotationUtils.findAnnotation(method, AccessFrequency.class);
if (null != accessFrequency){
List urlParameters = new ArrayList<>();
MethodParameter[] methodParameters = MethodParameterUtils.getMethodParameter(method);
for (int mp = 0; mp < methodParameters.length; mp++) {
UrlParameter urlParameter = getParameterName(methodParameters[mp]);
if (null != urlParameter){
urlParameters.add(urlParameter);
}
}
String[] urls = shuffle(paths, methodMapping);
for (String url : urls) {
UrlRate urlRate = new UrlRate();
urlRate.setMessage(accessFrequency.message());
urlRate.setParameters(urlParameters);
urlRate.setTimeout(accessFrequency.timeInterval());
urlRate.setTimes(accessFrequency.times());
urlRate.setTriggerWith(accessFrequency.triggerWith());
urlRate.setTimeUnit(accessFrequency.timeUnit());
urlRate.setCodePath(clazz.getName()+"#"+method.getName());
urlRate.setUrl(url);
AccessFrequencyMap.put(url, urlRate);
if (logger.isDebugEnabled()){
logger.debug("Control Access Frequency " + urlRate.getUrl());
}
}
}
}
}
private UrlParameter getParameterName(MethodParameter methodParameter){
JoinToAccessFrequencyKey accessFrequencyKey = methodParameter.getParameterAnnotation(JoinToAccessFrequencyKey.class);
UrlParameter urlParameter = null;
if (null != accessFrequencyKey){
UrlParameter.Type type = null;
// 从各个可能的注解处获取key
String key = null;
// 如果没有配置,那么查看RequestParam注解是否指定
if (!StringUtils.hasLength(key)){
RequestParam requestParam = methodParameter.getParameterAnnotation(RequestParam.class);
if (null != requestParam){
key = requestParam.value();
if (StringUtils.hasLength(key)){
type = UrlParameter.Type.REQUEST;
}
}
}
// 如果RequestParam也没有配置,那么使用
if (!StringUtils.hasLength(key)){
PathVariable pathVariable = methodParameter.getParameterAnnotation(PathVariable.class);
if (null != pathVariable){
key = pathVariable.value();
if (StringUtils.hasLength(key)){
type = UrlParameter.Type.PATH_VAR;
}
}
}
// 查看是否获取Cookie的数据
if (!StringUtils.hasLength(key)){
CookieValue cookieValue = methodParameter.getParameterAnnotation(CookieValue.class);
if (null != cookieValue){
key = cookieValue.value();
if (StringUtils.hasLength(key)){
type = UrlParameter.Type.COOKIE;
}
}
}
// 最后没有注解配置参数名称,那么直接获取参数的名称
if (!StringUtils.hasLength(key)){
key = methodParameter.getParameterName();
}
// 默认为Request
if (null == type){
type = UrlParameter.Type.REQUEST;
}
urlParameter = new UrlParameter(key, type);
}
return urlParameter;
}
/**
* 获取应用server.contextPath
* @return
*/
private String getServerContextPath(){
String contextPath = environment.getProperty("server.contextPath");
if (!StringUtils.hasText(contextPath)){
contextPath = environment.getProperty("server.context-path");
if (!StringUtils.hasText(contextPath)){
contextPath = "";
}
}
return contextPath;
}
private String[] shuffle(String[] paths, MethodMapping methodMapping){
List list = new ArrayList<>();
String contextPath = getServerContextPath();
String[] mpaths = methodMapping.getAllPaths();
if (null != paths){
for (String cpath : paths) {
if (!cpath.startsWith("/")){
cpath = "/" + cpath;
}
for (String mpath : mpaths) {
String url = contextPath + cpath + mpath;
list.add(url);
}
}
}else {
for (String mpath : mpaths) {
String url = contextPath + mpath;
list.add(url);
}
}
return list.toArray(new String[]{});
}
private MethodMapping getMethodMappingPaths(Method method){
GetMapping getMapping = AnnotationUtils.findAnnotation(method, GetMapping.class);
if (null != getMapping){
return new MethodMapping(getMapping.path(), RequestMethod.GET);
}
PostMapping postMapping = AnnotationUtils.findAnnotation(method, PostMapping.class);
if (null != postMapping){
return new MethodMapping(postMapping.path(), RequestMethod.POST);
}
PutMapping putMapping = AnnotationUtils.findAnnotation(method, PutMapping.class);
if (null != putMapping){
return new MethodMapping(putMapping.path(), RequestMethod.PUT);
}
DeleteMapping deleteMapping = AnnotationUtils.findAnnotation(method, DeleteMapping.class);
if (null != deleteMapping){
return new MethodMapping(deleteMapping.path(), RequestMethod.DELETE);
}
PatchMapping patchMapping = AnnotationUtils.findAnnotation(method, PatchMapping.class);
if (null != patchMapping){
return new MethodMapping(patchMapping.path(), RequestMethod.PATCH);
}
RequestMapping methodRequestMapping = AnnotationUtils.findAnnotation(method, RequestMapping.class);
if (null != methodRequestMapping){
return new MethodMapping(methodRequestMapping.path());
}
return null;
}
/**
* Spring环境下存在代理类,这时要找到他最原始的类
* @param clazz
* @return
*/
private Class getOriginalClazz(Class clazz){
while (ClassUtils.isCglibProxyClass(clazz)){
clazz = clazz.getSuperclass();
}
return clazz;
}
static class MethodMapping{
private String[] paths;
private RequestMethod[] requestMethods;
MethodMapping(String[] paths, RequestMethod...rms){
this.paths = paths;
if (null != rms && rms.length > 0){
this.requestMethods = rms;
}else {
this.requestMethods = RequestMethod.values();
}
}
public String[] getAllPaths(){
List list = new ArrayList<>();
for (RequestMethod requestMethod : requestMethods) {
for (String path : paths) {
if (!path.startsWith("/")){
path = "/" + path;
}
list.add(path + " " + requestMethod.name());
}
}
return list.toArray(new String[]{});
}
}
}