All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.onelogin.saml2.util.Util Maven / Gradle / Ivy

There is a newer version: 2.9.0
Show newest version
package com.onelogin.saml2.util;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.StringReader;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Calendar;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.TimeZone;
import java.util.UUID;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.Inflater;

import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.xml.XMLConstants;
import javax.xml.namespace.NamespaceContext;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.parsers.SAXParser;
import javax.xml.parsers.SAXParserFactory;
import javax.xml.transform.Source;
import javax.xml.transform.dom.DOMSource;
import javax.xml.validation.Schema;
import javax.xml.validation.Validator;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpression;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import javax.xml.xpath.XPathFactoryConfigurationException;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.xml.security.encryption.EncryptedData;
import org.apache.xml.security.encryption.EncryptedKey;
import org.apache.xml.security.encryption.XMLCipher;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.keys.KeyInfo;
import org.apache.xml.security.signature.XMLSignature;
import org.apache.xml.security.transforms.Transforms;
import org.apache.xml.security.utils.XMLUtils;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.Period;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.joda.time.format.ISOPeriodFormat;
import org.joda.time.format.PeriodFormatter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

import com.onelogin.saml2.exception.ValidationError;
import com.onelogin.saml2.exception.XMLEntityException;
import com.onelogin.saml2.model.SamlResponseStatus;


/**
 * Util class of OneLogin's Java Toolkit.
 *
 * A class that contains several auxiliary methods related to the SAML protocol
 */
public final class Util {

	/**
     * Private property to construct a logger for this class.
     */
	private static final Logger LOGGER = LoggerFactory.getLogger(Util.class);

    private static final DateTimeFormatter DATE_TIME_FORMAT = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss'Z'").withZone(DateTimeZone.UTC);
	private static final DateTimeFormatter DATE_TIME_FORMAT_MILLS = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'").withZone(DateTimeZone.UTC);
	public static final String UNIQUE_ID_PREFIX = "ONELOGIN_";
	public static final String RESPONSE_SIGNATURE_XPATH = "/samlp:Response/ds:Signature";
	public static final String ASSERTION_SIGNATURE_XPATH = "/samlp:Response/saml:Assertion/ds:Signature";
	/** Indicates if JAXP 1.5 support has been detected. */
	private static boolean JAXP_15_SUPPORTED = isJaxp15Supported();

	static {
		System.setProperty("org.apache.xml.security.ignoreLineBreaks", "true");
		org.apache.xml.security.Init.init();
	}

	private Util() {
	      //not called
	}

	/**
	 * Method which uses the recommended way ( https://docs.oracle.com/javase/tutorial/jaxp/properties/error.html )
	 * of checking if JAXP >= 1.5 options are supported. Needed if the project which uses this library also has
	 * Xerces in it's classpath.
	 *
	 * If for whatever reason this method cannot determine if JAXP 1.5 properties are supported it will indicate the
	 * options are supported. This way we don't accidentally disable configuration options.
	 *
	 * @return
	 */
	public static boolean isJaxp15Supported() {
		boolean supported = true;

		try {
			SAXParserFactory spf = SAXParserFactory.newInstance();
			SAXParser parser = spf.newSAXParser();
			parser.setProperty("http://javax.xml.XMLConstants/property/accessExternalDTD", "file");
		} catch (SAXException ex) {
			String err = ex.getMessage();
			if (err.contains("Property 'http://javax.xml.XMLConstants/property/accessExternalDTD' is not recognized.")) {
				//expected, jaxp 1.5 not supported
				supported = false;
			}
		} catch (Exception e) {
			LOGGER.info("An exception occurred while trying to determine if JAXP 1.5 options are supported.", e);
		}

		return supported;
	}

	/**
	 * This function load an XML string in a save way. Prevent XEE/XXE Attacks
	 *
	 * @param xml
	 * 				String. The XML string to be loaded.
	 *
	 * @return The result of load the XML at the Document or null if any error occurs
	 */
	public static Document loadXML(String xml) {
		try {
			if (xml.contains(" stringLength) {
				chunkSize = stringLength - i;
			}
			newStr += str.substring(i, chunkSize + i) + '\n';
		}
		return newStr;
	}


	/**
	 * Load X.509 certificate
	 *
	 * @param certString
	 * 				 certificate in string format
	 *
	 * @return Loaded Certificate. X509Certificate object
	 *
	 * @throws CertificateException
	 *
	 */
	public static X509Certificate loadCert(String certString) throws CertificateException {
		certString = formatCert(certString, true);
		X509Certificate cert;

		try {
			cert = (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(
				new ByteArrayInputStream(certString.getBytes(StandardCharsets.UTF_8)));
		} catch (IllegalArgumentException e){
			cert = null;
		}
		return cert;
	}

	/**
	 * Load private key
	 *
	 * @param keyString
	 * 				 private key in string format
	 *
	 * @return Loaded private key. PrivateKey object
	 *
	 * @throws GeneralSecurityException
	 */
	public static PrivateKey loadPrivateKey(String keyString) throws GeneralSecurityException {
		String extractedKey = formatPrivateKey(keyString, false);
		extractedKey = chunkString(extractedKey, 64);
		KeyFactory kf = KeyFactory.getInstance("RSA");

		PrivateKey privKey;
		try {
			byte[] encoded = Base64.decodeBase64(extractedKey);
			PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(encoded);
			privKey = kf.generatePrivate(keySpec);
		}
		catch(IllegalArgumentException e) {
			privKey = null;
		}

		return privKey;
	}

	/**
	 * Calculates the fingerprint of a x509cert
	 *
	 * @param x509cert
	 * 				 x509 certificate
	 * @param alg
	 * 				 Digest Algorithm
	 *
	 * @return the formated fingerprint
	 */
	public static String calculateX509Fingerprint(X509Certificate x509cert, String alg) {
		String fingerprint = StringUtils.EMPTY;

		try {
			byte[] dataBytes = x509cert.getEncoded();
			if (alg == null || alg.isEmpty() || alg.equals("SHA-1")|| alg.equals("sha1")) {
				fingerprint = DigestUtils.sha1Hex(dataBytes);
			} else if (alg.equals("SHA-256") || alg .equals("sha256")) {
				fingerprint = DigestUtils.sha256Hex(dataBytes);
			} else if (alg.equals("SHA-384") || alg .equals("sha384")) {
				fingerprint = DigestUtils.sha384Hex(dataBytes);
			} else if (alg.equals("SHA-512") || alg.equals("sha512")) {
				fingerprint = DigestUtils.sha512Hex(dataBytes);
			} else {
				LOGGER.debug("Error executing calculateX509Fingerprint. alg " + alg + " not supported");
			}
		} catch (Exception e) {
			LOGGER.debug("Error executing calculateX509Fingerprint: "+ e.getMessage(), e);
		}
		return fingerprint.toLowerCase();
	}

	/**
	 * Calculates the SHA-1 fingerprint of a x509cert
	 *
	 * @param x509cert
	 * 				 x509 certificate
	 *
	 * @return the SHA-1 formated fingerprint
	 */
	public static String calculateX509Fingerprint(X509Certificate x509cert) {
		return calculateX509Fingerprint(x509cert, "SHA-1");
	}

	/**
	 * Converts an X509Certificate in a well formated PEM string
	 *
	 * @param certificate
	 * 				 The public certificate
	 *
	 * @return the formated PEM string
	 */
	public static String convertToPem(X509Certificate certificate) {
		String pemCert = "";
		try {
			Base64 encoder = new Base64(64);
			String cert_begin = "-----BEGIN CERTIFICATE-----\n";
			String end_cert = "-----END CERTIFICATE-----";

			byte[] derCert = certificate.getEncoded();
			String pemCertPre = new String(encoder.encode(derCert));
			pemCert = cert_begin + pemCertPre + end_cert;

		} catch (Exception e) {
			LOGGER.debug("Error converting certificate on PEM format: "+ e.getMessage(), e);
		}
		return pemCert;
	}

	/**
	 * Loads a resource located at a relative path
	 *
	 * @param relativeResourcePath
	 *				Relative path of the resource
	 *
	 * @return the loaded resource in String format
	 *
	 * @throws IOException
	 */
	public static String getFileAsString(String relativeResourcePath) throws IOException {
		InputStream is = Util.class.getResourceAsStream("/" + relativeResourcePath);
		if (is == null) {
			throw new FileNotFoundException(relativeResourcePath);
		}

		try {
            ByteArrayOutputStream bytes = new ByteArrayOutputStream();
            copyBytes(new BufferedInputStream(is), bytes);

            return bytes.toString("utf-8");
        } finally {
            is.close();
        }
	}

	private static void copyBytes(InputStream is, OutputStream bytes) throws IOException {
		int res = is.read();
		while (res != -1) {
			bytes.write(res);
			res = is.read();
		}
	}

	/**
	 * Returns String Base64 decoded and inflated
	 *
	 * @param input
	 *				String input
	 *
	 * @return the base64 decoded and inflated string
	 */
	public static String base64decodedInflated(String input) {
		if (input.isEmpty()) {
			return input;
		}
		// Base64 decoder
		byte[] decoded = Base64.decodeBase64(input);

		// Inflater
		try {
			Inflater decompresser = new Inflater(true);
		    decompresser.setInput(decoded);
		    byte[] result = new byte[1024];
		    String inflated = "";
		    long limit = 0;
		    while(!decompresser.finished() && limit < 150) {
		    	int resultLength = decompresser.inflate(result);
		    	limit += 1;
		    	inflated += new String(result, 0, resultLength, "UTF-8");
		    }
		    decompresser.end();
		    return inflated;
		} catch (Exception e) {
			return new String(decoded);
		}
	}

	/**
	 * Returns String Deflated and base64 encoded
	 *
	 * @param input
	 *				String input
	 *
	 * @return the deflated and base64 encoded string
	 * @throws IOException
	 */
	public static String deflatedBase64encoded(String input) throws IOException {
		// Deflater
		ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
		Deflater deflater = new Deflater(Deflater.DEFLATED, true);
		DeflaterOutputStream deflaterStream = new DeflaterOutputStream(bytesOut, deflater);
		deflaterStream.write(input.getBytes(Charset.forName("UTF-8")));
		deflaterStream.finish();
		// Base64 encoder
		return new String(Base64.encodeBase64(bytesOut.toByteArray()));
	}

	/**
	 * Returns String base64 encoded
	 *
	 * @param input
	 *				Stream input
	 *
	 * @return the base64 encoded string
	 */
	public static String base64encoder(byte [] input) {
		return toStringUtf8(Base64.encodeBase64(input));
	}

	/**
	 * Returns String base64 encoded
	 *
	 * @param input
	 * 				 String input
	 *
	 * @return the base64 encoded string
	 */
	public static String base64encoder(String input) {
		return base64encoder(toBytesUtf8(input));
	}

	/**
	 * Returns String base64 decoded
	 *
	 * @param input
	 * 				 Stream input
	 *
	 * @return the base64 decoded bytes
	 */
	public static byte[] base64decoder(byte [] input) {
		return Base64.decodeBase64(input);
	}

	/**
	 * Returns String base64 decoded
	 *
	 * @param input
	 * 				 String input
	 *
	 * @return the base64 decoded bytes
	 */
	public static byte[] base64decoder(String input) {
		return base64decoder(toBytesUtf8(input));
	}

	/**
	 * Returns String URL encoded
	 *
	 * @param input
	 * 				 String input
	 *
	 * @return the URL encoded string
	 */
	public static String urlEncoder(String input) {
		if (input != null) {
			try {
				return URLEncoder.encode(input, "UTF-8");
			} catch (UnsupportedEncodingException e) {
				LOGGER.error("URL encoder error.", e);
				throw new IllegalArgumentException();
			}
		} else {
			return null;
		}
	}

	/**
	 * Returns String URL decoded
	 *
	 * @param input
	 * 				 URL encoded input
	 *
	 * @return the URL decoded string
	 */
	public static String urlDecoder(String input) {
		if (input != null) {
			try {
				return URLDecoder.decode(input, "UTF-8");
			} catch (UnsupportedEncodingException e) {
				LOGGER.error("URL decoder error.", e);
				throw new IllegalArgumentException();
			}
		} else {
			return null;
		}
	}

	/**
	 * Generates a signature from a string
	 *
	 * @param text
	 * 				 The string we should sign
	 * @param key
	 * 				 The private key to sign the string
	 * @param signAlgorithm
	 * 				 Signature algorithm method
	 *
	 * @return the signature
	 *
	 * @throws NoSuchAlgorithmException
	 * @throws InvalidKeyException
	 * @throws SignatureException
	 */
	public static byte[] sign(String text, PrivateKey key, String signAlgorithm) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException {
        if (signAlgorithm == null) {
        	signAlgorithm = Constants.RSA_SHA1;
        }

        Signature instance = Signature.getInstance(signatureAlgConversion(signAlgorithm));
        instance.initSign(key);
        instance.update(text.getBytes());
        byte[] signature = instance.sign();

        return signature;
	}

	/**
	 * Converts Signature algorithm method name
	 *
	 * @param sign
	 * 				 signature algorithm method
	 *
	 * @return the converted signature name
	 */
	public static String signatureAlgConversion(String sign) {
		String convertedSignatureAlg = "";

		if (sign == null) {
			convertedSignatureAlg = "SHA1withRSA";
		} else if (sign.equals(Constants.DSA_SHA1)) {
			convertedSignatureAlg = "SHA1withDSA";
		} else if (sign.equals(Constants.RSA_SHA256)) {
			convertedSignatureAlg = "SHA256withRSA";
		} else if (sign.equals(Constants.RSA_SHA384)) {
			convertedSignatureAlg = "SHA384withRSA";
		} else if (sign.equals(Constants.RSA_SHA512)) {
			convertedSignatureAlg = "SHA512withRSA";
		} else {
			convertedSignatureAlg = "SHA1withRSA";
		}

		return convertedSignatureAlg;
	}

	/**
	 * Validate the signature pointed to by the xpath
	 *
	 * @param doc The document we should validate
	 * @param cert The public certificate
	 * @param fingerprint The fingerprint of the public certificate
	 * @param alg The signature algorithm method
	 * @param xpath the xpath of the ds:Signture node to validate
	 *
	 * @return True if the signature exists and is valid, false otherwise.
	 */
	public static boolean validateSign(final Document doc, final X509Certificate cert, final String fingerprint,
									   final String alg, final String xpath) {
		try {
			final NodeList signatures = query(doc, xpath);
			return signatures.getLength() == 1 && validateSignNode(signatures.item(0), cert, fingerprint, alg);
		} catch (XPathExpressionException e) {
			LOGGER.warn("Failed to find signature nodes", e);
		}
		return false;
	}

	/**
	 * Validate the signature pointed to by the xpath
	 *
	 * @param doc The document we should validate
	 * @param certs The public certificates
	 * @param fingerprint The fingerprint of the public certificate
	 * @param alg The signature algorithm method
	 * @param xpath the xpath of the ds:Signture node to validate
	 *
	 * @return True if the signature exists and is valid, false otherwise.
	 */
	public static boolean validateSign(final Document doc, final List certList, final String fingerprint,
									   final String alg, final String xpath) {
		try {
			final NodeList signatures = query(doc, xpath);

			if (signatures.getLength() == 1) {
				final Node signNode = signatures.item(0);
				if (certList == null || certList.isEmpty()) {
					return validateSignNode(signNode, null, fingerprint, alg);
				} else {
					for (X509Certificate cert : certList) {
						if (validateSignNode(signNode, cert, fingerprint, alg))
							return true;
					}
				}
			}
		} catch (XPathExpressionException e) {
			LOGGER.warn("Failed to find signature nodes", e);
		}
		return false;
	}

	/**
     * Validate signature (Metadata).
     *
     * @param doc
     *               The document we should validate
     * @param cert
     *               The public certificate
     * @param fingerprint
     *               The fingerprint of the public certificate
     * @param alg
     *               The signature algorithm method
     *
     * @return True if the sign is valid, false otherwise.
     */
    public static Boolean validateMetadataSign(Document doc, X509Certificate cert, String fingerprint, String alg) {
        NodeList signNodesToValidate;
		try {
			signNodesToValidate = query(doc, "/md:EntitiesDescriptor/ds:Signature");

			if (signNodesToValidate.getLength() == 0) {
				signNodesToValidate = query(doc, "/md:EntityDescriptor/ds:Signature");

				if (signNodesToValidate.getLength() == 0) {
					signNodesToValidate = query(doc, "/md:EntityDescriptor/md:SPSSODescriptor/ds:Signature|/md:EntityDescriptor/IDPSSODescriptor/ds:Signature");
				}
			}

			if (signNodesToValidate.getLength() > 0) {
				for (int i = 0; i < signNodesToValidate.getLength(); i++) {
					Node signNode =  signNodesToValidate.item(i);
					if (!validateSignNode(signNode, cert, fingerprint, alg)) {
						return false;
					}
				}
				return true;
			}
		} catch (XPathExpressionException e) {
			LOGGER.warn("Failed to find signature nodes", e);
		}
		return false;
    }

	/**
	 * Validate signature of the Node.
	 *
	 * @param signNode
	 * 				 The document we should validate
	 * @param cert
	 * 				 The public certificate
	 * @param fingerprint
	 * 				 The fingerprint of the public certificate
	 * @param alg
	 * 				 The signature algorithm method
	 *
	 * @return True if the sign is valid, false otherwise.
	 */
	public static Boolean validateSignNode(Node signNode, X509Certificate cert, String fingerprint, String alg) {
		Boolean res = false;
		try {
			Element sigElement = (Element) signNode;
			XMLSignature signature = new XMLSignature(sigElement, "", true);

			String sigMethodAlg = signature.getSignedInfo().getSignatureMethodURI();
			if (!isAlgorithmWhitelisted(sigMethodAlg)){
				throw new Exception(sigMethodAlg + " is not a valid supported algorithm");
			}

			if (cert != null) {
				res = signature.checkSignatureValue(cert);
			} else {
				KeyInfo keyInfo = signature.getKeyInfo();
				if (fingerprint != null && keyInfo != null && keyInfo.containsX509Data()) {
					X509Certificate providedCert = keyInfo.getX509Certificate();
					String calculatedFingerprint = calculateX509Fingerprint(providedCert, alg);
					for (String fingerprintStr : fingerprint.split(",")) {
						if (calculatedFingerprint.equals(fingerprintStr.trim())) {
							res = signature.checkSignatureValue(providedCert);
						}
					}
				}
			}
		} catch (Exception e) {
			LOGGER.warn("Error executing validateSignNode: " + e.getMessage(), e);
		}
		return res;
	}

	/**
	 * Whitelist the XMLSignature algorithm
	 *
	 * @param signNode
	 * 				 The document we should validate
	 * @param cert
	 * 				 The public certificate
	 * @param fingerprint
	 * 				 The fingerprint of the public certificate
	 * @param alg
	 * 				 The signature algorithm method
	 *
	 * @return True if the sign is valid, false otherwise.
	 */
	public static boolean isAlgorithmWhitelisted(String alg) {
		Set whiteListedAlgorithm = new HashSet();
		whiteListedAlgorithm.add(Constants.DSA_SHA1);
		whiteListedAlgorithm.add(Constants.RSA_SHA1);
		whiteListedAlgorithm.add(Constants.RSA_SHA256);
		whiteListedAlgorithm.add(Constants.RSA_SHA384);
		whiteListedAlgorithm.add(Constants.RSA_SHA512);

		Boolean whitelisted = false;
		if (whiteListedAlgorithm.contains(alg)) {
			whitelisted = true;
		}

		return whitelisted;
	}

	/**
	 * Decrypt an encrypted element.
	 *
	 * @param encryptedDataElement
	 * 				 The encrypted element.
	 * @param inputKey
	 * 				 The private key to decrypt.
	 */
	public static void decryptElement(Element encryptedDataElement, PrivateKey inputKey) {
		try {
			XMLCipher xmlCipher = XMLCipher.getInstance();
			xmlCipher.init(XMLCipher.DECRYPT_MODE, null);

			/* Check if we have encryptedData with a KeyInfo that contains a RetrievalMethod to obtain the EncryptedKey.
			   xmlCipher is not able to handle that so we move the EncryptedKey inside the KeyInfo element and
			   replacing the RetrievalMethod.
			*/

			NodeList keyInfoInEncData = encryptedDataElement.getElementsByTagNameNS(Constants.NS_DS, "KeyInfo");
			if (keyInfoInEncData.getLength() == 0) {
				throw new ValidationError("No KeyInfo inside EncryptedData element", ValidationError.KEYINFO_NOT_FOUND_IN_ENCRYPTED_DATA);
			}

			NodeList childs = keyInfoInEncData.item(0).getChildNodes();
			for (int i=0; i < childs.getLength(); i++) {
				if (childs.item(i).getLocalName() != null && childs.item(i).getLocalName().equals("RetrievalMethod")) {
					Element retrievalMethodElem = (Element)childs.item(i);
					if (!retrievalMethodElem.getAttribute("Type").equals("http://www.w3.org/2001/04/xmlenc#EncryptedKey")) {
						throw new ValidationError("Unsupported Retrieval Method found", ValidationError.UNSUPPORTED_RETRIEVAL_METHOD);
					}

					String uri = retrievalMethodElem.getAttribute("URI").substring(1);

					NodeList encryptedKeyNodes = ((Element) encryptedDataElement.getParentNode()).getElementsByTagNameNS(Constants.NS_XENC, "EncryptedKey");
					for (int j=0; j < encryptedKeyNodes.getLength(); j++) {
						if (((Element)encryptedKeyNodes.item(j)).getAttribute("Id").equals(uri)) {
							keyInfoInEncData.item(0).replaceChild(encryptedKeyNodes.item(j), childs.item(i));
						}
					}
				}
			}

			xmlCipher.setKEK(inputKey);
			xmlCipher.doFinal(encryptedDataElement.getOwnerDocument(), encryptedDataElement, false);
		} catch (Exception e) {
			LOGGER.warn("Error executing decryption: " + e.getMessage(), e);
		}
	}

	/**
	 * Clone a Document object.
	 *
	 * @param source
	 * 				 The Document object to be cloned.
	 *
 	 * @return the clone of the Document object
 	 *
	 * @throws ParserConfigurationException
	 */
	public static Document copyDocument(Document source) throws ParserConfigurationException
	{
		DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
	  	dbf.setNamespaceAware(true);
        DocumentBuilder db = dbf.newDocumentBuilder();

        Node originalRoot = source.getDocumentElement();

        Document copiedDocument = db.newDocument();
        Node copiedRoot = copiedDocument.importNode(originalRoot, true);
        copiedDocument.appendChild(copiedRoot);

        return copiedDocument;
	}

	/**
	 * Signs the Document using the specified signature algorithm with the private key and the public certificate.
	 *
	 * @param document
	 * 				 The document to be signed
	 * @param key
	 * 				 The private key
	 * @param certificate
	 * 				 The public certificate
	 * @param signAlgorithm
	 * 				 Signature Algorithm
	 *
	 * @return the signed document in string format
	 *
	 * @throws XMLSecurityException
	 * @throws XPathExpressionException
	 */
	public static String addSign(Document document, PrivateKey key, X509Certificate certificate, String signAlgorithm) throws XMLSecurityException, XPathExpressionException {
		return addSign(document, key, certificate, signAlgorithm, Constants.SHA1);
	}

	/**
	 * Signs the Document using the specified signature algorithm with the private key and the public certificate.
	 *
	 * @param document
	 * 				 The document to be signed
	 * @param key
	 * 				 The private key
	 * @param certificate
	 * 				 The public certificate
	 * @param signAlgorithm
	 * 				 Signature Algorithm
	 * @param digestAlgorithm
	 * 				 Digest Algorithm
	 *
	 * @return the signed document in string format
	 *
	 * @throws XMLSecurityException
	 * @throws XPathExpressionException
	 */
	public static String addSign(Document document, PrivateKey key, X509Certificate certificate, String signAlgorithm, String digestAlgorithm) throws XMLSecurityException, XPathExpressionException {
		// Check arguments.
		if (document == null) {
			throw new IllegalArgumentException("Provided document was null");
		}

		if (document.getDocumentElement() == null) {
			throw new IllegalArgumentException("The Xml Document has no root element.");
		}

		if (key == null) {
			throw new IllegalArgumentException("Provided key was null");
		}

		if (certificate == null) {
			throw new IllegalArgumentException("Provided certificate was null");
		}

		if (signAlgorithm == null || signAlgorithm.isEmpty()) {
			signAlgorithm = Constants.RSA_SHA1;
		}
		if (digestAlgorithm == null || digestAlgorithm.isEmpty()) {
			digestAlgorithm = Constants.SHA1;
		}

		document.normalizeDocument();

		String c14nMethod = Constants.C14NEXC;

		// Signature object
		XMLSignature sig = new XMLSignature(document, null, signAlgorithm, c14nMethod);

		// Including the signature into the document before sign, because
		// this is an envelop signature
		Element root = document.getDocumentElement();
		document.setXmlStandalone(false);

		// If Issuer, locate Signature after Issuer, Otherwise as first child.
		NodeList issuerNodes = Util.query(document, "//saml:Issuer", null);
		Element elemToSign = null;
		if (issuerNodes.getLength() > 0) {
			Node issuer =  issuerNodes.item(0);
			root.insertBefore(sig.getElement(), issuer.getNextSibling());
			elemToSign = (Element) issuer.getParentNode();
		} else {
			NodeList entitiesDescriptorNodes = Util.query(document, "//md:EntitiesDescriptor", null);
			if (entitiesDescriptorNodes.getLength() > 0) {
				elemToSign = (Element)entitiesDescriptorNodes.item(0);
			} else {
				NodeList entityDescriptorNodes = Util.query(document, "//md:EntityDescriptor", null);
				if (entityDescriptorNodes.getLength() > 0) {
					elemToSign = (Element)entityDescriptorNodes.item(0);
				} else {
					elemToSign = root;
				}
			}
			root.insertBefore(sig.getElement(), elemToSign.getFirstChild());
		}

		String id = elemToSign.getAttribute("ID");

		String reference = id;
		if (!id.isEmpty()) {
			elemToSign.setIdAttributeNS(null, "ID", true);
			reference = "#" + id;
		}

		// Create the transform for the document
		Transforms transforms = new Transforms(document);
		transforms.addTransform(Constants.ENVSIG);
		transforms.addTransform(c14nMethod);
		sig.addDocument(reference, transforms, digestAlgorithm);

		// Add the certification info
		sig.addKeyInfo(certificate);

		// Sign the document
		sig.sign(key);

		return convertDocumentToString(document, true);
	}

	/**
	 * Signs a Node using the specified signature algorithm with the private key and the public certificate.
	 *
	 * @param node
	 * 				 The Node to be signed
	 * @param key
	 * 				 The private key
	 * @param certificate
	 * 				 The public certificate
	 * @param signAlgorithm
	 * 				 Signature Algorithm
	 * @param digestAlgorithm
	 * 				 Digest Algorithm
	 *
	 * @return the signed document in string format
	 *
	 * @throws ParserConfigurationException
	 * @throws XMLSecurityException
	 * @throws XPathExpressionException
	 */
	public static String addSign(Node node, PrivateKey key, X509Certificate certificate, String signAlgorithm, String digestAlgorithm) throws ParserConfigurationException, XPathExpressionException, XMLSecurityException {
		// Check arguments.
		if (node == null) {
			throw new IllegalArgumentException("Provided node was null");
		}

		DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
	  	dbf.setNamespaceAware(true);
		Document doc = dbf.newDocumentBuilder().newDocument();
		Node newNode = doc.importNode(node, true);
		doc.appendChild(newNode);

		return addSign(doc, key, certificate, signAlgorithm, digestAlgorithm);
	}

	/**
	 * Signs a Node using the specified signature algorithm with the private key and the public certificate.
	 *
	 * @param node
	 * 				 The Node to be signed
	 * @param key
	 * 				 The private key
	 * @param certificate
	 * 				 The public certificate
	 * @param signAlgorithm
	 * 				 Signature Algorithm
	 *
	 * @return the signed document in string format
	 *
	 * @throws ParserConfigurationException
	 * @throws XMLSecurityException
	 * @throws XPathExpressionException
	 */
	public static String addSign(Node node, PrivateKey key, X509Certificate certificate, String signAlgorithm) throws ParserConfigurationException, XPathExpressionException, XMLSecurityException {
		return addSign(node, key, certificate, signAlgorithm, Constants.SHA1);
	}

	/**
	 * Validates signed binary data (Used to validate GET Signature).
	 *
	 * @param signedQuery
	 * 				 The element we should validate
	 * @param signature
	 * 				 The signature that will be validate
	 * @param cert
	 * 				 The public certificate
	 * @param signAlg
	 * 				 Signature Algorithm
	 *
	 * @return the signed document in string format
	 *
	 * @throws NoSuchAlgorithmException
	 * @throws NoSuchProviderException
	 * @throws InvalidKeyException
	 * @throws SignatureException
	 */
	public static Boolean validateBinarySignature(String signedQuery, byte[] signature, X509Certificate cert, String signAlg) throws NoSuchAlgorithmException, NoSuchProviderException, InvalidKeyException, SignatureException {
		Boolean valid = false;
		try {
			String convertedSigAlg = signatureAlgConversion(signAlg);

			Signature sig = Signature.getInstance(convertedSigAlg); //, provider);
			sig.initVerify(cert.getPublicKey());
			sig.update(signedQuery.getBytes());

			valid = sig.verify(signature);
		} catch (Exception e) {
			LOGGER.warn("Error executing validateSign: " + e.getMessage(), e);
		}
		return valid;
	}

	/**
	 * Validates signed binary data (Used to validate GET Signature).
	 *
	 * @param signedQuery
	 * 				 The element we should validate
	 * @param signature
	 * 				 The signature that will be validate
	 * @param certList
	 * 				 The List of certificates
	 * @param signAlg
	 * 				 Signature Algorithm
	 *
	 * @return the signed document in string format
	 *
	 * @throws NoSuchAlgorithmException
	 * @throws NoSuchProviderException
	 * @throws InvalidKeyException
	 * @throws SignatureException
	 */
	public static Boolean validateBinarySignature(String signedQuery, byte[] signature, List certList, String signAlg) throws NoSuchAlgorithmException, NoSuchProviderException, InvalidKeyException, SignatureException {
		Boolean valid = false;

		String convertedSigAlg = signatureAlgConversion(signAlg);
		Signature sig = Signature.getInstance(convertedSigAlg); //, provider);

		for (X509Certificate cert : certList) {
			try {
				sig.initVerify(cert.getPublicKey());
				sig.update(signedQuery.getBytes());
				valid = sig.verify(signature);
				if (valid) {
					break;
				}
			} catch (Exception e) {
				LOGGER.warn("Error executing validateSign: " + e.getMessage(), e);
			}
		}
		return valid;
	}

	/**
	 * Get Status from a Response
	 *
	 * @param dom
	 *            The Response as XML
	 *
	 * @return SamlResponseStatus
	 *
	 * @throws IllegalArgumentException
	 * @throws ValidationError
	 */
	public static SamlResponseStatus getStatus(String statusXpath, Document dom) throws ValidationError {
		try {
			NodeList statusEntry = Util.query(dom, statusXpath, null);
			if (statusEntry.getLength() != 1) {
				throw new ValidationError("Missing Status on response", ValidationError.MISSING_STATUS);
			}
			NodeList codeEntry = Util.query(dom, statusXpath + "/samlp:StatusCode", (Element) statusEntry.item(0));

			if (codeEntry.getLength() == 0) {
				throw new ValidationError("Missing Status Code on response", ValidationError.MISSING_STATUS_CODE);
			}
			String stausCode = codeEntry.item(0).getAttributes().getNamedItem("Value").getNodeValue();
			SamlResponseStatus status = new SamlResponseStatus(stausCode);

			NodeList subStatusCodeEntry = Util.query(dom, statusXpath + "/samlp:StatusCode/samlp:StatusCode", (Element) statusEntry.item(0));
			if (subStatusCodeEntry.getLength() > 0) {
				String subStatusCode = subStatusCodeEntry.item(0).getAttributes().getNamedItem("Value").getNodeValue();
				status.setSubStatusCode(subStatusCode);
			}

			NodeList messageEntry = Util.query(dom, statusXpath + "/samlp:StatusMessage", (Element) statusEntry.item(0));
			if (messageEntry.getLength() == 1) {
				status.setStatusMessage(messageEntry.item(0).getTextContent());
			}

			return status;
		} catch (XPathExpressionException e) {
			String error = "Unexpected error in getStatus." +  e.getMessage();
			LOGGER.error(error);
			throw new IllegalArgumentException(error);
		}
	}

	/**
	 * Generates a nameID.
	 *
	 * @param value
	 * 				 The value
	 * @param spnq
	 * 				 SP Name Qualifier
	 * @param format
	 * 				 SP Format
	 * @param nq
	 * 				 Name Qualifier
	 * @param cert
	 * 				 IdP Public certificate to encrypt the nameID
	 *
	 * @return Xml contained in the document.
	 */
	public static String generateNameId(String value, String spnq, String format, String nq, X509Certificate cert) {
		String res = null;
		try {
		  	DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
		  	dbf.setNamespaceAware(true);
			Document doc = dbf.newDocumentBuilder().newDocument();
			Element nameId = doc.createElement("saml:NameID");

			if (spnq != null && !spnq.isEmpty()) {
				nameId.setAttribute("SPNameQualifier", spnq);
			}
			if (format != null && !format.isEmpty()) {
				nameId.setAttribute("Format", format);
			}
			if ((nq != null) && !nq.isEmpty())
			{
				nameId.setAttribute("NameQualifier", nq);
			}

			nameId.appendChild(doc.createTextNode(value));
			doc.appendChild(nameId);

			if (cert != null) {
				// We generate a symmetric key
				Key symmetricKey = generateSymmetricKey();

				// cipher for encrypt the data
				XMLCipher xmlCipher = XMLCipher.getInstance(Constants.AES128_CBC);
				xmlCipher.init(XMLCipher.ENCRYPT_MODE, symmetricKey);

				// cipher for encrypt the symmetric key
				XMLCipher keyCipher = XMLCipher.getInstance(Constants.RSA_1_5);
				keyCipher.init(XMLCipher.WRAP_MODE, cert.getPublicKey());

				// encrypt the symmetric key
				EncryptedKey encryptedKey = keyCipher.encryptKey(doc, symmetricKey);

				// Add keyinfo inside the encrypted data
				EncryptedData encryptedData = xmlCipher.getEncryptedData();
				KeyInfo keyInfo = new KeyInfo(doc);
				keyInfo.add(encryptedKey);
				encryptedData.setKeyInfo(keyInfo);

				// Encrypt the actual data
				xmlCipher.doFinal(doc, nameId, false);

				// Building the result
				res = "" + convertDocumentToString(doc) + "";
			} else {
				res = convertDocumentToString(doc);
			}
		} catch (Exception e) {
			LOGGER.error("Error executing generateNameId: " + e.getMessage(), e);
		}
		return res;
	}

	/**
	 * Generates a nameID.
	 *
	 * @param value
	 * 				 The value
	 * @param spnq
	 * 				 SP Name Qualifier
	 * @param format
	 * 				 SP Format
	 * @param cert
	 * 				 IdP Public certificate to encrypt the nameID
	 *
	 * @return Xml contained in the document.
	 */
	public static String generateNameId(String value, String spnq, String format, X509Certificate cert) {
		return generateNameId(value, spnq, format, null, cert);
	}

	/**
	 * Generates a nameID.
	 *
	 * @param value
	 * 				 The value
	 * @param spnq
	 * 				 SP Name Qualifier
	 * @param format
	 * 				 SP Format
	 *
	 * @return Xml contained in the document.
	 */
	public static String generateNameId(String value, String spnq, String format) {
		return generateNameId(value, spnq, format, null);
	}

	/**
	 * Generates a nameID.
	 *
	 * @param value
	 * 				 The value
	 *
	 * @return Xml contained in the document.
	 */
	public static String generateNameId(String value) {
		return generateNameId(value, null, null, null);
	}

	/**
	 * Method to generate a symmetric key for encryption
	 *
	 * @return the symmetric key
	 *
	 * @throws Exception
	 */
	private static SecretKey generateSymmetricKey() throws Exception {
		KeyGenerator keyGenerator = KeyGenerator.getInstance("AES");
		keyGenerator.init(128);
		return keyGenerator.generateKey();
	}

	/**
	 * Generates a unique string (used for example as ID of assertions)
	 *
	 * @param prefix
	 *          Prefix for the Unique ID.
	 *          Use property onelogin.saml2.unique_id_prefix to set this.
	 *
	 * @return A unique string
	 */
	public static String generateUniqueID(String prefix) {
		if (prefix == null || StringUtils.isEmpty(prefix)) {
			prefix = Util.UNIQUE_ID_PREFIX;
		}
		return prefix + UUID.randomUUID();
	}

	/**
	 * Generates a unique string (used for example as ID of assertions)
	 *
	 * @return A unique string
	 */
	public static String generateUniqueID() {
		return generateUniqueID(null);
	}

	/**
	 * Interprets a ISO8601 duration value relative to a current time timestamp.
	 *
	 * @param duration
	 *            The duration, as a string.
	 *
	 * @return int The new timestamp, after the duration is applied.
	 *
	 * @throws IllegalArgumentException
	 */
	public static long parseDuration(String duration) throws IllegalArgumentException {
		TimeZone timeZone = DateTimeZone.UTC.toTimeZone();
		return parseDuration(duration, Calendar.getInstance(timeZone).getTimeInMillis() / 1000);
	}

	/**
	 * Interprets a ISO8601 duration value relative to a given timestamp.
	 *
	 * @param durationString
	 * 				 The duration, as a string.
	 * @param timestamp
	 *               The unix timestamp we should apply the duration to.
	 *
	 * @return the new timestamp, after the duration is applied In Seconds.
	 *
	 * @throws IllegalArgumentException
	 */
	public static long parseDuration(String durationString, long timestamp) throws IllegalArgumentException {
		boolean haveMinus = false;

		if (durationString.startsWith("-")) {
			durationString = durationString.substring(1);
			haveMinus = true;
		}

		PeriodFormatter periodFormatter = ISOPeriodFormat.standard().withLocale(new Locale("UTC"));
		Period period = periodFormatter.parsePeriod(durationString);

		DateTime dt = new DateTime(timestamp * 1000, DateTimeZone.UTC);

		DateTime result = null;
		if (haveMinus) {
			result = dt.minus(period);
		} else {
			result = dt.plus(period);
		}
		return result.getMillis() / 1000;
	}

	/**
	 * @return the unix timestamp that matches the current time.
	 */
	public static Long getCurrentTimeStamp() {
		DateTime currentDate = new DateTime(DateTimeZone.UTC);
		return currentDate.getMillis() / 1000;
	}

	/**
	 * Compare 2 dates and return the the earliest
	 *
	 * @param cacheDuration
	 * 				 The duration, as a string.
	 * @param validUntil
	 * 				 The valid until date, as a string
	 *
	 * @return the expiration time (timestamp format).
	 */
	public static long getExpireTime(String cacheDuration, String validUntil) {
		long expireTime = 0;
		try {
			if (cacheDuration != null && !StringUtils.isEmpty(cacheDuration)) {
				expireTime = parseDuration(cacheDuration);
			}

			if (validUntil != null && !StringUtils.isEmpty(validUntil)) {
				DateTime dt = Util.parseDateTime(validUntil);
				long validUntilTimeInt = dt.getMillis() / 1000;
				if (expireTime == 0 || expireTime > validUntilTimeInt) {
					expireTime = validUntilTimeInt;
				}
			}
		} catch (Exception e) {
			LOGGER.error("Error executing getExpireTime: " + e.getMessage(), e);
		}
		return expireTime;
	}

	/**
	 * Compare 2 dates and return the the earliest
	 *
	 * @param cacheDuration
	 * 				 The duration, as a string.
	 * @param validUntil
	 * 				 The valid until date, as a timestamp
	 *
	 * @return the expiration time (timestamp format).
	 */
	public static long getExpireTime(String cacheDuration, long validUntil) {
		long expireTime = 0;
		try {
			if (cacheDuration != null && !StringUtils.isEmpty(cacheDuration)) {
				expireTime = parseDuration(cacheDuration);
			}

			if (expireTime == 0 || expireTime > validUntil) {
				expireTime = validUntil;
			}
		} catch (Exception e) {
			LOGGER.error("Error executing getExpireTime: " + e.getMessage(), e);
		}
		return expireTime;
	}

	/**
	 * Create string form time In Millis with format yyyy-MM-ddTHH:mm:ssZ
	 *
	 * @param timeInMillis
	 * 				The time in Millis
	 *
	 * @return string with format yyyy-MM-ddTHH:mm:ssZ
	 */
	public static String formatDateTime(long timeInMillis) {
		return DATE_TIME_FORMAT.print(timeInMillis);
	}

	/**
	 * Create string form time In Millis with format yyyy-MM-ddTHH:mm:ssZ
	 *
	 * @param time
	 * 			The time
	 * @param millis
	 * 			Defines if the time is in Millis
	 *
	 * @return string with format yyyy-MM-ddTHH:mm:ssZ
	 */
	public static String formatDateTime(long time, boolean millis) {
		if (millis) {
			return DATE_TIME_FORMAT_MILLS.print(time);
		} else {
			return formatDateTime(time);
		}
	}

	/**
	 * Create calendar form string with format yyyy-MM-ddTHH:mm:ssZ // yyyy-MM-ddTHH:mm:ss.SSSZ
	 *
	 * @param dateTime
	 * 				 string with format yyyy-MM-ddTHH:mm:ssZ // yyyy-MM-ddTHH:mm:ss.SSSZ
	 *
	 * @return datetime
	 */
	public static DateTime parseDateTime(String dateTime) {

		DateTime parsedData = null;
		try {
			parsedData = DATE_TIME_FORMAT.parseDateTime(dateTime);
		} catch(Exception e) {
			return DATE_TIME_FORMAT_MILLS.parseDateTime(dateTime);
		}
		return parsedData;
	}

	private static String toStringUtf8(byte[] bytes) {
		try {
			return new String(bytes, "UTF-8");
		} catch (UnsupportedEncodingException e) {
			throw new IllegalStateException(e);
		}
	}

	private static byte[] toBytesUtf8(String str) {
		try {
			return str.getBytes("UTF-8");
		} catch (UnsupportedEncodingException e) {
			throw new IllegalStateException(e);
		}
	}


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy