All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.wanggit.access.frequency.config.AccessFrequencyBeanFactoryPostProcessor Maven / Gradle / Ivy

package com.github.wanggit.access.frequency.config;

import com.github.wanggit.access.frequency.annotations.JoinToAccessFrequencyKey;
import com.github.wanggit.access.frequency.utils.MethodParameterUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Controller;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*;
import com.github.wanggit.access.frequency.annotations.AccessFrequency;
import com.github.wanggit.access.frequency.entity.AccessFrequencyMap;
import com.github.wanggit.access.frequency.entity.UrlParameter;
import com.github.wanggit.access.frequency.entity.UrlRate;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

public class AccessFrequencyBeanFactoryPostProcessor implements BeanFactoryPostProcessor {

    private static final Logger logger = Logger.getLogger(AccessFrequencyBeanFactoryPostProcessor.class);

    private Environment environment;

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory factory) throws BeansException {
        environment = factory.getBean(Environment.class);
        String[] controllers = factory.getBeanNamesForAnnotation(Controller.class);
        String[] restControllers = factory.getBeanNamesForAnnotation(RestController.class);
        for (String name : controllers) {
            reduceName(name, factory);
        }
        for (String name : restControllers) {
            reduceName(name, factory);
        }
    }

    private Object getBean(String beanName, ConfigurableListableBeanFactory factory){
        try {
            return factory.getBean(beanName);
        } catch (BeansException e) {
            if (logger.isDebugEnabled()){
                logger.debug(e.getMessage());
            }
            return null;
        }
    }

    private void reduceName(String beanName, ConfigurableListableBeanFactory factory){
        Object object = getBean(beanName, factory);
        if (null == object){
            return;
        }
        Class clazz = getOriginalClazz(object.getClass());
        Method[] methods = ReflectionUtils.getAllDeclaredMethods(clazz);
        // 获取Controller上配置RequestMapping的Path
        RequestMapping requestMapping = AnnotationUtils.findAnnotation(clazz, RequestMapping.class);
        String[] paths = null;
        if (null != requestMapping){
            paths = requestMapping.path();
        }
        for (Method method : methods) {
            MethodMapping methodMapping = getMethodMappingPaths(method);
            // 如果方法没有对外暴露接口,那么不处理此方法
            if (null == methodMapping){
                continue;
            }
            AccessFrequency accessFrequency = AnnotationUtils.findAnnotation(method, AccessFrequency.class);
            if (null != accessFrequency){
                List urlParameters = new ArrayList<>();
                MethodParameter[] methodParameters = MethodParameterUtils.getMethodParameter(method);
                for (int mp = 0; mp < methodParameters.length; mp++) {
                    UrlParameter urlParameter = getParameterName(methodParameters[mp]);
                    if (null != urlParameter){
                        urlParameters.add(urlParameter);
                    }
                }
                String[] urls = shuffle(paths, methodMapping);
                for (String url : urls) {
                    UrlRate urlRate = new UrlRate();
                    urlRate.setMessage(accessFrequency.message());
                    urlRate.setParameters(urlParameters);
                    urlRate.setTimeout(accessFrequency.timeInterval());
                    urlRate.setTimes(accessFrequency.times());
                    urlRate.setTriggerWith(accessFrequency.triggerWith());
                    urlRate.setTimeUnit(accessFrequency.timeUnit());
                    urlRate.setCodePath(clazz.getName()+"#"+method.getName());
                    urlRate.setUrl(url);
                    AccessFrequencyMap.put(url, urlRate);
                    if (logger.isDebugEnabled()){
                        logger.debug("Control Access Frequency " + urlRate.getUrl());
                    }
                }

            }
        }
    }

    private UrlParameter getParameterName(MethodParameter methodParameter){
        JoinToAccessFrequencyKey accessFrequencyKey = methodParameter.getParameterAnnotation(JoinToAccessFrequencyKey.class);
        UrlParameter urlParameter = null;
        if (null != accessFrequencyKey){
            UrlParameter.Type type = null;
            // 从各个可能的注解处获取key
            String key = null;
            // 如果没有配置,那么查看RequestParam注解是否指定
            if (!StringUtils.hasLength(key)){
                RequestParam requestParam = methodParameter.getParameterAnnotation(RequestParam.class);
                if (null != requestParam){
                    key = requestParam.value();
                    if (StringUtils.hasLength(key)){
                        type = UrlParameter.Type.REQUEST;
                    }
                }
            }
            // 如果RequestParam也没有配置,那么使用
            if (!StringUtils.hasLength(key)){
                PathVariable pathVariable = methodParameter.getParameterAnnotation(PathVariable.class);
                if (null != pathVariable){
                    key = pathVariable.value();
                    if (StringUtils.hasLength(key)){
                        type = UrlParameter.Type.PATH_VAR;
                    }
                }
            }
            // 查看是否获取Cookie的数据
            if (!StringUtils.hasLength(key)){
                CookieValue cookieValue = methodParameter.getParameterAnnotation(CookieValue.class);
                if (null != cookieValue){
                    key = cookieValue.value();
                    if (StringUtils.hasLength(key)){
                        type = UrlParameter.Type.COOKIE;
                    }
                }
            }

            // 最后没有注解配置参数名称,那么直接获取参数的名称
            if (!StringUtils.hasLength(key)){
                key = methodParameter.getParameterName();
            }
            // 默认为Request
            if (null == type){
                type = UrlParameter.Type.REQUEST;
            }
            urlParameter = new UrlParameter(key, type);
        }
        return urlParameter;
    }

    /**
     * 获取应用server.contextPath
     * @return
     */
    private String getServerContextPath(){
        String contextPath = environment.getProperty("server.contextPath");
        if (!StringUtils.hasText(contextPath)){
            contextPath = environment.getProperty("server.context-path");
            if (!StringUtils.hasText(contextPath)){
                contextPath = "";
            }
        }
        return contextPath;
    }

    private String[] shuffle(String[] paths, MethodMapping methodMapping){
        List list = new ArrayList<>();
        String contextPath = getServerContextPath();
        String[] mpaths = methodMapping.getAllPaths();
        if (null != paths){
            for (String cpath : paths) {
                if (!cpath.startsWith("/")){
                    cpath = "/" + cpath;
                }
                for (String mpath : mpaths) {
                    String url = contextPath + cpath + mpath;
                    list.add(url);
                }
            }
        }else {
            for (String mpath : mpaths) {
                String url = contextPath + mpath;
                list.add(url);
            }
        }
        return list.toArray(new String[]{});
    }

    private MethodMapping getMethodMappingPaths(Method method){
        GetMapping getMapping = AnnotationUtils.findAnnotation(method, GetMapping.class);
        if (null != getMapping){
            return new MethodMapping(getMapping.path(), RequestMethod.GET);
        }
        PostMapping postMapping = AnnotationUtils.findAnnotation(method, PostMapping.class);
        if (null != postMapping){
            return new MethodMapping(postMapping.path(), RequestMethod.POST);
        }
        PutMapping putMapping = AnnotationUtils.findAnnotation(method, PutMapping.class);
        if (null != putMapping){
            return new MethodMapping(putMapping.path(), RequestMethod.PUT);
        }
        DeleteMapping deleteMapping = AnnotationUtils.findAnnotation(method, DeleteMapping.class);
        if (null != deleteMapping){
            return new MethodMapping(deleteMapping.path(), RequestMethod.DELETE);
        }
        PatchMapping patchMapping = AnnotationUtils.findAnnotation(method, PatchMapping.class);
        if (null != patchMapping){
            return new MethodMapping(patchMapping.path(), RequestMethod.PATCH);
        }
        RequestMapping methodRequestMapping = AnnotationUtils.findAnnotation(method, RequestMapping.class);
        if (null != methodRequestMapping){
            return new MethodMapping(methodRequestMapping.path());
        }
        return null;
    }

    /**
     * Spring环境下存在代理类,这时要找到他最原始的类
     * @param clazz
     * @return
     */
    private Class getOriginalClazz(Class clazz){
        while (ClassUtils.isCglibProxyClass(clazz)){
            clazz = clazz.getSuperclass();
        }
        return clazz;
    }

    static class MethodMapping{
        private String[] paths;
        private RequestMethod[] requestMethods;
        MethodMapping(String[] paths, RequestMethod...rms){
            this.paths = paths;
            if (null != rms && rms.length > 0){
                this.requestMethods = rms;
            }else {
                this.requestMethods = RequestMethod.values();
            }
        }

        public String[] getAllPaths(){
            List list = new ArrayList<>();
            for (RequestMethod requestMethod : requestMethods) {
                for (String path : paths) {
                    if (!path.startsWith("/")){
                        path = "/" + path;
                    }
                    list.add(path + " " + requestMethod.name());
                }
            }
            return list.toArray(new String[]{});
        }

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy