All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.optionfactory.keycloak.spid.metadata.SpidSpMetadataResourceProvider Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package net.optionfactory.keycloak.spid.metadata;

import org.jboss.logging.Logger;

import org.keycloak.common.util.PemUtils;
import org.keycloak.crypto.Algorithm;
import org.keycloak.crypto.KeyStatus;
import org.keycloak.crypto.KeyUse;
import org.keycloak.dom.saml.v2.metadata.AttributeConsumingServiceType;
import org.keycloak.dom.saml.v2.metadata.EndpointType;
import org.keycloak.dom.saml.v2.metadata.EntityDescriptorType;
import org.keycloak.dom.saml.v2.metadata.IndexedEndpointType;
import org.keycloak.dom.saml.v2.metadata.LocalizedNameType;
import org.keycloak.dom.saml.v2.metadata.SPSSODescriptorType;
import org.keycloak.models.KeyManager;
import org.keycloak.models.IdentityProviderModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.saml.SPMetadataDescriptor;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.exceptions.ProcessingException;
import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.common.util.StaxUtil;
import org.keycloak.saml.common.exceptions.ConfigurationException;
import org.keycloak.saml.processing.core.saml.v2.writers.SAMLMetadataWriter;
import org.keycloak.saml.processing.api.saml.v2.sig.SAML2Signature;
import org.keycloak.protocol.saml.SamlService;
import org.keycloak.protocol.saml.mappers.SamlMetadataDescriptorUpdater;
import org.keycloak.services.resource.RealmResourceProvider;

import java.io.StringWriter;
import java.math.BigInteger;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import javax.ws.rs.GET;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;

import javax.xml.crypto.dsig.CanonicalizationMethod;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.XMLStreamWriter;
import net.optionfactory.keycloak.spid.SpidIdentityProvider;
import net.optionfactory.keycloak.spid.SpidIdentityProviderConfig;
import net.optionfactory.keycloak.spid.SpidIdentityProviderFactory;
import net.optionfactory.keycloak.spid.metadata.extensions.SpidBillingContactType;
import net.optionfactory.keycloak.spid.metadata.extensions.SpidOrganizationType;
import net.optionfactory.keycloak.spid.metadata.extensions.SpidOtherContactType;

import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.keycloak.broker.provider.IdentityProviderMapper;
import org.keycloak.dom.saml.v2.metadata.KeyDescriptorType;
import org.keycloak.dom.saml.v2.metadata.KeyTypes;

public class SpidSpMetadataResourceProvider implements RealmResourceProvider {

    protected static final Logger logger = Logger.getLogger(SpidSpMetadataResourceProvider.class);

    private KeycloakSession session;

    public SpidSpMetadataResourceProvider(KeycloakSession session) {
        this.session = session;
    }

    @Override
    public Object getResource() {
        return this;
    }

    @GET
    @Produces("text/xml; charset=utf-8")
    public Response get() {
        try {
            // Retrieve all enabled SPID Identity Providers for this realms
            RealmModel realm = session.getContext().getRealm();
            List lstSpidIdentityProviders = realm.getIdentityProvidersStream()
                    .filter(t -> t.getProviderId().equals(SpidIdentityProviderFactory.PROVIDER_ID)
                    && t.isEnabled())
                    .sorted((o1, o2) -> o1.getAlias().compareTo(o2.getAlias()))
                    .collect(Collectors.toList());

            if (lstSpidIdentityProviders.size() == 0) {
                throw new Exception("No SPID providers found!");
            }

            // Create an instance of the first SPID Identity Provider in alphabetical order
            SpidIdentityProviderFactory providerFactory = new SpidIdentityProviderFactory();
            SpidIdentityProvider firstSpidProvider = providerFactory.create(session, lstSpidIdentityProviders.get(0));

            // Retrieve the context URI
            UriInfo uriInfo = session.getContext().getUri();

            //
            URI authnBinding = JBossSAMLURIConstants.SAML_HTTP_REDIRECT_BINDING.getUri();

            if (firstSpidProvider.getConfig().isPostBindingAuthnRequest()) {
                authnBinding = JBossSAMLURIConstants.SAML_HTTP_POST_BINDING.getUri();
            }

            URI endpoint = uriInfo.getBaseUriBuilder()
                    .path("realms").path(realm.getName())
                    .path("broker")
                    .path(firstSpidProvider.getConfig().getAlias())
                    .path("endpoint")
                    .build();

            boolean wantAuthnRequestsSigned = firstSpidProvider.getConfig().isWantAuthnRequestsSigned();
            boolean wantAssertionsSigned = firstSpidProvider.getConfig().isWantAssertionsSigned();
            boolean wantAssertionsEncrypted = firstSpidProvider.getConfig().isWantAssertionsEncrypted();
            String configEntityId = firstSpidProvider.getConfig().getEntityId();
            String entityId = getEntityId(configEntityId, uriInfo, realm);
            String nameIDPolicyFormat = firstSpidProvider.getConfig().getNameIDPolicyFormat();
            int attributeConsumingServiceIndex = firstSpidProvider.getConfig().getAttributeConsumingServiceIndex() != null ? firstSpidProvider.getConfig().getAttributeConsumingServiceIndex() : 1;
            String attributeConsumingServiceName = firstSpidProvider.getConfig().getAttributeConsumingServiceName();
            String[] attributeConsumingServiceNames = attributeConsumingServiceName != null ? attributeConsumingServiceName.split(",") : null;


            List signingKeys = new LinkedList<>();
            List encryptionKeys = new LinkedList<>();


            session.keys().getKeysStream(realm, KeyUse.SIG, Algorithm.RS256)
                    .filter(Objects::nonNull)
                    .filter(key -> key.getCertificate() != null)
                    .sorted(SamlService::compareKeys)
                    .forEach(key -> {
                        try {
                            Element element = SPMetadataDescriptor
                                    .buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()));
                            
                            signingKeys.add(SPMetadataDescriptor.buildKeyDescriptorType(element, KeyTypes.SIGNING, Algorithm.RS256));

                            if (key.getStatus() == KeyStatus.ACTIVE) {
                                encryptionKeys.add(SPMetadataDescriptor.buildKeyDescriptorType(element, KeyTypes.ENCRYPTION, Algorithm.RS256));
                            }

                        } catch (ParserConfigurationException e) {
                            logger.warn("Failed to export SAML SP Metadata!", e);
                            throw new RuntimeException(e);
                        }
                    });

            EntityDescriptorType entityDescriptor = SPMetadataDescriptor.buildSPDescriptor(
                    authnBinding, authnBinding, endpoint, endpoint,
                    wantAuthnRequestsSigned, wantAssertionsSigned, wantAssertionsEncrypted,
                    entityId, nameIDPolicyFormat, signingKeys, encryptionKeys);

            // Create the AttributeConsumingService
            AttributeConsumingServiceType attributeConsumingService = new AttributeConsumingServiceType(attributeConsumingServiceIndex);
            attributeConsumingService.setIsDefault(true);

            if (attributeConsumingServiceNames != null && attributeConsumingServiceNames.length > 0) {
                for (String attributeConsumingServiceNameStr : attributeConsumingServiceNames) {
                    String currentLocale = realm.getDefaultLocale() == null ? "en" : realm.getDefaultLocale();

                    String[] parsedName = attributeConsumingServiceNameStr.split("\\|", 2);
                    String serviceNameLocale = parsedName.length >= 2 ? parsedName[0] : currentLocale;

                    LocalizedNameType attributeConsumingServiceNameElement = new LocalizedNameType(serviceNameLocale);
                    attributeConsumingServiceNameElement.setValue(parsedName.length >= 2 ? parsedName[1] : attributeConsumingServiceNameStr);
                    attributeConsumingService.addServiceName(attributeConsumingServiceNameElement);
                }
            }

            // Look for the SP descriptor and add the attribute consuming service
            for (EntityDescriptorType.EDTChoiceType choiceType : entityDescriptor.getChoiceType()) {
                List descriptors = choiceType.getDescriptors();

                if (descriptors != null) {
                    for (EntityDescriptorType.EDTDescriptorChoiceType descriptor : descriptors) {
                        if (descriptor.getSpDescriptor() != null) {
                            descriptor.getSpDescriptor().addAttributeConsumerService(attributeConsumingService);
                        }
                    }
                }
            }

            // Add the attribute mappers
            realm.getIdentityProviderMappersByAliasStream(firstSpidProvider.getConfig().getAlias())
                    .forEach(mapper -> {
                        IdentityProviderMapper target = (IdentityProviderMapper) session.getKeycloakSessionFactory().getProviderFactory(IdentityProviderMapper.class, mapper.getIdentityProviderMapper());
                        if (target instanceof SamlMetadataDescriptorUpdater) {
                            SamlMetadataDescriptorUpdater metadataAttrProvider = (SamlMetadataDescriptorUpdater) target;
                            metadataAttrProvider.updateMetadata(mapper, entityDescriptor);
                        }
                    });

            // Additional EntityDescriptor customizations
            customizeEntityDescriptor(entityDescriptor, firstSpidProvider.getConfig());

            // Additional SPSSODescriptor customizations
            List assertionEndpoints = lstSpidIdentityProviders.stream()
                    .map(t -> uriInfo.getBaseUriBuilder()
                    .path("realms").path(realm.getName())
                    .path("broker")
                    .path(t.getAlias())
                    .path("endpoint")
                    .build()).collect(Collectors.toList());

            List logoutEndpoints = lstSpidIdentityProviders.stream()
                    .map(t -> uriInfo.getBaseUriBuilder()
                    .path("realms").path(realm.getName())
                    .path("broker")
                    .path(t.getAlias())
                    .path("endpoint")
                    .build()).collect(Collectors.toList());

            for (EntityDescriptorType.EDTChoiceType choiceType : entityDescriptor.getChoiceType()) {
                List descriptors = choiceType.getDescriptors();

                if (descriptors != null) {
                    for (EntityDescriptorType.EDTDescriptorChoiceType descriptor : descriptors) {
                        SPSSODescriptorType spDescriptor = descriptor.getSpDescriptor();

                        if (spDescriptor != null) {
                            customizeSpDescriptor(spDescriptor,
                                    authnBinding, authnBinding,
                                    assertionEndpoints, logoutEndpoints);
                        }
                    }
                }
            }

            String descriptor = writeEntityDescriptorWithConsistentID(entityDescriptor);

            // Metadata signing
            if (firstSpidProvider.getConfig().isSignSpMetadata()) {
                KeyManager.ActiveRsaKey activeKey = session.keys().getActiveRsaKey(realm);
                String keyName = firstSpidProvider.getConfig().getXmlSigKeyInfoKeyNameTransformer().getKeyName(activeKey.getKid(), activeKey.getCertificate());
                KeyPair keyPair = new KeyPair(activeKey.getPublicKey(), activeKey.getPrivateKey());

                Document metadataDocument = DocumentUtil.getDocument(descriptor);
                SAML2Signature signatureHelper = new SAML2Signature();
                signatureHelper.setSignatureMethod(firstSpidProvider.getSignatureAlgorithm().getXmlSignatureMethod());
                signatureHelper.setDigestMethod(firstSpidProvider.getSignatureAlgorithm().getXmlSignatureDigestMethod());

                Node nextSibling = metadataDocument.getDocumentElement().getFirstChild();
                signatureHelper.setNextSibling(nextSibling);

                signatureHelper.signSAMLDocument(metadataDocument, keyName, keyPair, CanonicalizationMethod.EXCLUSIVE);

                descriptor = DocumentUtil.getDocumentAsString(metadataDocument);
            }

            return Response.ok(descriptor, MediaType.APPLICATION_XML_TYPE).build();
        } catch (Exception e) {
            logger.warn("Failed to export SAML SP Metadata!", e);
            throw new RuntimeException(e);
        }
    }

    private String writeEntityDescriptorWithConsistentID(final EntityDescriptorType entityDescriptor) throws ProcessingException {
        // Update ID with hash of content so multiple metadata request give back same xml if configuration is same.
        entityDescriptor.setID("ID_"); // Set to fixed value before hashing
        String data = entityDescriptorAsString(entityDescriptor);
        String hash = md5hex(data);
        entityDescriptor.setID("ID_" + hash); // Update to hashed value ID
        return entityDescriptorAsString(entityDescriptor);
    }

    private String entityDescriptorAsString(final EntityDescriptorType entityDescriptor) throws ProcessingException {
        StringWriter sw = new StringWriter();
        XMLStreamWriter writer = StaxUtil.getXMLStreamWriter(sw);
        SAMLMetadataWriter metadataWriter = new SAMLMetadataWriter(writer);
        metadataWriter.writeEntityDescriptor(entityDescriptor);
        return sw.toString();
    }

    private String getEntityId(String configEntityId, UriInfo uriInfo, RealmModel realm) {
        if (configEntityId == null || configEntityId.isEmpty()) {
            return UriBuilder.fromUri(uriInfo.getBaseUri()).path("realms").path(realm.getName()).build().toString();
        } else {
            return configEntityId;
        }
    }

    private static void customizeEntityDescriptor(EntityDescriptorType entityDescriptor,
            SpidIdentityProviderConfig config)
            throws ConfigurationException {
        // Organization
        SpidOrganizationType.build(config).ifPresent(entityDescriptor::setOrganization);

        // ContactPerson type=OTHER
        SpidOtherContactType.build(config).ifPresent(entityDescriptor::addContactPerson);

        // ContactPerson type=BILLING
        SpidBillingContactType.build(config).ifPresent(entityDescriptor::addContactPerson);
    }

    private static void customizeSpDescriptor(SPSSODescriptorType spDescriptor,
            URI loginBinding, URI logoutBinding,
            List assertionEndpoints, List logoutEndpoints) {
        // Remove any existing SingleLogoutService endpoints
        List lstSingleLogoutService = spDescriptor.getSingleLogoutService();
        for (int i = lstSingleLogoutService.size() - 1; i >= 0; --i) {
            spDescriptor.removeSingleLogoutService(lstSingleLogoutService.get(i));
        }

        // Add the new SingleLogoutService endpoints
        for (URI logoutEndpoint : logoutEndpoints) {
            spDescriptor.addSingleLogoutService(new EndpointType(logoutBinding, logoutEndpoint));
        }

        // Remove any existing AssertionConsumerService endpoints
        List lstAssertionConsumerService = spDescriptor.getAssertionConsumerService();
        for (int i = lstAssertionConsumerService.size() - 1; i >= 0; --i) {
            spDescriptor.removeAssertionConsumerService(lstAssertionConsumerService.get(i));
        }

        // Add the new AssertionConsumerService endpoints
        int assertionEndpointIndex = 0;
        for (URI assertionEndpoint : assertionEndpoints) {
            IndexedEndpointType assertionConsumerEndpoint = new IndexedEndpointType(loginBinding, assertionEndpoint);
            if (assertionEndpointIndex == 0) {
                assertionConsumerEndpoint.setIsDefault(true);
            }
            assertionConsumerEndpoint.setIndex(assertionEndpointIndex);

            spDescriptor.addAssertionConsumerService(assertionConsumerEndpoint);
            assertionEndpointIndex++;
        }
    }

    private static String md5hex(String data) {
        try {
            byte[] bytes = MessageDigest.getInstance("MD5").digest(data.getBytes(StandardCharsets.UTF_8));
            return new BigInteger(1, bytes).toString(16);
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void close() {
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy