org.test4j.module.spring.SpringEnv Maven / Gradle / Ivy
package org.test4j.module.spring;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.test4j.Context;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static org.test4j.mock.faking.util.ReflectUtility.doThrow;
/**
* Spring环境设置
*
* @author darui.wu
*/
public class SpringEnv {
private static Map classIsSpring = new ConcurrentHashMap<>();
private static String[] springTestAnnotations = {
"org.springframework.boot.test.context.SpringBootTest",
"org.springframework.test.context.ContextConfiguration"
};
public static boolean isSpringEnv() {
return isSpringEnv(Context.currTestClass());
}
public static void setSpringEnv(Class> clazz) {
classIsSpring.put(clazz, isSpringTest(clazz));
}
public static boolean isSpringEnv(Class clazz) {
return clazz == null || classIsSpring.get(clazz) == null ? false : classIsSpring.get(clazz);
}
private static boolean isSpringTest(Class aClass) {
if (aClass == null) {
return false;
}
for (String annotation : springTestAnnotations) {
boolean hasAnnotation = hasAnnotation(aClass, annotation);
if (hasAnnotation) {
return true;
}
}
return false;
}
static Map HasAnnotation = new HashMap<>(5);
private static boolean hasAnnotation(Class objectClass, String annotation) {
Class annotationClass = getAnnotationClass(annotation);
if (annotationClass == null) {
return false;
} else {
Annotation instance = getClassLevelAnnotation(annotationClass, objectClass);
return instance != null;
}
}
private static Class> getAnnotationClass(String annotation) {
if (!HasAnnotation.containsKey(annotation)) {
try {
Class clazz = Class.forName(annotation);
HasAnnotation.put(annotation, clazz);
return clazz;
} catch (ClassNotFoundException e) {
HasAnnotation.put(annotation, null);
return null;
}
}
return HasAnnotation.get(annotation);
}
private static T getClassLevelAnnotation(Class annotationClass, Class clazz) {
Class superClass = clazz;
while (!Object.class.equals(superClass)) {
T annotation = (T) clazz.getAnnotation(annotationClass);
if (annotation != null) {
return annotation;
}
superClass = superClass.getSuperclass();
}
return null;
}
/**
* 获得当前测试类spring容器中名称为beanName的spring bean
*
* @param beanName
* @return
*/
public static T getBeanByName(String beanName) {
return SpringInit.getBeanByName(beanName);
}
public static T getBeanByType(Class beanType) {
return SpringInit.getBeanByType(beanType);
}
public static void injectSpringBeans(Object testedObject) {
if (!SpringEnv.isSpringEnv()) {
return;
}
SpringInit.injectSpringBeans(testedObject);
}
/**
* 用来在test4j初始化之前工作
* 比如spring加载前的mock工作等
*
* @param test 测试类实例
*/
public static void invokeSpringInitMethod(Object test) {
Method[] methods = test.getClass().getDeclaredMethods();
for (Method method : methods) {
if (method.getParameterCount() == 0 && method.getAnnotation(BeforeSpringContext.class) != null) {
method.setAccessible(true);
try {
method.invoke(test);
} catch (Exception e) {
doThrow(e);
}
}
}
}
/**
* 在测试spring容器启动前后执行
* 1. 执行@BeforeSpringContext 方法
* 2. 初始化测试实例注入
* 3. 注册spring容器
*
* @param testInstance
* @param context
* @throws Exception
*/
public static void doSpringInitial(Object testInstance, ExtensionContext context) throws Exception {
SpringInit.doSpringInitial(testInstance, context);
}
}