All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.crosstreelabs.junited.dbunit.DbUnitRule Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.crosstreelabs.junited.dbunit;

import com.crosstreelabs.junited.core.WrappingRule;
import com.crosstreelabs.junited.dbunit.annotation.DatabaseSetup;
import com.crosstreelabs.junited.dbunit.annotation.DatabaseSetups;
import com.crosstreelabs.junited.dbunit.conversion.Converter;
import com.crosstreelabs.junited.dbunit.conversion.StandardConverters;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.sql.DataSource;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.CompositeDataSet;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.ext.db2.Db2Connection;
import org.dbunit.ext.h2.H2Connection;
import org.dbunit.ext.hsqldb.HsqldbConnection;
import org.dbunit.ext.mckoi.MckoiConnection;
import org.dbunit.ext.mssql.MsSqlConnection;
import org.dbunit.ext.mysql.MySqlConnection;
import org.dbunit.ext.oracle.OracleConnection;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DbUnitRule extends WrappingRule {
    private static final Logger LOGGER = LoggerFactory.getLogger(DbUnitRule.class);

    private DataSource datasource;
    private Connection connection;
    private IDatabaseConnection dbunitConnection;
    private String schema;
    private Class testClass;
    private Method testMethod;
    
    public DbUnitRule(final DataSource datasource) {
        try {
            this.datasource = datasource;
            this.connection = datasource.getConnection();
            this.schema = connection.getSchema();
        } catch (SQLException ex) {
            throw new RuntimeException(ex);
        }
    }
    public DbUnitRule(final Connection connection) {
        try {
            this.connection = connection;
            this.schema = connection.getSchema();
        } catch (SQLException ex) {
            throw new RuntimeException(ex);
        }
    }
    public DbUnitRule(final IDatabaseConnection dbunitConnection) {
        this.dbunitConnection = dbunitConnection;
        this.schema = dbunitConnection.getSchema();
    }
    
    @Override
    public Statement apply(final Statement base,
            final Description description) {
        testClass = description.getTestClass();
        if (description.getMethodName() != null) {
            try {
                testMethod = testClass.getDeclaredMethod(description.getMethodName());
            } catch (NoSuchMethodException | SecurityException ex) {
                throw new RuntimeException(ex);
            }
        }
        
        return super.apply(base, description);
    }
    
    @Override
    protected void before(final Statement statement,
            final Description description) throws Throwable {
        List setups = new ArrayList<>();
        // Get class-level setups
        if (testClass.isAnnotationPresent(DatabaseSetup.class)) {
            setups.add(testClass.getAnnotation(DatabaseSetup.class));
        }
        if (testClass.isAnnotationPresent(DatabaseSetups.class)) {
            setups.addAll(Arrays.asList(testClass.getAnnotation(DatabaseSetups.class).value()));
        }
        
        // Get method-level setups
        if (testMethod != null && testMethod.isAnnotationPresent(DatabaseSetup.class)) {
            setups.add(testMethod.getAnnotation(DatabaseSetup.class));
        }
        if (testMethod != null && testMethod.isAnnotationPresent(DatabaseSetups.class)) {
            setups.addAll(Arrays.asList(testMethod.getAnnotation(DatabaseSetups.class).value()));
        }
        
        // If there's nothing to do
        if (setups.isEmpty()) {
            return;
        }
        
        // Get the applicable operation
        Operation op = setups.get(setups.size()-1).operation();
        List datasets = loadDataSets(setups);
        if (datasets.isEmpty()) {
            return;
        }
        IDatabaseConnection conn = getConnection();
//        for (IDataSet ds : datasets) {
//            op.getOperation().execute(conn, ds);
//        }
        IDataSet dataSet = new CompositeDataSet(
                datasets.toArray(new IDataSet[datasets.size()]));
        op.getOperation().execute(conn, dataSet);
    }

    @Override
    protected void after(final Statement statement,
            final Description description) {
        super.after(statement, description); //To change body of generated methods, choose Tools | Templates.
    }
    
    //~ Helpers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    private List loadDataSets(final List setups) throws DataSetException {
        List datasets = new ArrayList<>();
        for (DatabaseSetup setup : setups) {
            datasets.addAll(loadDataSets(setup));
        }
        return datasets;
    }
    private List loadDataSets(final DatabaseSetup setup) throws DataSetException {
        List datasets = new ArrayList<>();
        for (String location : setup.value()) {
            LOGGER.debug("Loading dataset: {}", location);
            datasets.add(loadDataSet(location));
        }
        return datasets;
    }
    private IDataSet loadDataSet(final String location) throws DataSetException {
        if (location == null || location.trim().isEmpty()) {
            return null;
        }
        
        // Find the resource
        InputStream is = testClass.getResourceAsStream(location);
        if (is == null) {
            throw new DataSetException("Unable to find dataset "+location);
        }
        try {
            return new EnhancedXmlDataSet(is){
                @Override
                public Map getConverters() {
                    return new HashMap(){{
                        put("hex", new StandardConverters.HexToBinaryConverter());
                        put("bin", new StandardConverters.BinaryConverter());
                        put("uuid", new StandardConverters.UUIDConverter());
                    }};
                }
            };
        } catch (DataSetException ex) {
            throw new DataSetException("Problem with dataset "+location, ex);
        }
    }
    
    private IDatabaseConnection getConnection() throws DatabaseUnitException, SQLException {
        if (dbunitConnection != null) {
            return dbunitConnection;
        }
        
        if (connection == null) {
            connection = datasource.getConnection();
        }
        String product = connection.getMetaData().getDatabaseProductName().toLowerCase();
        
        if (product.contains("db2")) {
            return dbunitConnection = new Db2Connection(connection, schema);
        }
        if (product.contains("h2")) {
            return dbunitConnection = new H2Connection(connection, schema);
        }
        if (product.contains("hsql")) {
            return dbunitConnection = new HsqldbConnection(connection, schema);
        }
        if (product.contains("mckoi")) {
            return dbunitConnection = new MckoiConnection(connection, schema);
        }
        if (product.contains("sql server")) {
            return dbunitConnection = new MsSqlConnection(connection, schema);
        }
        if (product.contains("mysql")) {
            return dbunitConnection = new MySqlConnection(connection, schema);
        }
        if (product.contains("oracle")) {
            return dbunitConnection = new OracleConnection(connection, schema);
        }
        return dbunitConnection = new DatabaseConnection(connection, product);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy