All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.pac4j.saml.transport.AbstractPac4jDecoder Maven / Gradle / Ivy

package org.pac4j.saml.transport;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import net.shibboleth.shared.codec.Base64Support;
import net.shibboleth.shared.component.ComponentInitializationException;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.xml.ParserPool;
import net.shibboleth.shared.xml.XMLParserException;
import org.apache.hc.core5.net.URLEncodedUtils;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.io.UnmarshallingException;
import org.opensaml.core.xml.util.XMLObjectSupport;
import org.opensaml.messaging.decoder.AbstractMessageDecoder;
import org.opensaml.messaging.decoder.MessageDecodingException;
import org.opensaml.saml.common.binding.SAMLBindingSupport;
import org.opensaml.saml.common.messaging.context.SAMLBindingContext;
import org.pac4j.core.context.WebContext;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.saml.context.SAML2MessageContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

/**
 * Common decoder.
 *
 * @author Jerome Leleu
 * @since 3.4.0
 */
public abstract class AbstractPac4jDecoder extends AbstractMessageDecoder {

    static final String[] SAML_PARAMETERS = {"SAMLRequest", "SAMLResponse", "logoutRequest"};

    protected final Logger logger = LoggerFactory.getLogger(getClass());

    /** Parser pool used to deserialize the message. */
    protected ParserPool parserPool;

    protected final WebContext context;

    public AbstractPac4jDecoder(final WebContext context) {
        CommonHelper.assertNotNull("context", context);
        this.context = context;
    }

    protected byte[] getBase64DecodedMessage() throws MessageDecodingException {
        Optional encodedMessage = Optional.empty();
        for (final var parameter : SAML_PARAMETERS) {
            encodedMessage = this.context.getRequestParameter(parameter);
            if (encodedMessage.isPresent()) {
                break;
            }
        }
        if (!encodedMessage.isPresent()) {
            encodedMessage = Optional.ofNullable(this.context.getRequestContent());
            // we have a body, it may be the SAML request/response directly
            // but we also try to parse it as a list key=value where the value is the SAML request/response
            if (encodedMessage.isPresent()) {
                final var a = URLEncodedUtils.parse(encodedMessage.get(), StandardCharsets.UTF_8);
                final Multimap  paramMap = HashMultimap.create();
                for (final var p : a) {
                    paramMap.put(p.getName(), p.getValue());
                }
                for (final var parameter : SAML_PARAMETERS) {
                    final var newEncodedMessageCollection = paramMap.get(parameter);
                    if (newEncodedMessageCollection != null && !newEncodedMessageCollection.isEmpty()) {
                        encodedMessage = Optional.of(newEncodedMessageCollection.iterator().next());
                        break;
                    }
                }
            }
        }

        if (!encodedMessage.isPresent()) {
            throw new MessageDecodingException("Request did not contain either a SAMLRequest parameter, a SAMLResponse parameter, "
                + "a logoutRequest parameter or a body content");
        } else {
            if (encodedMessage.get().contains("<")) {
                logger.trace("Raw SAML message:\n{}", encodedMessage);
                return encodedMessage.get().getBytes(StandardCharsets.UTF_8);
            } else {

                try {
                    final var decodedBytes = Base64Support.decode(encodedMessage.get());
                    logger.trace("Decoded SAML message:\n{}", new String(decodedBytes, StandardCharsets.UTF_8));
                    return decodedBytes;
                } catch (final Exception e) {
                    throw new MessageDecodingException(e);
                }
            }
        }
    }

    @Override
    protected void doDestroy() {
        parserPool = null;

        super.doDestroy();
    }

    @Override
    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        logger.debug("Initialized {}", this.getClass().getSimpleName());

        if (parserPool == null) {
            throw new ComponentInitializationException("Parser pool cannot be null");
        }
    }

    /**
     * Populate the context which carries information specific to this binding.
     *
     * @param messageContext the current message context
     */
    protected void populateBindingContext(final SAML2MessageContext messageContext) {
        final var bindingContext = messageContext.getMessageContext().getSubcontext(SAMLBindingContext.class, true);
        bindingContext.setBindingUri(getBindingURI(messageContext));
        bindingContext.setHasBindingSignature(false);
        bindingContext.setIntendedDestinationEndpointURIRequired(SAMLBindingSupport.isMessageSigned(messageContext.getMessageContext()));
    }

    /**
     * Get the binding of the message context;.
     *
     * @param messageContext the message context
     * @return the binding URI
     */
    public abstract String getBindingURI(SAML2MessageContext messageContext);

    /**
     * Helper method that deserializes and unmarshalls the message from the given stream.
     *
     * @param messageStream input stream containing the message
     *
     * @return the inbound message
     *
     * @throws MessageDecodingException thrown if there is a problem deserializing and unmarshalling the message
     */
    protected XMLObject unmarshallMessage(final InputStream messageStream) throws MessageDecodingException {
        try {
            final var message = XMLObjectSupport.unmarshallFromInputStream(getParserPool(), messageStream);
            return message;
        } catch (final XMLParserException e) {
            throw new MessageDecodingException("Error unmarshalling message from input stream", e);
        } catch (final UnmarshallingException e) {
            throw new MessageDecodingException("Error unmarshalling message from input stream", e);
        }
    }

    /**
     * Gets the parser pool used to deserialize incoming messages.
     *
     * @return parser pool used to deserialize incoming messages
     */
    public ParserPool getParserPool() {
        return parserPool;
    }

    /**
     * Sets the parser pool used to deserialize incoming messages.
     *
     * @param pool parser pool used to deserialize incoming messages
     */
    public void setParserPool(final ParserPool pool) {
        Constraint.isNotNull(pool, "ParserPool cannot be null");
        parserPool = pool;
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy