org.test4j.module.spring.SpringInit Maven / Gradle / Ivy
package org.test4j.module.spring;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextManager;
import org.test4j.Context;
import java.lang.ref.WeakReference;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import static org.test4j.integration.junit5.JUnit5SpringHelper.getTestContextManager;
import static org.test4j.module.spring.SpringEnv.invokeSpringInitMethod;
/**
* 和SpringEnv分开, 无spring依赖时避免NoClassDefFoundError异常
*
* @author wudarui
*/
public class SpringInit {
/**
* key: 测试类, value:AbstractApplicationContext实例
*/
private static Map> springBeanFactories = new HashMap<>();
/**
* spring事务管理
*/
private static ThreadLocal springTestContextManager = new ThreadLocal<>();
/**
* 获取当前测试实例的spring容器
*
* @return
*/
public static Optional getSpringContext() {
WeakReference reference = springBeanFactories.get(Context.currTestClass());
if (reference == null || reference.get() == null) {
return Optional.empty();
} else {
return Optional.of(reference.get());
}
}
/**
* 获得当前测试类spring容器中名称为beanName的spring bean
*
* @param beanName
* @return
*/
static T getBeanByName(String beanName) {
Object bean = getSpringContext().map(c -> {
try {
return c.getBean(beanName);
} catch (NoSuchBeanDefinitionException e) {
return null;
}
}).orElse(null);
return (T) bean;
}
static T getBeanByType(Class beanType) {
Object bean = getSpringContext().map(c -> {
try {
return c.getBean(beanType);
} catch (NoSuchBeanDefinitionException e) {
return null;
}
}).orElse(null);
return (T) bean;
}
/**
* 设置当前测试实例的spring容器
*
* @param context
*/
public static void setSpringContext(Class testClass, ApplicationContext context) {
springBeanFactories.put(testClass, new WeakReference<>(context));
}
static void injectSpringBeans(Object testedObject) {
if (!SpringEnv.isSpringEnv()) {
return;
}
AutowireCapableBeanFactory beanFactory = getSpringContext().get().getAutowireCapableBeanFactory();
beanFactory.autowireBeanProperties(testedObject, AutowireCapableBeanFactory.AUTOWIRE_NO, false);
beanFactory.initializeBean(testedObject, testedObject.getClass().getSimpleName());
}
/**
* 在测试spring容器启动前后执行
* 1. 执行@BeforeSpringContext 方法
* 2. 初始化测试实例注入
* 3. 注册spring容器
*
* @param testInstance
* @param context
* @throws Exception
*/
public static void doSpringInitial(Object testInstance, ExtensionContext context) throws Exception {
TestContextManager contextManager = getTestContextManager(context);
SpringInit.doSpringInitial(testInstance, contextManager);
}
/**
* 在测试spring容器启动前后执行
* 1. 执行@BeforeSpringContext 方法
* 2. 初始化测试实例注入
* 3. 注册spring容器
*
* @param testInstance
* @param contextManager
* @throws Exception
*/
public static void doSpringInitial(Object testInstance, TestContextManager contextManager) throws Exception {
invokeSpringInitMethod(testInstance);
springTestContextManager.set(contextManager);
contextManager.prepareTestInstance(testInstance);
ApplicationContext applicationContext = getApplicationContext(contextManager);
SpringInit.setSpringContext(testInstance.getClass(), applicationContext);
}
/**
* 有些版本getTestContext禁止访问,所以这里反射调用
*
* @param contextManager
* @return
*/
public static ApplicationContext getApplicationContext(TestContextManager contextManager) {
try {
Method method = TestContextManager.class.getMethod("getTestContext");
method.setAccessible(true);
TestContext testContext = (TestContext) method.invoke(contextManager);
return testContext.getApplicationContext();
} catch (Exception e) {
throw new RuntimeException("get Spring Application Context error: " + e.getMessage(), e);
}
}
public static TestContextManager getSpringTestContextManager() {
return springTestContextManager.get();
}
}