All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.linecorp.armeria.server.saml.SamlServiceProviderBuilder Maven / Gradle / Ivy

Go to download

Asynchronous HTTP/2 RPC/REST client/server library built on top of Java 8, Netty, Thrift and gRPC (armeria-saml)

The newest version!
/*
 * Copyright 2018 LINE Corporation
 *
 * LINE Corporation licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.linecorp.armeria.server.saml;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.linecorp.armeria.server.saml.HttpRedirectBindingUtil.responseWithLocation;
import static com.linecorp.armeria.server.saml.SamlEndpoint.ofHttpPost;
import static com.linecorp.armeria.server.saml.SamlEndpoint.ofHttpRedirect;
import static java.util.Objects.requireNonNull;

import java.io.UnsupportedEncodingException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.Signature;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;

import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.saml.common.messaging.context.SAMLBindingContext;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.LogoutRequest;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialResolver;
import org.opensaml.xmlsec.algorithm.AlgorithmSupport;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.shibboleth.utilities.java.support.resolver.CriteriaSet;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import com.linecorp.armeria.common.AggregatedHttpRequest;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerPort;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.auth.Authorizer;

/**
 * A builder which builds a {@link SamlServiceProvider}.
 */
public final class SamlServiceProviderBuilder {
    private static final Logger logger = LoggerFactory.getLogger(SamlServiceProviderBuilder.class);

    private final List idpConfigBuilders = new ArrayList<>();
    private final List acsConfigBuilders = new ArrayList<>();

    private final List sloEndpoints = new ArrayList<>();

    private final SamlPortConfigBuilder hostConfigBuilder = new SamlPortConfigBuilder();

    @Nullable
    private String entityId;
    @Nullable
    private String hostname;
    @Nullable
    private Authorizer authorizer;

    @Nullable
    private CredentialResolverAdapter credentialResolver;
    private String signingKey = "signing";
    private String encryptionKey = "encryption";

    private String signatureAlgorithm = SignatureConstants.ALGO_ID_SIGNATURE_DSA;

    private String metadataPath = "/saml/metadata";

    @Nullable
    private SamlIdentityProviderConfigSelector idpConfigSelector;

    @Nullable
    private SamlRequestIdManager requestIdManager;

    private SamlSingleSignOnHandler ssoHandler = new SamlSingleSignOnHandler() {
        @Override
        public CompletionStage beforeInitiatingSso(ServiceRequestContext ctx, HttpRequest req,
                                                         MessageContext message,
                                                         SamlIdentityProviderConfig idpConfig) {
            final String requestedPath = req.path();
            if (requestedPath.length() <= 80) {
                // Relay the requested path by default.
                final SAMLBindingContext sub = message.getSubcontext(SAMLBindingContext.class, true);
                assert sub != null : "SAMLBindingContext";
                sub.setRelayState(requestedPath);
            }
            return CompletableFuture.completedFuture(null);
        }

        @Override
        public HttpResponse loginSucceeded(ServiceRequestContext ctx, AggregatedHttpRequest req,
                                           MessageContext message, @Nullable String sessionIndex,
                                           @Nullable String relayState) {
            return responseWithLocation(firstNonNull(relayState, "/"));
        }

        @Override
        public HttpResponse loginFailed(ServiceRequestContext ctx, AggregatedHttpRequest req,
                                        @Nullable MessageContext message, Throwable cause) {
            logger.warn("{} SAML SSO failed", ctx, cause);
            return responseWithLocation("/error");
        }
    };

    private SamlSingleLogoutHandler sloHandler = new SamlSingleLogoutHandler() {
        @Override
        public CompletionStage logoutSucceeded(ServiceRequestContext ctx, AggregatedHttpRequest req,
                                                     MessageContext message) {
            return CompletableFuture.completedFuture(null);
        }

        @Override
        public CompletionStage logoutFailed(ServiceRequestContext ctx, AggregatedHttpRequest req,
                                                  Throwable cause) {
            logger.warn("{} SAML SLO failed", ctx, cause);
            return CompletableFuture.completedFuture(null);
        }
    };

    SamlServiceProviderBuilder() {}

    /**
     * Set an {@link Authorizer} which is used for this service provider's authentication.
     */
    public SamlServiceProviderBuilder authorizer(Authorizer authorizer) {
        this.authorizer = requireNonNull(authorizer, "authorizer");
        return this;
    }

    /**
     * Sets an entity ID for this service provider.
     */
    public SamlServiceProviderBuilder entityId(String entityId) {
        this.entityId = requireNonNull(entityId, "entityId");
        return this;
    }

    /**
     * Sets a {@link CredentialResolver} for this service provider.
     */
    public SamlServiceProviderBuilder credentialResolver(CredentialResolver credentialResolver) {
        this.credentialResolver =
                new CredentialResolverAdapter(requireNonNull(credentialResolver, "credentialResolver"));
        return this;
    }

    /**
     * Sets a {@code signing} key name for this service provider.
     */
    public SamlServiceProviderBuilder signingKey(String signingKey) {
        this.signingKey = requireNonNull(signingKey, "signingKey");
        return this;
    }

    /**
     * Sets an {@code encryption} key name for this service provider.
     */
    public SamlServiceProviderBuilder encryptionKey(String encryptionKey) {
        this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey");
        return this;
    }

    /**
     * Sets a signature algorithm which is used for signing by this service provider.
     */
    public SamlServiceProviderBuilder signatureAlgorithm(String signatureAlgorithm) {
        this.signatureAlgorithm = requireNonNull(signatureAlgorithm, "signatureAlgorithm");
        return this;
    }

    /**
     * Sets a hostname of this service provider.
     */
    public SamlServiceProviderBuilder hostname(String hostname) {
        this.hostname = requireNonNull(hostname, "hostname");
        return this;
    }

    /**
     * Sets a protocol scheme of this service provider.
     */
    public SamlServiceProviderBuilder scheme(SessionProtocol scheme) {
        hostConfigBuilder.setSchemeIfAbsent(requireNonNull(scheme, "scheme"));
        return this;
    }

    /**
     * Sets a port of this service provider.
     */
    public SamlServiceProviderBuilder port(int port) {
        hostConfigBuilder.setPortIfAbsent(port);
        return this;
    }

    /**
     * Sets a {@link ServerPort} of this service provider.
     */
    public SamlServiceProviderBuilder schemeAndPort(ServerPort serverPort) {
        hostConfigBuilder.setSchemeAndPortIfAbsent(requireNonNull(serverPort, "serverPort"));
        return this;
    }

    /**
     * Sets a URL for retrieving a metadata of this service provider.
     */
    public SamlServiceProviderBuilder metadataPath(String metadataPath) {
        this.metadataPath = requireNonNull(metadataPath, "metadataPath");
        return this;
    }

    /**
     * Sets a {@link SamlIdentityProviderConfigSelector} which determines a suitable idp for a request.
     */
    public SamlServiceProviderBuilder idpConfigSelector(
            SamlIdentityProviderConfigSelector idpConfigSelector) {
        this.idpConfigSelector = requireNonNull(idpConfigSelector, "idpConfigSelector");
        return this;
    }

    /**
     * Adds a new single logout service endpoint of this service provider.
     */
    public SamlServiceProviderBuilder sloEndpoint(SamlEndpoint sloEndpoint) {
        sloEndpoints.add(requireNonNull(sloEndpoint, "sloEndpoint"));
        return this;
    }

    /**
     * Sets a {@link SamlRequestIdManager} which creates and validates a SAML request ID.
     */
    public SamlServiceProviderBuilder requestIdManager(SamlRequestIdManager requestIdManager) {
        this.requestIdManager = requireNonNull(requestIdManager, "requestIdManager");
        return this;
    }

    /**
     * Sets a {@link SamlSingleSignOnHandler} which handles SAML messages for a single sign-on.
     */
    public SamlServiceProviderBuilder ssoHandler(SamlSingleSignOnHandler ssoHandler) {
        this.ssoHandler = requireNonNull(ssoHandler, "ssoHandler");
        return this;
    }

    /**
     * Sets a {@link SamlSingleLogoutHandler} which handles SAML messages for a single sign-on.
     */
    public SamlServiceProviderBuilder sloHandler(SamlSingleLogoutHandler sloHandler) {
        this.sloHandler = requireNonNull(sloHandler, "sloHandler");
        return this;
    }

    /**
     * Returns a {@link SamlIdentityProviderConfigBuilder} to configure a new idp for authentication.
     */
    public SamlIdentityProviderConfigBuilder idp() {
        final SamlIdentityProviderConfigBuilder config = new SamlIdentityProviderConfigBuilder(this);
        idpConfigBuilders.add(config);
        return config;
    }

    /**
     * Returns a {@link SamlAssertionConsumerConfigBuilder} to configure a new assertion consumer service
     * of this service provider.
     *
     * @deprecated Use {@link #acs(SamlEndpoint)}.
     */
    @Deprecated
    public SamlAssertionConsumerConfigBuilder acs() {
        final SamlAssertionConsumerConfigBuilder config = new SamlAssertionConsumerConfigBuilder(this);
        acsConfigBuilders.add(config);
        return config;
    }

    /**
     * Returns a {@link SamlAssertionConsumerConfigBuilder} to configure a new assertion consumer service
     * of this service provider.
     */
    public SamlAssertionConsumerConfigBuilder acs(SamlEndpoint endpoint) {
        final SamlAssertionConsumerConfigBuilder config =
                new SamlAssertionConsumerConfigBuilder(this, requireNonNull(endpoint, "endpoint"));
        acsConfigBuilders.add(config);
        return config;
    }

    /**
     * Builds a {@link SamlServiceProvider} which helps a {@link Server} have a SAML-based
     * authentication.
     */
    public SamlServiceProvider build() {

        // Must ensure that OpenSAML is initialized before building a SAML service provider.
        SamlInitializer.ensureAvailability();

        if (entityId == null) {
            throw new IllegalStateException("entity ID is not specified");
        }
        if (credentialResolver == null) {
            throw new IllegalStateException(CredentialResolver.class.getSimpleName() + " is not specified");
        }
        if (authorizer == null) {
            throw new IllegalStateException(Authorizer.class.getSimpleName() + " is not specified");
        }

        final Credential signingCredential = credentialResolver.apply(signingKey);
        if (signingCredential == null) {
            throw new IllegalStateException("cannot resolve a " + Credential.class.getSimpleName() +
                                            " for signing: " + signingKey);
        }
        final Credential encryptionCredential = credentialResolver.apply(encryptionKey);
        if (encryptionCredential == null) {
            throw new IllegalStateException("cannot resolve a " + Credential.class.getSimpleName() +
                                            " for encryption: " + encryptionKey);
        }
        validateSignatureAlgorithm(signatureAlgorithm, signingCredential);
        validateSignatureAlgorithm(signatureAlgorithm, encryptionCredential);

        // Initialize single logout service configurations.
        final List sloEndpoints;
        if (this.sloEndpoints.isEmpty()) {
            // Add two endpoints by default if there's no SLO endpoint specified by a user.
            sloEndpoints = ImmutableList.of(ofHttpPost("/saml/slo/post"),
                                            ofHttpRedirect("/saml/slo/redirect"));
        } else {
            sloEndpoints = ImmutableList.copyOf(this.sloEndpoints);
        }

        // Initialize assertion consumer service configurations.
        final List assertionConsumerConfigs;
        if (acsConfigBuilders.isEmpty()) {
            // Add two endpoints by default if there's no ACS endpoint specified by a user.
            assertionConsumerConfigs =
                    ImmutableList.of(new SamlAssertionConsumerConfigBuilder(this, ofHttpPost("/saml/acs/post"))
                                             .asDefault().build(),
                                     new SamlAssertionConsumerConfigBuilder(
                                             this, ofHttpRedirect("/saml/acs/redirect")).build());
        } else {
            // If there is only one ACS, it will be automatically a default ACS.
            if (acsConfigBuilders.size() == 1) {
                acsConfigBuilders.get(0).asDefault();
            }

            assertionConsumerConfigs = acsConfigBuilders.stream()
                                                        .map(SamlAssertionConsumerConfigBuilder::build)
                                                        .collect(toImmutableList());
        }

        // Collect assertion consumer service endpoints for checking duplication and existence.
        final Set acsEndpoints =
                assertionConsumerConfigs.stream().map(SamlAssertionConsumerConfig::endpoint)
                                        .collect(toImmutableSet());
        if (acsEndpoints.size() != assertionConsumerConfigs.size()) {
            throw new IllegalStateException("duplicated access consumer services exist");
        }

        // Initialize identity provider configurations.
        if (idpConfigBuilders.isEmpty()) {
            throw new IllegalStateException("no identity provider configuration is specified");
        }
        // If there is only one IdP, it will be automatically a default IdP.
        if (idpConfigBuilders.size() == 1) {
            idpConfigBuilders.get(0).asDefault();
        }

        final ImmutableMap.Builder idpConfigs = ImmutableMap.builder();
        SamlIdentityProviderConfig defaultIdpConfig = null;
        for (final SamlIdentityProviderConfigBuilder builder : idpConfigBuilders) {
            if (builder.acsEndpoint() != null && !acsEndpoints.contains(builder.acsEndpoint())) {
                throw new IllegalStateException("unspecified access consumer service at " +
                                                builder.acsEndpoint());
            }

            final SamlIdentityProviderConfig config = builder.build(credentialResolver);

            validateSignatureAlgorithm(signatureAlgorithm, config.signingCredential());
            validateSignatureAlgorithm(signatureAlgorithm, config.encryptionCredential());

            idpConfigs.put(config.entityId(), config);

            if (builder.isDefault()) {
                if (defaultIdpConfig != null) {
                    throw new IllegalStateException("there has to be only one default identity provider");
                }
                defaultIdpConfig = config;
            }
        }

        if (idpConfigSelector == null) {
            if (defaultIdpConfig == null) {
                throw new IllegalStateException("default identity provider does not exist");
            }

            // Configure a default identity provider selector which always returns a default identity provider.
            final SamlIdentityProviderConfig defaultConfig = defaultIdpConfig;
            idpConfigSelector =
                    (unused1, unused2, unused3) -> CompletableFuture.completedFuture(defaultConfig);
        }

        // entityID would be used as a secret by default.
        try {
            requestIdManager = firstNonNull(requestIdManager,
                                            SamlRequestIdManager.ofJwt(entityId, entityId, 60, 5));
        } catch (UnsupportedEncodingException e) {
            throw new IllegalStateException("cannot create a " + SamlRequestIdManager.class.getSimpleName(),
                                            e);
        }

        return new SamlServiceProvider(authorizer,
                                       entityId,
                                       hostname,
                                       signingCredential,
                                       encryptionCredential,
                                       signatureAlgorithm,
                                       hostConfigBuilder.toAutoFiller(),
                                       metadataPath,
                                       idpConfigs.build(),
                                       defaultIdpConfig,
                                       idpConfigSelector,
                                       assertionConsumerConfigs,
                                       sloEndpoints,
                                       requestIdManager,
                                       ssoHandler,
                                       sloHandler);
    }

    private static void validateSignatureAlgorithm(String signatureAlgorithm, Credential credential) {
        final String jcaAlgorithmID = AlgorithmSupport.getAlgorithmID(signatureAlgorithm);
        if (jcaAlgorithmID == null) {
            throw new IllegalStateException("unsupported signature algorithm: " + signatureAlgorithm);
        }
        try {
            final Signature signature = Signature.getInstance(jcaAlgorithmID);
            final PrivateKey key = credential.getPrivateKey();
            if (key != null) {
                signature.initSign(key);
            } else {
                signature.initVerify(credential.getPublicKey());
            }
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException("unsupported signature algorithm: " + signatureAlgorithm, e);
        } catch (InvalidKeyException e) {
            throw new IllegalStateException("failed to initialize a signature with an algorithm: " +
                                            signatureAlgorithm, e);
        }
    }

    /**
     * An adapter for {@link CredentialResolver} which helps to resolve a {@link Credential} from
     * the specified {@code keyName}.
     */
    static class CredentialResolverAdapter implements Function {
        private final CredentialResolver resolver;

        CredentialResolverAdapter(CredentialResolver resolver) {
            this.resolver = requireNonNull(resolver, "resolver");
        }

        @Nullable
        @Override
        public Credential apply(String keyName) {
            final CriteriaSet cs = new CriteriaSet();
            cs.add(new EntityIdCriterion(keyName));
            try {
                return resolver.resolveSingle(cs);
            } catch (Throwable cause) {
                return Exceptions.throwUnsafely(cause);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy