com.aire.ux.test.spring.servlet.ServletDefinitionExtension Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aire-test-spring Show documentation
Show all versions of aire-test-spring Show documentation
Sunshower UX Testing Libraries
The newest version!
package com.aire.ux.test.spring.servlet;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import jakarta.servlet.Servlet;
import jakarta.servlet.annotation.WebServlet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import lombok.val;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;
public class ServletDefinitionExtension implements Extension, BeforeAllCallback, AfterAllCallback {
private static final String KEY = "DEFINITION_STORE";
@Override
public void afterAll(ExtensionContext context) throws Exception {
val applicationContext = SpringExtension.getApplicationContext(context);
unregisterServletDefinitions(applicationContext, context);
}
@Override
public void beforeAll(ExtensionContext context) throws Exception {
val applicationContext = SpringExtension.getApplicationContext(context);
registerServletDefinitions(applicationContext, context);
}
private void registerServletDefinitions(
ApplicationContext applicationContext, ExtensionContext context) {
val store = context.getStore(Namespace.create(applicationContext, context));
registerClient(store, applicationContext);
context
.getTestClass()
.flatMap(testClass -> Optional.ofNullable(testClass.getAnnotation(WithServlets.class)))
.ifPresent(
withServlets -> {
defineServlets(
store,
withServlets,
(ConfigurableListableBeanFactory)
applicationContext.getAutowireCapableBeanFactory());
});
postProcessBeanFactory(
store,
((ConfigurableListableBeanFactory) applicationContext.getAutowireCapableBeanFactory()));
}
@SuppressFBWarnings
@SuppressWarnings("unchecked")
private void registerClient(Store store, ApplicationContext applicationContext) {
val beanDefinitionRegistry =
(BeanDefinitionRegistry) applicationContext.getAutowireCapableBeanFactory();
val beanDefinition =
BeanDefinitionBuilder.rootBeanDefinition(DefaultClient.class)
.addConstructorArgValue(applicationContext)
.getBeanDefinition();
val definitions =
(List)
store.getOrComputeIfAbsent(KEY, (k) -> new ArrayList());
definitions.add(beanDefinition);
beanDefinitionRegistry.registerBeanDefinition(
Objects.requireNonNull(beanDefinition.getBeanClassName()), beanDefinition);
}
private void postProcessBeanFactory(Store store, ConfigurableListableBeanFactory beanFactory)
throws BeansException {
if (!(beanFactory instanceof BeanDefinitionRegistry)) {
throw new IllegalStateException(
String.format(
"Wrong type of bean factory (unsupported context): %s", beanFactory.getClass()));
}
val names = beanFactory.getBeanNamesForAnnotation(WithServlets.class);
for (val name : names) {
scanBeanType(store, beanFactory.getBeanDefinition(name), beanFactory);
}
}
private void scanBeanType(
Store store, BeanDefinition beanDefinition, ConfigurableListableBeanFactory beanFactory) {
try {
val classloader = beanFactory.getBeanClassLoader();
val actualType = Class.forName(beanDefinition.getBeanClassName(), false, classloader);
val withServletsAnnotation = actualType.getAnnotation(WithServlets.class);
defineServlets(store, withServletsAnnotation, beanFactory);
} catch (ClassNotFoundException e) {
throw new BeanInitializationException(e.getMessage(), e);
}
}
private void defineServlets(
Store store,
WithServlets withServletsAnnotation,
ConfigurableListableBeanFactory beanFactory) {
for (val servletDefinition : withServletsAnnotation.servlets()) {
if (!Servlet.class.equals(servletDefinition.type())) {
defineServlet(store, servletDefinition.type(), beanFactory, servletDefinition.paths());
}
}
for (val servlet : withServletsAnnotation.value()) {
if (!Servlet.class.equals(servlet)) {
defineServlet(store, servlet, beanFactory, getRequestMappings(servlet));
}
}
}
@SuppressFBWarnings
@SuppressWarnings("unchecked")
private void defineServlet(
Store store,
Class extends Servlet> servlet,
ConfigurableListableBeanFactory beanFactory,
String[] names) {
val definitions =
(List)
store.getOrComputeIfAbsent(KEY, (k) -> new ArrayList());
/** register servlet class */
val definition = BeanDefinitionBuilder.rootBeanDefinition(servlet).getBeanDefinition();
((BeanDefinitionRegistry) beanFactory)
.registerBeanDefinition(Objects.requireNonNull(definition.getBeanClassName()), definition);
val servletRegistrationDefinition =
BeanDefinitionBuilder.rootBeanDefinition(ServletRegistrationBean.class)
.addConstructorArgReference(definition.getBeanClassName())
.addConstructorArgValue(true)
.addConstructorArgValue(names)
.setLazyInit(false)
.getBeanDefinition();
((BeanDefinitionRegistry) beanFactory)
.registerBeanDefinition(
definition.getBeanClassName() + "registration", servletRegistrationDefinition);
definitions.addAll(List.of(definition, servletRegistrationDefinition));
}
private String[] getRequestMappings(Class extends Servlet> servlet) {
if (servlet.isAnnotationPresent(WebServlet.class)) {
return servlet.getAnnotation(WebServlet.class).value();
}
throw new UnsupportedOperationException(
String.format(
"Error: must annotate '%s' with an @WebServlet containing request mappings", servlet));
}
@SuppressWarnings("unchecked")
private void unregisterServletDefinitions(
ApplicationContext applicationContext, ExtensionContext context) {
val store = context.getStore(Namespace.create(applicationContext, context));
val definitions =
new HashSet<>(
(List)
store.getOrComputeIfAbsent(KEY, (k) -> new ArrayList()));
val registry = (BeanDefinitionRegistry) applicationContext.getAutowireCapableBeanFactory();
for (val name : registry.getBeanDefinitionNames()) {
val definition = registry.getBeanDefinition(name);
if (definitions.contains(definition)) {
registry.removeBeanDefinition(name);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy