com.github.database.rider.junit5.DBUnitExtension Maven / Gradle / Ivy
package com.github.database.rider.junit5;
import com.github.database.rider.core.RiderRunner;
import com.github.database.rider.core.RiderTestContext;
import com.github.database.rider.core.api.configuration.DBUnit;
import com.github.database.rider.core.api.configuration.DataSetMergingStrategy;
import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.dataset.ExpectedDataSet;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.junit5.util.EntityManagerProvider;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.*;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static com.github.database.rider.junit5.jdbc.ConnectionManager.getConfiguredDataSourceBeanName;
import static com.github.database.rider.junit5.jdbc.ConnectionManager.getTestConnection;
import static com.github.database.rider.junit5.util.Constants.*;
import static java.lang.String.format;
/**
* Created by pestano on 27/08/16.
*/
public class DBUnitExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback,
BeforeEachCallback, AfterEachCallback, BeforeAllCallback, AfterAllCallback {
private static final Logger LOG = LoggerFactory.getLogger(DBUnitExtension.class.getName());
@Override
public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
EntityManagerProvider.clear();
DBUnitTestContext dbUnitTestContext = getTestContext(extensionContext);
final DataSetExecutor dataSetExecutor = dbUnitTestContext.getExecutor();
final DBUnitConfig dbUnitConfig = resolveDbUnitConfig(Optional.empty(), extensionContext.getTestMethod(), extensionContext.getRequiredTestClass());
dataSetExecutor.setDBUnitConfig(dbUnitConfig);
if (dbUnitConfig.isLeakHunter()) {
try {
LeakHunter leakHunter = LeakHunterFactory.from(dataSetExecutor.getRiderDataSource(), extensionContext.getRequiredTestMethod().getName(), dbUnitConfig.isCacheConnection());
leakHunter.measureConnectionsBeforeExecution();
dbUnitTestContext.setLeakHunter(leakHunter);
} catch (SQLException e) {
LOG.warn(format("Could not create leak hunter for test %s", extensionContext.getRequiredTestMethod().getName()), e);
}
}
RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
RiderRunner riderRunner = new RiderRunner();
riderRunner.setup(riderTestContext);
riderRunner.runBeforeTest(riderTestContext);
}
@Override
public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
final DBUnitTestContext dbUnitTestContext = getTestContext(extensionContext);
final DBUnitConfig dbUnitConfig = dbUnitTestContext.getExecutor().getDBUnitConfig();
RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
RiderRunner riderRunner = new RiderRunner();
try {
riderRunner.runAfterTest(riderTestContext);
if (dbUnitConfig != null && dbUnitConfig.isLeakHunter()) {
LeakHunter leakHunter = dbUnitTestContext.getLeakHunter();
leakHunter.checkConnectionsAfterExecution();
}
} finally {
riderRunner.teardown(riderTestContext);
}
}
/**
* one test context (datasetExecutor and dbunitConfig) per test
*/
private DBUnitTestContext getTestContext(ExtensionContext context) {
Class> testClass = context.getRequiredTestClass();
Store store = context.getStore(NAMESPACE);
return store.getOrComputeIfAbsent(testClass, (tc) -> createDBUnitTestContext(context), DBUnitTestContext.class);
}
private DBUnitTestContext createDBUnitTestContext(ExtensionContext extensionContext) {
final String executorId = getExecutorId(extensionContext, null);
final ConnectionHolder connectionHolder = getTestConnection(extensionContext, executorId);
final DataSetExecutor dataSetExecutor = DataSetExecutorImpl.instance(executorId, connectionHolder);
return new DBUnitTestContext(dataSetExecutor);
}
private Set findCallbackMethods(Class testClass, Class callback) {
final Set methods = new HashSet<>();
Stream.of(testClass.getSuperclass()
.getDeclaredMethods(), testClass.getDeclaredMethods())
.flatMap(Stream::of)
.filter(m -> m.getAnnotation(callback) != null)
.forEach(m -> methods.add((Method) m)); //do not use Collectors.toSet here: stream incompatible types
return Collections.unmodifiableSet(methods);
}
@Override
public void beforeEach(ExtensionContext extensionContext) throws Exception {
if (extensionContext.getTestClass().isPresent()) {
Set callbackMethods = findCallbackMethods(extensionContext.getTestClass().get(), BeforeEach.class);
if (!callbackMethods.isEmpty()) {
for (Method callbackMethod : callbackMethods) {
executeDataSetForCallback(extensionContext, BeforeEach.class, callbackMethod);
executeExpectedDataSetForCallback(extensionContext, BeforeEach.class, callbackMethod);
}
}
}
}
@Override
public void afterEach(ExtensionContext extensionContext) throws Exception {
if (extensionContext.getTestClass().isPresent()) {
Set callbackMethods = findCallbackMethods(extensionContext.getTestClass().get(), AfterEach.class);
if (!callbackMethods.isEmpty()) {
for (Method callbackMethod : callbackMethods) {
executeDataSetForCallback(extensionContext, AfterEach.class, callbackMethod);
executeExpectedDataSetForCallback(extensionContext, AfterEach.class, callbackMethod);
}
}
}
}
@Override
public void beforeAll(ExtensionContext extensionContext) throws Exception {
if (extensionContext.getTestClass().isPresent()) {
Set callbackMethods = findCallbackMethods(extensionContext.getTestClass().get(), BeforeAll.class);
if (!callbackMethods.isEmpty()) {
for (Method callbackMethod : callbackMethods) {
executeDataSetForCallback(extensionContext, BeforeAll.class, callbackMethod);
executeExpectedDataSetForCallback(extensionContext, BeforeAll.class, callbackMethod);
}
}
}
}
@Override
public void afterAll(ExtensionContext extensionContext) throws Exception {
if (extensionContext.getTestClass().isPresent()) {
Set callbackMethods = findCallbackMethods(extensionContext.getTestClass().get(), AfterAll.class);
if (!callbackMethods.isEmpty()) {
for (Method callbackMethod : callbackMethods) {
executeDataSetForCallback(extensionContext, AfterAll.class, callbackMethod);
executeExpectedDataSetForCallback(extensionContext, AfterAll.class, callbackMethod);
}
}
}
}
private void executeDataSetForCallback(ExtensionContext extensionContext, Class callbackAnnotation, Method callbackMethod) throws SQLException {
Class testClass = extensionContext.getTestClass().get();
// get DataSet annotation, if any
Optional dataSetAnnotation = AnnotationUtils.findAnnotation(callbackMethod, DataSet.class);
if (!dataSetAnnotation.isPresent()) {
LOG.warn("Could not find dataset annotation from callback method: " + callbackMethod);
return;
}
EntityManagerProvider.clear();
final DBUnitTestContext dbUnitTestContext = getTestContext(extensionContext);
final DBUnitConfig dbUnitConfig = resolveDbUnitConfig(Optional.of(callbackAnnotation), Optional.of(callbackMethod), testClass);
DataSet dataSet;
if (dbUnitConfig.isMergeDataSets()) {
Optional classLevelDataSetAnnotation = AnnotationUtils.findAnnotation(testClass, DataSet.class);
dataSet = resolveDataSet(dataSetAnnotation, classLevelDataSetAnnotation, dbUnitConfig);
} else {
dataSet = dataSetAnnotation.get();
}
DataSetExecutor dataSetExecutor = dbUnitTestContext.getExecutor();
dataSetExecutor.setDBUnitConfig(dbUnitConfig);
dataSetExecutor = resetExecutorConnectionIfNeeded(extensionContext, callbackAnnotation, dbUnitConfig, dataSetExecutor);
dataSetExecutor.createDataSet(new DataSetConfig().from(dataSet));
closeConnectionForAfterCallback(dataSetExecutor, callbackAnnotation);
}
/**
* We only need to close the connection in afterCallback because the connection opened in before callback is closed after test execution ({@link RiderRunner#teardown(RiderTestContext)})
*
* @param dataSetExecutor
* @param callbackAnnotation
* @throws SQLException
*/
private void closeConnectionForAfterCallback(DataSetExecutor dataSetExecutor, Class callbackAnnotation) throws SQLException {
if (!isAfterTestCallback(callbackAnnotation)) {
return;
}
if (!dataSetExecutor.getDBUnitConfig().isCacheConnection() && !dataSetExecutor.getRiderDataSource().getDBUnitConnection().getConnection().isClosed()) {
dataSetExecutor.getRiderDataSource().getDBUnitConnection().getConnection().close();
((DataSetExecutorImpl) dataSetExecutor).clearRiderDataSource();
}
}
private void executeExpectedDataSetForCallback(ExtensionContext extensionContext, Class callbackAnnotation, Method callbackMethod) throws DatabaseUnitException, SQLException {
Class testClass = extensionContext.getTestClass().get();
// get ExpectedDataSet annotation, if any
Optional expectedDataSetAnnotation = AnnotationUtils.findAnnotation(callbackMethod, ExpectedDataSet.class);
if (!expectedDataSetAnnotation.isPresent()) {
LOG.warn("Could not find expectedDataSet annotation annotation from callback method: " + callbackMethod);
return;
}
ExpectedDataSet expectedDataSet = expectedDataSetAnnotation.get();
// Verify expected dataset
// Resolve DBUnit config from annotation or file
DBUnitConfig dbUnitConfig = resolveDbUnitConfig(Optional.of(callbackAnnotation), Optional.of(callbackMethod), testClass);
DataSetExecutor dataSetExecutor = getTestContext(extensionContext).getExecutor();
dataSetExecutor.setDBUnitConfig(dbUnitConfig);
dataSetExecutor = resetExecutorConnectionIfNeeded(extensionContext, callbackAnnotation, dbUnitConfig, dataSetExecutor);
dataSetExecutor.compareCurrentDataSetWith(
new DataSetConfig(expectedDataSet.value()).disableConstraints(true).datasetProvider(expectedDataSet.provider()),
expectedDataSet.ignoreCols(),
expectedDataSet.replacers(),
expectedDataSet.orderBy(),
expectedDataSet.compareOperation());
closeConnectionForAfterCallback(dataSetExecutor, callbackAnnotation);
}
private DataSetExecutor resetExecutorConnectionIfNeeded(ExtensionContext extensionContext, Class callbackAnnotation, DBUnitConfig dbUnitConfig, DataSetExecutor dataSetExecutor) {
if (!dbUnitConfig.isCacheConnection() && isAfterTestCallback(callbackAnnotation)) { //we close the connection after test execution when cache is disabled so we need a new one for the callback
final ConnectionHolder connectionHolder = getTestConnection(extensionContext, dataSetExecutor.getExecutorId());
dataSetExecutor = DataSetExecutorImpl.instance(dataSetExecutor.getExecutorId(), connectionHolder, dbUnitConfig);
}
return dataSetExecutor;
}
private boolean isAfterTestCallback(Class callbackAnnotation) {
return callbackAnnotation.equals(AfterEach.class) || callbackAnnotation.equals(AfterAll.class);
}
// Resolve DBUnit config from annotation or file
private DBUnitConfig resolveDbUnitConfig(Optional callbackAnnotation, Optional method, Class testClass) {
Optional dbUnitAnnotation = AnnotationUtils.findAnnotation(method, DBUnit.class);
if (!dbUnitAnnotation.isPresent()) {
dbUnitAnnotation = AnnotationUtils.findAnnotation(testClass, DBUnit.class);
}
if (!dbUnitAnnotation.isPresent() && callbackAnnotation.isPresent()) {
Set callbackMethods = findCallbackMethods(testClass, callbackAnnotation.get());
if (!callbackMethods.isEmpty()) {
dbUnitAnnotation = AnnotationUtils.findAnnotation(callbackMethods.iterator().next(), DBUnit.class);
}
}
if (!dbUnitAnnotation.isPresent() && testClass.getSuperclass() != null) {
dbUnitAnnotation = AnnotationUtils.findAnnotation(testClass.getSuperclass(), DBUnit.class);
}
return dbUnitAnnotation.isPresent() ? DBUnitConfig.from(dbUnitAnnotation.get()) : DBUnitConfig.fromGlobalConfig();
}
// Resolve dataSet annotation, merging class and method annotations if needed
private DataSet resolveDataSet(Optional methodLevelDataSet,
Optional classLevelDataSet, DBUnitConfig config) {
if (classLevelDataSet.isPresent()) {
if (DataSetMergingStrategy.METHOD.equals(config.getMergingStrategy())) {
return com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations(classLevelDataSet.get(), methodLevelDataSet.get());
} else {
return com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations(methodLevelDataSet.get(), classLevelDataSet.get());
}
} else {
return methodLevelDataSet.get();
}
}
private String getExecutorId(final ExtensionContext extensionContext, DataSet dataSet) {
Optional annDataSet;
if (dataSet != null) {
annDataSet = Optional.of(dataSet);
} else {
annDataSet = findDataSetAnnotation(extensionContext);
}
String dataSourceBeanName = getConfiguredDataSourceBeanName(extensionContext);
String executionIdSuffix = dataSourceBeanName.isEmpty() ? EMPTY_STRING : "-" + dataSourceBeanName;
return annDataSet
.map(DataSet::executorId)
.filter(StringUtils::isNotBlank)
.map(id -> id + executionIdSuffix)
.orElseGet(() -> JUNIT5_EXECUTOR + executionIdSuffix);
}
private Optional findDataSetAnnotation(ExtensionContext extensionContext) {
Optional testMethod = extensionContext.getTestMethod();
if (testMethod.isPresent()) {
Optional annDataSet = AnnotationUtils.findAnnotation(testMethod.get(), DataSet.class);
if (!annDataSet.isPresent()) {
annDataSet = AnnotationUtils.findAnnotation(extensionContext.getRequiredTestClass(), DataSet.class);
}
return annDataSet;
} else {
return Optional.empty();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy