org.test4j.module.spring.SpringEnv Maven / Gradle / Ivy
package org.test4j.module.spring;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.test4j.Context;
import org.test4j.annotations.BeforeSpring;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static org.test4j.mock.faking.util.ReflectUtility.doThrow;
/**
* Spring环境设置
*
* @author darui.wu
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public class SpringEnv {
private static final Map classIsSpring = new ConcurrentHashMap<>();
private static final String[] springTestAnnotations = {
"org.springframework.boot.test.context.SpringBootTest",
"org.springframework.test.context.ContextConfiguration",
"org.test4j.integration.spring.SpringContext"
};
public static boolean isSpringEnv() {
return isSpringEnv(Context.currTestClass());
}
public static void setSpringEnv(Class> clazz) {
classIsSpring.put(clazz, isSpringTest(clazz));
}
public static boolean isSpringEnv(Class clazz) {
return clazz != null && classIsSpring.get(clazz) != null && classIsSpring.get(clazz);
}
private static boolean isSpringTest(Class aClass) {
if (aClass == null) {
return false;
}
for (String annotation : springTestAnnotations) {
boolean hasAnnotation = hasAnnotation(aClass, annotation);
if (hasAnnotation) {
return true;
}
}
return false;
}
static Map HasAnnotation = new HashMap<>(5);
private static boolean hasAnnotation(Class objectClass, String annotation) {
Class annotationClass = getAnnotationClass(annotation);
if (annotationClass == null) {
return false;
} else {
Annotation instance = getClassLevelAnnotation(annotationClass, objectClass);
return instance != null;
}
}
private static Class> getAnnotationClass(String annotation) {
if (!HasAnnotation.containsKey(annotation)) {
try {
Class clazz = Class.forName(annotation);
HasAnnotation.put(annotation, clazz);
return clazz;
} catch (ClassNotFoundException e) {
HasAnnotation.put(annotation, null);
return null;
}
}
return HasAnnotation.get(annotation);
}
private static T getClassLevelAnnotation(Class annotationClass, Class clazz) {
Class superClass = clazz;
while (!Object.class.equals(superClass)) {
T annotation = (T) clazz.getAnnotation(annotationClass);
if (annotation != null) {
return annotation;
}
superClass = superClass.getSuperclass();
}
return null;
}
/**
* 获得当前测试类spring容器中名称为beanName的spring bean
*
* @param beanName bean name
* @return bean object
*/
public static T getBeanByName(String beanName) {
return SpringInit.getBeanByName(beanName);
}
public static T getBeanByType(Class beanType) {
return SpringInit.getBeanByType(beanType);
}
public static void injectSpringBeans(Object testedObject) {
if (!SpringEnv.isSpringEnv()) {
return;
}
SpringInit.injectSpringBeans(testedObject);
}
/**
* 用来在test4j初始化之前工作
* 比如spring加载前的mock工作等
*
* @param testedClass 测试类
*/
public static void invokeBeforeSpringMethod(Class testedClass) {
Method[] methods = testedClass.getMethods();
for (Method method : methods) {
if (method.getParameterCount() != 0 || !Modifier.isStatic(method.getModifiers())) {
continue;
}
if (method.getAnnotation(BeforeSpring.class) != null) {
method.setAccessible(true);
try {
method.invoke(null);
} catch (Exception e) {
doThrow(e);
}
}
}
}
/**
* 仅仅是转调用, 避免SpringEnv直接依赖spring class
*
* @param testInstance tested object
* @param context junit context
* @throws Exception exception
*/
public static void doSpringInitial(Object testInstance, ExtensionContext context) throws Exception {
SpringInit.doSpringInitial(testInstance, context);
}
}