All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.hrodberaht.injection.plugin.junit.SpringJUnit4Runner Maven / Gradle / Ivy

There is a newer version: 3.0.0
Show newest version
package org.hrodberaht.injection.plugin.junit;

import org.hrodberaht.injection.core.internal.exception.InjectRuntimeException;
import org.hrodberaht.injection.plugin.junit.spring.beans.SpringEntityManager;
import org.junit.runner.Description;
import org.junit.runner.notification.Failure;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextManager;
import org.springframework.test.context.TestExecutionListener;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.support.DependencyInjectionTestExecutionListener;
import org.springframework.test.context.transaction.TransactionalTestExecutionListener;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class SpringJUnit4Runner extends SpringJUnit4ClassRunner {

    private static final Logger LOG = LoggerFactory.getLogger(SpringJUnit4Runner.class);

    private final JUnitContext jUnitContext;
    private TestContext testContext;
    private TransactionalTestExecutionListener transactionalTestExecutionListener;

    /**
     * Creates a BlockJUnit4ClassRunner to run
     *
     * @param clazz
     * @throws InitializationError if the test class is malformed.
     */
    public SpringJUnit4Runner(Class clazz) throws InitializationError {
        super(clazz);
        jUnitContext = new JUnitContext(clazz) {
            @Override
            void runBeforeTest(boolean activateContainer, String testName) {
                super.runBeforeTest(activateContainer, testName);
            }
        };
    }

    @Override
    protected void runChild(FrameworkMethod frameworkMethod, RunNotifier notifier) {
        jUnitContext.runChild(frameworkMethod, () -> super.runChild(frameworkMethod, notifier), (e) -> {
            Description description = describeChild(frameworkMethod);
            notifier.fireTestFailure(new Failure(description, e));
            notifier.fireTestFinished(description);
        });
    }

    private void flushEntityManager() {
        SpringEntityManager springEntityManager = getSpringEntityManager();
        if (springEntityManager != null) {
            if (springEntityManager.getEntityManager() != null) {
                try {
                    springEntityManager.getEntityManager().flush();
                    // springEntityManager.getEntityManager().close();
                } catch (RuntimeException exception) {
                    // No not fail due to springEntityManager issues
                }

            }
        }

    }

    private SpringEntityManager getSpringEntityManager() {
        try {
            return jUnitContext.get(ApplicationContext.class)
                    .getBean(SpringEntityManager.class);
        } catch (Exception ex) {
            LOG.debug("SpringJUnitRunner info: " + ex.getMessage());
            return null;
        }
    }

    @Override
    protected Object createTest() throws Exception {
        Object testInstance = super.createTest();
        jUnitContext.autoWireTestObject(testInstance);
        return testInstance;
    }

    @Override
    protected TestContextManager createTestContextManager(Class clazz) {
        TestContextManager contextManager = new TestContextManagerLocal(clazz);
        replaceTestContext(contextManager);
        return contextManager;
    }

    private void replaceTestContext(TestContextManager contextManager) {
        try {
            Field field = TestContextManager.class.getDeclaredField("testContext");
            field.setAccessible(true);
            testContext = (TestContext) field.get(contextManager);
            field.set(contextManager, new TestContextLocal(testContext));
        } catch (NoSuchFieldException | IllegalAccessException e) {
            throw new InjectRuntimeException(e);
        }
    }


    private class TestContextManagerLocal extends TestContextManager {

        public TestContextManagerLocal(Class testClass) {
            super(testClass);
        }

        @Override
        public void registerTestExecutionListeners(List testExecutionListeners) {


            final List listeners = new ArrayList<>();
            testExecutionListeners
                    .stream()
                    .filter(t -> !(t instanceof DependencyInjectionTestExecutionListener))
                    .collect(Collectors.toList())
                    .forEach(t -> {
                                if (t instanceof TransactionalTestExecutionListener) {
                                    transactionalTestExecutionListener = new TransactionalTestExecutionListener() {
                                        @Override
                                        public void afterTestMethod(TestContext testContext) throws Exception {
                                            flushEntityManager();
                                            super.afterTestMethod(testContext);
                                        }
                                    };
                                    listeners.add(transactionalTestExecutionListener);
                                } else {
                                    listeners.add(t);
                                }
                            }

                    );

            super.registerTestExecutionListeners(listeners);
        }
    }

    private class TestContextLocal implements TestContext {
        private TestContext testContext;

        public TestContextLocal(TestContext testContext) {
            this.testContext = testContext;
        }

        @Override
        public ApplicationContext getApplicationContext() {
            return jUnitContext.get(ApplicationContext.class);
        }

        @Override
        public Class getTestClass() {
            return testContext.getTestClass();
        }

        @Override
        public Object getTestInstance() {
            return testContext.getTestInstance();
        }

        @Override
        public Method getTestMethod() {
            return testContext.getTestMethod();
        }

        @Override
        public Throwable getTestException() {
            return testContext.getTestException();
        }

        @Override
        public void markApplicationContextDirty(DirtiesContext.HierarchyMode hierarchyMode) {
            testContext.markApplicationContextDirty(hierarchyMode);
        }

        @Override
        public void updateState(Object o, Method method, Throwable throwable) {
            testContext.updateState(o, method, throwable);
        }

        @Override
        public void setAttribute(String s, Object o) {
            testContext.setAttribute(s, o);
        }

        @Override
        public Object getAttribute(String s) {
            return testContext.getAttribute(s);
        }

        @Override
        public Object removeAttribute(String s) {
            return testContext.removeAttribute(s);
        }

        @Override
        public boolean hasAttribute(String s) {
            return testContext.hasAttribute(s);
        }

        @Override
        public String[] attributeNames() {
            return testContext.attributeNames();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy