All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sap.cds.feature.postgresql.PostgreSqlSSLFactory Maven / Gradle / Ivy

The newest version!
/**************************************************************************
 * (C) 2019-2024 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.feature.postgresql;

import static java.nio.charset.StandardCharsets.UTF_8;

import java.io.ByteArrayInputStream;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Base64.Decoder;
import java.util.List;
import java.util.Properties;

import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;

import com.sap.cds.services.ServiceException;
import com.sap.cds.services.utils.StringUtils;

import org.postgresql.ssl.WrappedFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PostgreSqlSSLFactory extends WrappedFactory {

	private static final Logger log = LoggerFactory.getLogger(PostgreSqlSSLFactory.class);

	public PostgreSqlSSLFactory(Properties props) {
		String rootCertBase64 = props.getProperty("sslrootcertbase64");
		String privateKeyBase64 = props.getProperty("sslprivatekeybase64");
		String clientCertBase64 = props.getProperty("sslclientcertbase64");

		Decoder decoder = Base64.getUrlDecoder();

		try {
			KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType());
			keystore.load(null);

			// can be null if mTLS is not used
			KeyManager[] keyManagers = null;

			// check if mTLS is enabled
			if (!StringUtils.isEmpty(privateKeyBase64)) {
				String sslPrivateKey = new String(decoder.decode(privateKeyBase64), UTF_8);
				PrivateKey privateKey = generatePrivateKey(sslPrivateKey);

				Certificate clientCert = toX509Certificates(decoder.decode(clientCertBase64)).get(0);
				keystore.setKeyEntry("client-key", privateKey, null, new Certificate[] { clientCert });

				KeyManagerFactory keyManagerFactory = KeyManagerFactory
						.getInstance(KeyManagerFactory.getDefaultAlgorithm());
				keyManagerFactory.init(keystore, null);

				keyManagers = keyManagerFactory.getKeyManagers();
			}

			List rootCerts = toX509Certificates(decoder.decode(rootCertBase64));
			rootCerts.forEach(cert -> {
				try {
					keystore.setCertificateEntry(cert.getSubjectX500Principal().getName(), cert);
				} catch (KeyStoreException e) {
					log.warn("Error adding certificate to keystore.", e);
				}
			});

			TrustManagerFactory trustManagerFactory = TrustManagerFactory
					.getInstance(TrustManagerFactory.getDefaultAlgorithm());
			trustManagerFactory.init(keystore);
			TrustManager[] trustManagers = trustManagerFactory.getTrustManagers();

			SSLContext sslContext = SSLContext.getInstance("TLS"); // TLS, TLSv1.1, TLSv1.2, TLSv1.3
			sslContext.init(keyManagers, trustManagers, SecureRandom.getInstanceStrong());

			this.factory = sslContext.getSocketFactory();
		} catch (Exception e) {
			throw new ServiceException(e);
		}
	}

	/**
	 * Reads the PEM encoded RSA key from the given {@link String} and creates {@link PrivateKey} from {@link KeySpec}.
	 * 
	 * @param privateKeyString - PEM encoded RSA private key. Something like:
	 * 
	 *                         
	 * -----BEGIN RSA PRIVATE KEY-----
	 * MII...
	 * ...
	 * ...==
	 * -----END RSA PRIVATE KEY-----
	 *                         
* * @return Java object ({@link PrivateKey}) representation of given {@link String}. * @throws InvalidKeySpecException In case something isn't as expected in the given {@link String} to created new * {@link RSAPrivateCrtKeySpec}. * @throws NoSuchAlgorithmException In case RSA isn't available (see: * {@link KeyFactory#getInstance(String)}). * @throws InvalidKeyException if the given key cannot be processed by the RSA key factory. */ private static final PrivateKey generatePrivateKey(String privateKeyString) throws InvalidKeySpecException, NoSuchAlgorithmException, InvalidKeyException { String pem = privateKeyString.replaceAll("-----.+KEY-----", "").replaceAll("\\s+", ""); byte[] decoded = Base64.getDecoder().decode(pem.getBytes(UTF_8)); KeyFactory keyFactory = KeyFactory.getInstance("RSA"); PrivateKey key = (PrivateKey) keyFactory.translateKey(new PKCS1PrivateKey(decoded)); KeySpec rsaSpec = keyFactory.getKeySpec(key, RSAPrivateCrtKeySpec.class); return keyFactory.generatePrivate(rsaSpec); } private static List toX509Certificates(byte[] sslRootCert) throws CertificateException { CertificateFactory factory = CertificateFactory.getInstance("X.509"); ByteArrayInputStream certsIn = new ByteArrayInputStream(sslRootCert); List certs = new ArrayList<>(); do { X509Certificate cert = (X509Certificate) factory.generateCertificate(certsIn); certs.add(cert); } while (certsIn.available() > 0); return certs; } static class PKCS1PrivateKey implements PrivateKey { private static final long serialVersionUID = -880683996554878059L; private final byte[] encoded; public PKCS1PrivateKey(byte[] encoded) { this.encoded = encoded; } @Override public String getAlgorithm() { return "RSA"; } @Override public String getFormat() { return "PKCS#1"; } @Override public byte[] getEncoded() { return this.encoded; } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy