All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apereo.cas.tomcat.CasTomcatServletWebServerFactoryCustomizer Maven / Gradle / Ivy

The newest version!
package org.apereo.cas.tomcat;

import org.apereo.cas.configuration.CasConfigurationProperties;
import org.apereo.cas.configuration.model.core.web.tomcat.CasEmbeddedApacheTomcatHttpProperties;
import org.apereo.cas.configuration.model.core.web.tomcat.CasEmbeddedApacheTomcatHttpProxyProperties;
import org.apereo.cas.configuration.support.Beans;
import org.apereo.cas.util.RandomUtils;
import org.apereo.cas.util.RegexUtils;
import org.apereo.cas.util.ResourceUtils;
import org.apereo.cas.util.function.FunctionUtils;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.catalina.authenticator.BasicAuthenticator;
import org.apache.catalina.connector.Connector;
import org.apache.catalina.connector.Request;
import org.apache.catalina.connector.Response;
import org.apache.catalina.realm.GenericPrincipal;
import org.apache.catalina.valves.ExtendedAccessLogValve;
import org.apache.catalina.valves.SSLValve;
import org.apache.catalina.valves.ValveBase;
import org.apache.catalina.valves.rewrite.RewriteValve;
import org.apache.commons.lang3.StringUtils;
import org.apache.coyote.ajp.AbstractAjpProtocol;
import org.apache.coyote.ajp.AjpNio2Protocol;
import org.apache.coyote.ajp.AjpNioProtocol;
import org.apache.coyote.http11.Http11Nio2Protocol;
import org.apache.coyote.http11.Http11NioProtocol;
import org.apache.coyote.http2.Http2Protocol;
import org.apache.tomcat.util.descriptor.web.LoginConfig;
import org.apache.tomcat.util.descriptor.web.SecurityCollection;
import org.apache.tomcat.util.descriptor.web.SecurityConstraint;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryCustomizer;
import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory;
import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory;
import org.springframework.util.ReflectionUtils;
import jakarta.servlet.ServletException;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;

/**
 * This is {@link CasTomcatServletWebServerFactoryCustomizer}.
 *
 * @author Misagh Moayyed
 * @since 6.1.0
 */
@Slf4j
public class CasTomcatServletWebServerFactoryCustomizer extends ServletWebServerFactoryCustomizer {
    private static final int PORT_RANGE_MIN = 4000;
    private static final int PORT_RANGE_MAX = 9000;

    private final CasConfigurationProperties casProperties;

    private final ServerProperties serverProperties;

    public CasTomcatServletWebServerFactoryCustomizer(final ServerProperties serverProperties,
                                                      final CasConfigurationProperties casProperties) {
        super(serverProperties);
        this.casProperties = casProperties;
        this.serverProperties = serverProperties;
    }

    private static void configureConnectorForProtocol(final Connector connector,
                                                      final CasEmbeddedApacheTomcatHttpProxyProperties proxy) {
        val handler = ReflectionUtils.findField(connector.getClass(), "protocolHandler");
        if (handler != null) {
            ReflectionUtils.makeAccessible(handler);
            if ("HTTP/2".equalsIgnoreCase(proxy.getProtocol())) {
                connector.addUpgradeProtocol(new Http2Protocol());
            } else {
                val protocolHandlerInstance = switch (proxy.getProtocol()) {
                    case "AJP/2" -> new AjpNio2Protocol();
                    case "HTTP/1.2" -> new Http11Nio2Protocol();
                    case "HTTP/1.1" -> new Http11NioProtocol();
                    default -> new AjpNioProtocol();
                };
                if (protocolHandlerInstance instanceof final AbstractAjpProtocol ajp) {
                    ajp.setSecretRequired(proxy.isSecure());
                    ajp.setSecret(proxy.getSecret());
                }
                protocolHandlerInstance.setPort(connector.getPort());
                ReflectionUtils.setField(handler, connector, protocolHandlerInstance);
            }
            val handlerClass = ReflectionUtils.findField(connector.getClass(), "protocolHandlerClassName");
            if (handlerClass != null) {
                ReflectionUtils.makeAccessible(handlerClass);
                ReflectionUtils.setField(handlerClass, connector, connector.getProtocolHandler().getClass().getName());
            }
        }
    }

    @Override
    public void customize(final ConfigurableServletWebServerFactory factory) {
        if (factory instanceof final TomcatServletWebServerFactory tomcat) {
            configureAjp(tomcat);
            configureHttp(tomcat);
            configureHttpProxy(tomcat);
            configureExtendedAccessLogValve(tomcat);
            configureRewriteValve(tomcat);
            configureSSLValve(tomcat);
            configureBasicAuthn(tomcat);
            configureRemoteUserValve(tomcat);
            finalizeConnectors(tomcat);
        }
    }

    private void finalizeConnectors(final TomcatServletWebServerFactory tomcat) {
        tomcat.addConnectorCustomizers(connector -> {
            val tc = casProperties.getServer().getTomcat();
            connector.setProperty("Server", tc.getServerName());

            val socket = tc.getSocket();
            if (socket.getBufferPool() > 0) {
                connector.setProperty("socket.bufferPool", String.valueOf(socket.getBufferPool()));
            }
            if (socket.getAppReadBufSize() > 0) {
                connector.setProperty("socket.appReadBufSize", String.valueOf(socket.getAppReadBufSize()));
            }
            if (socket.getAppWriteBufSize() > 0) {
                connector.setProperty("socket.appWriteBufSize", String.valueOf(socket.getAppWriteBufSize()));
            }
            if (socket.getPerformanceBandwidth() >= 0) {
                connector.setProperty("socket.performanceBandwidth", String.valueOf(socket.getPerformanceBandwidth()));
            }
            if (socket.getPerformanceConnectionTime() >= 0) {
                connector.setProperty("socket.performanceConnectionTime", String.valueOf(socket.getPerformanceConnectionTime()));
            }
            if (socket.getPerformanceLatency() >= 0) {
                connector.setProperty("socket.performanceLatency", String.valueOf(socket.getPerformanceLatency()));
            }
        });
    }

    private void configureBasicAuthn(final TomcatServletWebServerFactory tomcat) {
        val basic = casProperties.getServer().getTomcat().getBasicAuthn();
        if (basic.isEnabled()) {
            tomcat.addContextCustomizers(ctx -> {
                val config = new LoginConfig();
                config.setAuthMethod("BASIC");
                ctx.setLoginConfig(config);

                basic.getSecurityRoles().forEach(ctx::addSecurityRole);

                basic.getAuthRoles().forEach(r -> {
                    val constraint = new SecurityConstraint();
                    constraint.addAuthRole(r);
                    val collection = new SecurityCollection();
                    basic.getPatterns().forEach(collection::addPattern);
                    constraint.addCollection(collection);
                    ctx.addConstraint(constraint);
                });
            });
            tomcat.addContextValves(new BasicAuthenticator());
        }
    }

    private void configureExtendedAccessLogValve(final TomcatServletWebServerFactory tomcat) {
        val ext = casProperties.getServer().getTomcat().getExtAccessLog();

        if (ext.isEnabled() && StringUtils.isNotBlank(ext.getPattern())) {
            LOGGER.debug("Creating extended access log valve configuration for the embedded tomcat container...");
            val valve = new ExtendedAccessLogValve();
            valve.setPattern(ext.getPattern());

            if (StringUtils.isBlank(ext.getDirectory())) {
                valve.setDirectory(serverProperties.getTomcat().getAccesslog().getDirectory());
            } else {
                valve.setDirectory(ext.getDirectory());
            }
            valve.setPrefix(ext.getPrefix());
            valve.setSuffix(ext.getSuffix());
            valve.setAsyncSupported(true);
            valve.setEnabled(true);
            valve.setRotatable(true);
            valve.setBuffered(true);
            tomcat.addEngineValves(valve);
        }
    }

    private void configureHttp(final TomcatServletWebServerFactory tomcat) {
        casProperties.getServer().getTomcat().getHttp()
            .stream()
            .filter(CasEmbeddedApacheTomcatHttpProperties::isEnabled)
            .forEach(http -> {
                LOGGER.debug("Creating HTTP configuration for the embedded tomcat container...");
                val connector = new Connector(http.getProtocol());
                var port = http.getPort();
                if (port <= 0) {
                    port = RandomUtils.nextInt(PORT_RANGE_MIN, PORT_RANGE_MAX);
                    LOGGER.warn("No explicit port configuration is provided to CAS. Using random port [{}]", port);
                }
                LOGGER.info("Activated embedded tomcat container HTTP port on [{}]", port);
                connector.setPort(port);
                if (http.getRedirectPort() > 0) {
                    connector.setRedirectPort(http.getRedirectPort());
                }
                connector.setScheme(http.getScheme());
                connector.setSecure(http.isSecure());

                LOGGER.debug("Configuring embedded tomcat container for HTTP2 protocol support");
                connector.addUpgradeProtocol(new Http2Protocol());

                http.getAttributes().forEach(connector::setProperty);
                tomcat.addAdditionalTomcatConnectors(connector);
            });
    }

    private void configureHttpProxy(final TomcatServletWebServerFactory tomcat) {
        val proxy = casProperties.getServer().getTomcat().getHttpProxy();
        if (proxy.isEnabled()) {
            LOGGER.debug("Customizing HTTP proxying for connector listening on port [{}]", tomcat.getPort());
            tomcat.getTomcatConnectorCustomizers().add(connector -> {
                connector.setSecure(proxy.isSecure());
                connector.setScheme(proxy.getScheme());

                if (StringUtils.isNotBlank(proxy.getProtocol())) {
                    LOGGER.debug("Setting HTTP proxying protocol to [{}]", proxy.getProtocol());
                    configureConnectorForProtocol(connector, proxy);
                }

                if (proxy.getRedirectPort() > 0) {
                    LOGGER.debug("Setting HTTP proxying redirect port to [{}]", proxy.getRedirectPort());
                    connector.setRedirectPort(proxy.getRedirectPort());
                }
                if (proxy.getProxyPort() > 0) {
                    LOGGER.debug("Setting HTTP proxying proxy port to [{}]", proxy.getProxyPort());
                    connector.setProxyPort(proxy.getProxyPort());
                }
                connector.addUpgradeProtocol(new Http2Protocol());

                proxy.getAttributes().forEach(connector::setProperty);
                LOGGER.info("Configured connector listening on port [{}]", tomcat.getPort());
            });
        } else {
            LOGGER.trace("HTTP proxying is not enabled for CAS; Connector configuration for port [{}] is not modified.", tomcat.getPort());
        }
    }

    private void configureAjp(final TomcatServletWebServerFactory tomcat) {
        val ajp = casProperties.getServer().getTomcat().getAjp();
        if (ajp.isEnabled() && ajp.getPort() > 0) {
            LOGGER.debug("Creating AJP configuration for the embedded tomcat container...");
            val ajpConnector = new Connector(ajp.getProtocol());
            ajpConnector.setPort(ajp.getPort());
            ajpConnector.setSecure(ajp.isSecure());
            ajpConnector.setAllowTrace(ajp.isAllowTrace());
            ajpConnector.setScheme(ajp.getScheme());
            ajpConnector.setAsyncTimeout(Beans.newDuration(ajp.getAsyncTimeout()).toMillis());
            ajpConnector.setEnableLookups(ajp.isEnableLookups());
            ajpConnector.setMaxPostSize(ajp.getMaxPostSize());
            ajpConnector.addUpgradeProtocol(new Http2Protocol());

            val handler = (AbstractAjpProtocol) ajpConnector.getProtocolHandler();
            if (handler != null) {
                handler.setSecretRequired(ajp.isSecure());
                handler.setSecret(ajp.getSecret());
            }

            if (ajp.getProxyPort() > 0) {
                LOGGER.debug("Set AJP proxy port to [{}]", ajp.getProxyPort());
                ajpConnector.setProxyPort(ajp.getProxyPort());
            }

            if (ajp.getRedirectPort() > 0) {
                LOGGER.debug("Set AJP redirect port to [{}]", ajp.getRedirectPort());
                ajpConnector.setRedirectPort(ajp.getRedirectPort());
            }
            ajp.getAttributes().forEach(ajpConnector::setProperty);
            tomcat.addAdditionalTomcatConnectors(ajpConnector);
        }
    }

    private void configureSSLValve(final TomcatServletWebServerFactory tomcat) {
        val valveConfig = casProperties.getServer().getTomcat().getSslValve();

        if (valveConfig.isEnabled()) {
            LOGGER.debug("Adding SSLValve to context of the embedded tomcat container...");
            val valve = new SSLValve();
            valve.setSslCipherHeader(valveConfig.getSslCipherHeader());
            valve.setSslCipherUserKeySizeHeader(valveConfig.getSslCipherUserKeySizeHeader());
            valve.setSslClientCertHeader(valveConfig.getSslClientCertHeader());
            valve.setSslSessionIdHeader(valveConfig.getSslSessionIdHeader());
            tomcat.addContextValves(valve);
        }
    }

    private void configureRewriteValve(final TomcatServletWebServerFactory tomcat) {
        val res = casProperties.getServer().getTomcat().getRewriteValve().getLocation();
        if (ResourceUtils.doesResourceExist(res)) {
            LOGGER.debug("Configuring rewrite valve at [{}]", res);

            val valve = new RewriteValve() {
                @Override
                public synchronized void startInternal() {
                    FunctionUtils.doUnchecked(__ -> {
                        super.startInternal();
                        try (val is = res.getInputStream();
                             val isr = new InputStreamReader(is, StandardCharsets.UTF_8);
                             val buffer = new BufferedReader(isr)) {
                            parse(buffer);
                        }
                    });
                }
            };
            valve.setAsyncSupported(true);
            valve.setEnabled(true);

            LOGGER.debug("Creating rewrite valve configuration for the embedded tomcat container...");
            tomcat.addContextValves(valve);
        }
    }

    private void configureRemoteUserValve(final TomcatServletWebServerFactory tomcat) {
        val valve = casProperties.getServer().getTomcat().getRemoteUserValve();
        if (StringUtils.isNotBlank(valve.getRemoteUserHeader())) {
            tomcat.addContextValves(new RemoteUserValve());
        }
    }

    private final class RemoteUserValve extends ValveBase {
        @Override
        public void invoke(final Request request, final Response response) throws IOException, ServletException {
            val valve = casProperties.getServer().getTomcat().getRemoteUserValve();
            val username = request.getHeader(valve.getRemoteUserHeader());
            LOGGER.trace("Received remote user [{}] from [{}]", username, request.getRemoteAddr());
            if (StringUtils.isNotBlank(username) && RegexUtils.matchesIpAddress(valve.getAllowedIpAddressRegex(), request.getRemoteAddr())) {
                val principal = new GenericPrincipal(username);
                request.setUserPrincipal(principal);
                response.setHeader("X-Remote-User", username);
            }
            getNext().invoke(request, response);
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy