All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.test4j.module.spring.SpringEnv Maven / Gradle / Ivy

package org.test4j.module.spring;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.test4j.Context;
import org.test4j.annotations.BeforeSpring;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static org.test4j.mock.faking.util.ReflectUtility.doThrow;

/**
 * Spring环境设置
 *
 * @author darui.wu
 */
@SuppressWarnings({"rawtypes", "unchecked"})
public class SpringEnv {
    private static final Map classIsSpring = new ConcurrentHashMap<>();

    private static final String[] springTestAnnotations = {
        "org.springframework.boot.test.context.SpringBootTest",
        "org.springframework.test.context.ContextConfiguration",
        "org.test4j.integration.spring.SpringContext"
    };

    public static boolean isSpringEnv() {
        return isSpringEnv(Context.currTestClass());
    }

    public static void setSpringEnv(Class clazz) {
        classIsSpring.put(clazz, isSpringTest(clazz));
    }


    public static boolean isSpringEnv(Class clazz) {
        return clazz != null && classIsSpring.get(clazz) != null && classIsSpring.get(clazz);
    }

    private static boolean isSpringTest(Class aClass) {
        if (aClass == null) {
            return false;
        }
        for (String annotation : springTestAnnotations) {
            boolean hasAnnotation = hasAnnotation(aClass, annotation);
            if (hasAnnotation) {
                return true;
            }
        }
        return false;
    }

    static Map HasAnnotation = new HashMap<>(5);

    private static boolean hasAnnotation(Class objectClass, String annotation) {
        Class annotationClass = getAnnotationClass(annotation);
        if (annotationClass == null) {
            return false;
        } else {
            Annotation instance = getClassLevelAnnotation(annotationClass, objectClass);
            return instance != null;
        }
    }

    private static Class getAnnotationClass(String annotation) {
        if (!HasAnnotation.containsKey(annotation)) {
            try {
                Class clazz = Class.forName(annotation);
                HasAnnotation.put(annotation, clazz);
                return clazz;
            } catch (ClassNotFoundException e) {
                HasAnnotation.put(annotation, null);
                return null;
            }
        }
        return HasAnnotation.get(annotation);
    }


    private static  T getClassLevelAnnotation(Class annotationClass, Class clazz) {
        Class superClass = clazz;
        while (!Object.class.equals(superClass)) {
            T annotation = (T) clazz.getAnnotation(annotationClass);
            if (annotation != null) {
                return annotation;
            }
            superClass = superClass.getSuperclass();
        }
        return null;
    }


    /**
     * 获得当前测试类spring容器中名称为beanName的spring bean
     *
     * @param beanName bean name
     * @return bean object
     */
    public static  T getBeanByName(String beanName) {
        return SpringInit.getBeanByName(beanName);
    }

    public static  T getBeanByType(Class beanType) {
        return SpringInit.getBeanByType(beanType);
    }

    public static void injectSpringBeans(Object testedObject) {
        if (!SpringEnv.isSpringEnv()) {
            return;
        }
        SpringInit.injectSpringBeans(testedObject);
    }

    /**
     * 用来在test4j初始化之前工作
* 比如spring加载前的mock工作等 * * @param testedClass 测试类 */ public static void invokeBeforeSpringMethod(Class testedClass) { Method[] methods = testedClass.getMethods(); for (Method method : methods) { if (method.getParameterCount() != 0 || !Modifier.isStatic(method.getModifiers())) { continue; } if (method.getAnnotation(BeforeSpring.class) != null) { method.setAccessible(true); try { method.invoke(null); } catch (Exception e) { doThrow(e); } } } } /** * 仅仅是转调用, 避免SpringEnv直接依赖spring class * * @param testInstance tested object * @param context junit context * @throws Exception exception */ public static void doSpringInitial(Object testInstance, ExtensionContext context) throws Exception { SpringInit.doSpringInitial(testInstance, context); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy