All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.coveo.saml.SamlClient Maven / Gradle / Ivy

There is a newer version: 5.0.0
Show newest version
package com.coveo.saml;

import com.sun.org.apache.xerces.internal.parsers.DOMParser;

import org.joda.time.DateTime;
import org.opensaml.DefaultBootstrap;
import org.opensaml.common.SAMLVersion;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.saml2.core.AuthnContextClassRef;
import org.opensaml.saml2.core.AuthnContextComparisonTypeEnumeration;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.Conditions;
import org.opensaml.saml2.core.Issuer;
import org.opensaml.saml2.core.NameIDPolicy;
import org.opensaml.saml2.core.RequestedAuthnContext;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.core.validator.ResponseSchemaValidator;
import org.opensaml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml2.metadata.IDPSSODescriptor;
import org.opensaml.saml2.metadata.KeyDescriptor;
import org.opensaml.saml2.metadata.SingleSignOnService;
import org.opensaml.saml2.metadata.provider.DOMMetadataProvider;
import org.opensaml.saml2.metadata.provider.MetadataProvider;
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.xml.Configuration;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.io.Marshaller;
import org.opensaml.xml.io.MarshallingException;
import org.opensaml.xml.io.UnmarshallingException;
import org.opensaml.xml.security.credential.Credential;
import org.opensaml.xml.security.credential.UsageType;
import org.opensaml.xml.security.keyinfo.KeyInfoHelper;
import org.opensaml.xml.security.x509.BasicX509Credential;
import org.opensaml.xml.security.x509.X509Util;
import org.opensaml.xml.signature.Signature;
import org.opensaml.xml.signature.SignatureValidator;
import org.opensaml.xml.signature.X509Data;
import org.opensaml.xml.util.Base64;
import org.opensaml.xml.util.XMLHelper;
import org.opensaml.xml.validation.ValidationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UnsupportedEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.namespace.QName;

public class SamlClient {
  private static final Logger logger = LoggerFactory.getLogger(SamlClient.class);
  private static boolean initializedOpenSaml = false;

  private String relyingPartyIdentifier;
  private String assertionConsumerServiceUrl;
  private String identityProviderUrl;
  private String responseIssuer;
  private Credential credential;
  private DateTime now; // used for testing only

  /**
   * Returns the url where SAML requests should be posted.
   *
   * @return the url where SAML requests should be posted.
   */
  public String getIdentityProviderUrl() {
    return identityProviderUrl;
  }

  /**
   * Sets the date that will be considered as now. This is only useful for testing.
   *
   * @param now the date to use for now.
   */
  public void setDateTimeNow(DateTime now) {
    this.now = now;
  }

  /**
   * Constructs an SAML client using explicit parameters.
   *
   * @param relyingPartyIdentifier      the identifier of the relying party.
   * @param assertionConsumerServiceUrl the url where the identity provider will post back the
   *                                    SAML response.
   * @param identityProviderUrl         the url where the SAML request will be submitted.
   * @param responseIssuer              the expected issuer ID for SAML responses.
   * @param certificate                 the base-64 encoded certificate to use to validate
   *                                    responses.
   * @throws SamlException thrown if any error occur while loading the provider information.
   */
  public SamlClient(
      String relyingPartyIdentifier,
      String assertionConsumerServiceUrl,
      String identityProviderUrl,
      String responseIssuer,
      X509Certificate certificate)
      throws SamlException {

    ensureOpenSamlIsInitialized();

    if (relyingPartyIdentifier == null) {
      throw new IllegalArgumentException("relyingPartyIdentifier");
    }
    if (identityProviderUrl == null) {
      throw new IllegalArgumentException("identityProviderUrl");
    }
    if (responseIssuer == null) {
      throw new IllegalArgumentException("responseIssuer");
    }
    if (certificate == null) {
      throw new IllegalArgumentException("certificate");
    }

    this.relyingPartyIdentifier = relyingPartyIdentifier;
    this.assertionConsumerServiceUrl = assertionConsumerServiceUrl;
    this.identityProviderUrl = identityProviderUrl;
    this.responseIssuer = responseIssuer;
    credential = getCredential(certificate);
  }

  /**
   * Builds an encoded SAML request.
   *
   * @return The base-64 encoded SAML request.
   * @throws SamlException thrown if an unexpected error occurs.
   */
  public String getSamlRequest() throws SamlException {
    AuthnRequest request = (AuthnRequest) buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
    request.setID("z" + UUID.randomUUID().toString()); // ADFS needs IDs to start with a letter

    request.setVersion(SAMLVersion.VERSION_20);
    request.setIssueInstant(DateTime.now());
    request.setProtocolBinding("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST");
    request.setAssertionConsumerServiceURL(assertionConsumerServiceUrl);

    Issuer issuer = (Issuer) buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
    issuer.setValue(relyingPartyIdentifier);
    request.setIssuer(issuer);

    NameIDPolicy nameIDPolicy = (NameIDPolicy) buildSamlObject(NameIDPolicy.DEFAULT_ELEMENT_NAME);
    nameIDPolicy.setFormat("urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified");
    request.setNameIDPolicy(nameIDPolicy);

    RequestedAuthnContext requestedAuthnContext =
        (RequestedAuthnContext) buildSamlObject(RequestedAuthnContext.DEFAULT_ELEMENT_NAME);
    requestedAuthnContext.setComparison(AuthnContextComparisonTypeEnumeration.EXACT);
    request.setRequestedAuthnContext(requestedAuthnContext);

    AuthnContextClassRef authnContextClassRef =
        (AuthnContextClassRef) buildSamlObject(AuthnContextClassRef.DEFAULT_ELEMENT_NAME);
    authnContextClassRef.setAuthnContextClassRef(
        "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport");
    requestedAuthnContext.getAuthnContextClassRefs().add(authnContextClassRef);

    StringWriter stringWriter = new StringWriter();
    try {
      Marshaller marshaller = Configuration.getMarshallerFactory().getMarshaller(request);
      Element dom = marshaller.marshall(request);
      XMLHelper.writeNode(dom, stringWriter);
    } catch (MarshallingException ex) {
      throw new SamlException("Error while marshalling SAML request to XML", ex);
    }

    logger.trace("Issuing SAML request: " + stringWriter.toString());

    try {
      return Base64.encodeBytes(stringWriter.toString().getBytes("UTF-8"));
    } catch (UnsupportedEncodingException ex) {
      throw new SamlException("Error while encoding SAML request", ex);
    }
  }

  /**
   * Decodes and validates an SAML response returned by an identity provider.
   *
   * @param encodedResponse the encoded response returned by the identity provider.
   * @return An {@link SamlResponse} object containing information decoded from the SAML response.
   * @throws SamlException if the signature is invalid, or if any other error occurs.
   */
  public SamlResponse decodeAndValidateSamlResponse(String encodedResponse) throws SamlException {
    String decodedResponse;
    try {
      decodedResponse = new String(Base64.decode(encodedResponse), "UTF-8");
    } catch (UnsupportedEncodingException ex) {
      throw new SamlException("Cannot decode base64 encoded response", ex);
    }

    logger.trace("Validating SAML response: " + decodedResponse);

    Response response;
    try {
      DOMParser parser = new DOMParser();
      parser.parse(new InputSource(new StringReader(decodedResponse)));
      response =
          (Response)
              Configuration.getUnmarshallerFactory()
                  .getUnmarshaller(parser.getDocument().getDocumentElement())
                  .unmarshall(parser.getDocument().getDocumentElement());
    } catch (IOException | SAXException | UnmarshallingException ex) {
      throw new SamlException("Cannot decode xml encoded response", ex);
    }

    validateResponse(response);
    validateAssertion(response);
    validateSignature(response);

    Assertion assertion = response.getAssertions().get(0);
    return new SamlResponse(assertion);
  }

  /**
   * Redirects an {@link HttpServletResponse} to the configured identity provider.
   *
   * @param response   The {@link HttpServletResponse}.
   * @param relayState Optional relay state that will be passed along.
   * @throws IOException   thrown if an IO error occurs.
   * @throws SamlException thrown is an unexpected error occurs.
   */
  public void redirectToIdentityProvider(HttpServletResponse response, String relayState)
      throws IOException, SamlException {
    Map values = new HashMap<>();
    values.put("SAMLRequest", getSamlRequest());
    if (relayState != null) {
      values.put("RelayState", relayState);
    }

    BrowserUtils.postUsingBrowser(identityProviderUrl, response, values);
  }

  /**
   * Processes a POST containing the SAML response.
   *
   * @param request the {@link HttpServletRequest}.
   * @return An {@link SamlResponse} object containing information decoded from the SAML response.
   * @throws SamlException thrown is an unexpected error occurs.
   */
  public SamlResponse processPostFromIdentityProvider(HttpServletRequest request)
      throws SamlException {
    String encodedResponse = request.getParameter("SAMLResponse");
    return decodeAndValidateSamlResponse(encodedResponse);
  }

  /**
   * Constructs an SAML client using XML metadata obtained from the identity provider. 

When * using Okta as an identity provider, it is possible to pass null to relyingPartyIdentifier and * assertionConsumerServiceUrl; they will be inferred from the metadata provider XML. * * @param relyingPartyIdentifier the identifier for the relying party. * @param assertionConsumerServiceUrl the url where the identity provider will post back the * SAML response. * @param metadata the XML metadata obtained from the identity provider. * @return The created {@link SamlClient}. * @throws SamlException thrown if any error occur while loading the metadata information. */ public static SamlClient fromMetadata( String relyingPartyIdentifier, String assertionConsumerServiceUrl, Reader metadata) throws SamlException { ensureOpenSamlIsInitialized(); MetadataProvider metadataProvider = createMetadataProvider(metadata); EntityDescriptor entityDescriptor = getEntityDescriptor(metadataProvider); IDPSSODescriptor idpSsoDescriptor = getIDPSSODescriptor(entityDescriptor); SingleSignOnService postBinding = getPostBinding(idpSsoDescriptor); X509Certificate x509Certificate = getCertificate(idpSsoDescriptor); boolean isOkta = entityDescriptor.getEntityID().contains(".okta.com"); if (relyingPartyIdentifier == null) { // Okta's own toolkit uses the entity ID as a relying party identifier, so if we // detect that the IDP is Okta let's tolerate a null value for this parameter. if (isOkta) { relyingPartyIdentifier = entityDescriptor.getEntityID(); } else { throw new IllegalArgumentException("relyingPartyIdentifier"); } } if (assertionConsumerServiceUrl == null && isOkta) { // Again, Okta's own toolkit uses this value for the assertion consumer url, which // kinda makes no sense since this is supposed to be a url pointing to a server // outside Okta, but it probably just straight ignores this and use the one from // it's own config anyway. assertionConsumerServiceUrl = postBinding.getLocation(); } String identityProviderUrl = postBinding.getLocation(); String responseIssuer = entityDescriptor.getEntityID(); return new SamlClient( relyingPartyIdentifier, assertionConsumerServiceUrl, identityProviderUrl, responseIssuer, x509Certificate); } private void validateResponse(Response response) throws SamlException { try { new ResponseSchemaValidator().validate(response); } catch (ValidationException ex) { throw new SamlException("The response schema validation failed", ex); } if (!response.getIssuer().getValue().equals(responseIssuer)) { throw new SamlException("The response issuer didn't match the expected value"); } String statusCode = response.getStatus().getStatusCode().getValue(); if (!statusCode.equals("urn:oasis:names:tc:SAML:2.0:status:Success")) { throw new SamlException("Invalid status code: " + statusCode); } } private void validateAssertion(Response response) throws SamlException { if (response.getAssertions().size() != 1) { throw new SamlException("The response doesn't contain exactly 1 assertion"); } Assertion assertion = response.getAssertions().get(0); if (!assertion.getIssuer().getValue().equals(responseIssuer)) { throw new SamlException("The assertion issuer didn't match the expected value"); } if (assertion.getSubject().getNameID() == null) { throw new SamlException( "The NameID value is missing from the SAML response; this is likely an IDP configuration issue"); } enforceConditions(assertion.getConditions()); } private void enforceConditions(Conditions conditions) throws SamlException { DateTime now = this.now != null ? this.now : DateTime.now(); if (now.isBefore(conditions.getNotBefore())) { throw new SamlException( "The assertion cannot be used before " + conditions.getNotBefore().toString()); } if (now.isAfter(conditions.getNotOnOrAfter())) { throw new SamlException( "The assertion cannot be used after " + conditions.getNotOnOrAfter().toString()); } } private void validateSignature(Response response) throws SamlException { // I must admit this is not yet 100% clear to me, but it seems that it's OK if either // the assertion is signed (since we take the stuff from it) OR if the response itself // is signed, since it contains the assertion. I've seen boths possibilities coming out // from ADFS and Okta. if (!validateResponseSignature(response) && !validateAssertionSignature(response)) { throw new SamlException("No signature is present in either response or assertion"); } } private boolean validateResponseSignature(Response response) throws SamlException { Signature signature = response.getSignature(); if (signature == null) { return false; } try { SignatureValidator signatureValidator = new SignatureValidator(credential); signatureValidator.validate(signature); return true; } catch (ValidationException ex) { throw new SamlException("Invalid response signature", ex); } } private boolean validateAssertionSignature(Response response) throws SamlException { Assertion assertion = response.getAssertions().get(0); Signature signature = assertion.getSignature(); if (signature == null) { return false; } try { SignatureValidator signatureValidator = new SignatureValidator(credential); signatureValidator.validate(signature); return true; } catch (ValidationException ex) { throw new SamlException("Invalid assertion signature", ex); } } private synchronized static void ensureOpenSamlIsInitialized() throws SamlException { if (!initializedOpenSaml) { try { DefaultBootstrap.bootstrap(); initializedOpenSaml = true; } catch (Throwable ex) { throw new SamlException("Error while initializing the Open SAML library", ex); } } } private static MetadataProvider createMetadataProvider(Reader metadata) throws SamlException { try { DOMParser parser = new DOMParser(); parser.parse(new InputSource(metadata)); DOMMetadataProvider provider = new DOMMetadataProvider(parser.getDocument().getDocumentElement()); provider.initialize(); return provider; } catch (IOException | SAXException | MetadataProviderException ex) { throw new SamlException("Cannot load identity provider metadata", ex); } } private static EntityDescriptor getEntityDescriptor(MetadataProvider metadataProvider) throws SamlException { EntityDescriptor descriptor; try { descriptor = (EntityDescriptor) metadataProvider.getMetadata(); } catch (MetadataProviderException ex) { throw new SamlException("Cannot retrieve the entity descriptor", ex); } if (descriptor == null) { throw new SamlException("Cannot retrieve the entity descriptor"); } return descriptor; } private static IDPSSODescriptor getIDPSSODescriptor(EntityDescriptor entityDescriptor) throws SamlException { IDPSSODescriptor idpssoDescriptor = entityDescriptor.getIDPSSODescriptor("urn:oasis:names:tc:SAML:2.0:protocol"); if (idpssoDescriptor == null) { throw new SamlException("Cannot retrieve IDP SSO descriptor"); } return idpssoDescriptor; } private static SingleSignOnService getPostBinding(IDPSSODescriptor idpSsoDescriptor) throws SamlException { return idpSsoDescriptor .getSingleSignOnServices() .stream() .filter(x -> x.getBinding().equals("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST")) .findAny() .orElseThrow(() -> new SamlException("Cannot find HTTP-POST SSO binding in metadata")); } private static X509Certificate getCertificate(IDPSSODescriptor idpSsoDescriptor) throws SamlException { KeyDescriptor keyDescriptor = idpSsoDescriptor .getKeyDescriptors() .stream() .filter(x -> x.getUse() == UsageType.SIGNING) .findAny() .orElseThrow(() -> new SamlException("Cannot find signing certificate")); X509Data data = keyDescriptor .getKeyInfo() .getX509Datas() .stream() .findFirst() .orElseThrow(() -> new SamlException("Cannot find X509 data")); org.opensaml.xml.signature.X509Certificate certificate = data.getX509Certificates() .stream() .findFirst() .orElseThrow(() -> new SamlException("Cannot find X509 certificate")); try { return KeyInfoHelper.getCertificate(certificate); } catch (CertificateException ex) { throw new SamlException("Cannot load signing certificate", ex); } } private static X509Certificate getCertificate(String certificate) throws SamlException { try { Collection certificates = X509Util.decodeCertificate(Base64.decode(certificate)); return certificates .stream() .findFirst() .orElseThrow(() -> new SamlException("Cannot load certificate")); } catch (CertificateException ex) { throw new SamlException("Cannot load certificate", ex); } } private static Credential getCredential(X509Certificate certificate) throws SamlException { BasicX509Credential credential = new BasicX509Credential(); credential.setEntityCertificate(certificate); credential.setPublicKey(certificate.getPublicKey()); credential.setCRLs(Collections.emptyList()); return credential; } private static XMLObject buildSamlObject(QName qname) { return Configuration.getBuilderFactory().getBuilder(qname).buildObject(qname); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy