All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.snowflake.client.core.SessionUtilKeyPair Maven / Gradle / Ivy

There is a newer version: 3.21.0
Show newest version
/*
 * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved.
 */
package net.snowflake.client.core;

import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetEnv;

import com.google.common.base.Strings;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.io.IOException;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.PublicKey;
import java.security.Security;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.util.Date;
import javax.crypto.EncryptedPrivateKeyInfo;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import org.apache.commons.codec.binary.Base64;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder;
import org.bouncycastle.operator.InputDecryptorProvider;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo;
import org.bouncycastle.pkcs.PKCSException;
import org.bouncycastle.util.io.pem.PemReader;

/** Class used to compute jwt token for key pair authentication Created by hyu on 1/16/18. */
class SessionUtilKeyPair {

  private static final SFLogger logger = SFLoggerFactory.getLogger(SessionUtilKeyPair.class);

  // user name in upper case
  private final String userName;

  // account name in upper case
  private final String accountName;

  private final PrivateKey privateKey;

  private PublicKey publicKey = null;

  private boolean isFipsMode = false;

  private Provider SecurityProvider = null;

  private static final String ISSUER_FMT = "%s.%s.%s";

  private static final String SUBJECT_FMT = "%s.%s";

  private static final int JWT_DEFAULT_AUTH_TIMEOUT = 10;

  private boolean isBouncyCastleProviderEnabled = false;

  SessionUtilKeyPair(
      PrivateKey privateKey,
      String privateKeyFile,
      String privateKeyBase64,
      String privateKeyPwd,
      String accountName,
      String userName)
      throws SFException {
    this.userName = userName.toUpperCase();
    this.accountName = accountName.toUpperCase();
    String enableBouncyCastleJvm =
        System.getProperty(SecurityUtil.ENABLE_BOUNCYCASTLE_PROVIDER_JVM);
    if (enableBouncyCastleJvm != null) {
      isBouncyCastleProviderEnabled = enableBouncyCastleJvm.equalsIgnoreCase("true");
    }
    // check if in FIPS mode
    for (Provider p : Security.getProviders()) {
      if (SecurityUtil.BOUNCY_CASTLE_FIPS_PROVIDER.equals(p.getName())) {
        this.isFipsMode = true;
        this.SecurityProvider = p;
        break;
      }
    }

    ensurePrivateKeyProvidedInOnlyOneProperty(privateKey, privateKeyFile, privateKeyBase64);
    this.privateKey = buildPrivateKey(privateKey, privateKeyFile, privateKeyBase64, privateKeyPwd);

    // construct public key from raw bytes
    if (this.privateKey instanceof RSAPrivateCrtKey) {
      RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) this.privateKey;
      RSAPublicKeySpec rsaPublicKeySpec =
          new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());

      try {
        this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec);
      } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
        throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error retrieving public key");
      }
    } else {
      throw new SFException(
          ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
          "Use java.security.interfaces.RSAPrivateCrtKey.class for the private key");
    }
  }

  private static void ensurePrivateKeyProvidedInOnlyOneProperty(
      PrivateKey privateKey, String privateKeyFile, String privateKeyBase64) throws SFException {
    if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) {
      throw new SFException(
          ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
          "Cannot have both private key object and private key file.");
    }
    if (!Strings.isNullOrEmpty(privateKeyBase64) && !Strings.isNullOrEmpty(privateKeyFile)) {
      throw new SFException(
          ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
          "Cannot have both private key file and private key base64 string value.");
    }
    if (!Strings.isNullOrEmpty(privateKeyBase64) && privateKey != null) {
      throw new SFException(
          ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
          "Cannot have both private key object and private key base64 string value.");
    }
  }

  private PrivateKey buildPrivateKey(
      PrivateKey privateKey, String privateKeyFile, String privateKeyBase64, String privateKeyPwd)
      throws SFException {
    if (!Strings.isNullOrEmpty(privateKeyBase64)) {
      logger.trace("Reading private key from base64 string");
      return extractPrivateKeyFromBase64(privateKeyBase64, privateKeyPwd);
    }
    if (!Strings.isNullOrEmpty(privateKeyFile)) {
      logger.trace("Reading private key from file");
      return extractPrivateKeyFromFile(privateKeyFile, privateKeyPwd);
    }
    return privateKey;
  }

  private KeyFactory getKeyFactoryInstance() throws NoSuchAlgorithmException {
    if (isFipsMode) {
      return KeyFactory.getInstance("RSA", this.SecurityProvider);
    } else {
      return KeyFactory.getInstance("RSA");
    }
  }

  private SecretKeyFactory getSecretKeyFactory(String algorithm) throws NoSuchAlgorithmException {
    if (isFipsMode) {
      return SecretKeyFactory.getInstance(algorithm, this.SecurityProvider);
    } else {
      return SecretKeyFactory.getInstance(algorithm);
    }
  }

  private PrivateKey extractPrivateKeyFromFile(String privateKeyFile, String privateKeyPwd)
      throws SFException {

    try {
      Path privKeyPath = Paths.get(privateKeyFile);
      FileUtil.logFileUsage(privKeyPath, "Extract private key from file", true);
      byte[] bytes = Files.readAllBytes(privKeyPath);
      return extractPrivateKeyFromBytes(bytes, privateKeyPwd);
    } catch (IOException ie) {
      logger.error("Could not read private key from file", ie);
      throw new SFException(ie, ErrorCode.INVALID_PARAMETER_VALUE, ie.getCause());
    }
  }

  private PrivateKey extractPrivateKeyFromBytes(byte[] privateKeyBytes, String privateKeyPwd)
      throws SFException {
    if (isBouncyCastleProviderEnabled) {
      try {
        return extractPrivateKeyWithBouncyCastle(privateKeyBytes, privateKeyPwd);
      } catch (IOException | PKCSException | OperatorCreationException e) {
        logger.error("Could not extract private key using Bouncy Castle provider", e);
        throw new SFException(e, ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, e.getCause());
      }
    } else {
      try {
        return extractPrivateKeyWithJdk(privateKeyBytes, privateKeyPwd);
      } catch (NoSuchAlgorithmException
          | InvalidKeySpecException
          | IOException
          | IllegalArgumentException
          | NullPointerException
          | InvalidKeyException e) {
        logger.error(
            "Could not extract private key using standard JDK. Try setting the JVM argument: "
                + "-D{}"
                + "=TRUE",
            SecurityUtil.ENABLE_BOUNCYCASTLE_PROVIDER_JVM);
        throw new SFException(e, ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, e.getMessage());
      }
    }
  }

  private PrivateKey extractPrivateKeyFromBase64(String privateKeyBase64, String privateKeyPwd)
      throws SFException {
    byte[] decodedKey = Base64.decodeBase64(privateKeyBase64);
    return extractPrivateKeyFromBytes(decodedKey, privateKeyPwd);
  }

  public String issueJwtToken() throws SFException {
    JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder();
    String sub = String.format(SUBJECT_FMT, this.accountName, this.userName);
    String iss =
        String.format(
            ISSUER_FMT,
            this.accountName,
            this.userName,
            this.calculatePublicKeyFingerprint(this.publicKey));

    // iat is now
    Date iat = new Date(System.currentTimeMillis());

    // expiration is 60 seconds later
    Date exp = new Date(iat.getTime() + 60L * 1000);

    JWTClaimsSet claimsSet =
        builder.issuer(iss).subject(sub).issueTime(iat).expirationTime(exp).build();

    SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256), claimsSet);
    JWSSigner signer = new RSASSASigner(this.privateKey);

    try {
      signedJWT.sign(signer);
    } catch (JOSEException e) {
      throw new SFException(e, ErrorCode.FAILED_TO_GENERATE_JWT);
    }
    // Log the contents of the token, displaying expiration and issue time in epoch time
    logger.debug(
        "JWT:\n'{'\niss: {}\nsub: {}\niat: {}\nexp: {}\n'}'",
        iss,
        sub,
        String.valueOf(iat.getTime() / 1000),
        String.valueOf(exp.getTime() / 1000));
    return signedJWT.serialize();
  }

  private String calculatePublicKeyFingerprint(PublicKey publicKey) throws SFException {
    try {
      MessageDigest md = MessageDigest.getInstance("SHA-256");
      byte[] sha256Hash = md.digest(publicKey.getEncoded());
      return "SHA256:" + Base64.encodeBase64String(sha256Hash);
    } catch (NoSuchAlgorithmException e) {
      throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error when calculating fingerprint");
    }
  }

  public static int getTimeout() {
    String jwtAuthTimeoutStr = systemGetEnv("JWT_AUTH_TIMEOUT");
    int jwtAuthTimeout = JWT_DEFAULT_AUTH_TIMEOUT;
    if (jwtAuthTimeoutStr != null) {
      jwtAuthTimeout = Integer.parseInt(jwtAuthTimeoutStr);
    }
    return jwtAuthTimeout;
  }

  private PrivateKey extractPrivateKeyWithBouncyCastle(byte[] privateKeyBytes, String privateKeyPwd)
      throws IOException, PKCSException, OperatorCreationException {
    logger.trace("Extracting private key using Bouncy Castle provider");
    PrivateKeyInfo privateKeyInfo = null;
    PEMParser pemParser =
        new PEMParser(new StringReader(new String(privateKeyBytes, StandardCharsets.UTF_8)));
    Object pemObject = pemParser.readObject();
    if (pemObject instanceof PKCS8EncryptedPrivateKeyInfo) {
      // Handle the case where the private key is encrypted.
      PKCS8EncryptedPrivateKeyInfo encryptedPrivateKeyInfo =
          (PKCS8EncryptedPrivateKeyInfo) pemObject;
      InputDecryptorProvider pkcs8Prov =
          new JceOpenSSLPKCS8DecryptorProviderBuilder().build(privateKeyPwd.toCharArray());
      privateKeyInfo = encryptedPrivateKeyInfo.decryptPrivateKeyInfo(pkcs8Prov);
    } else if (pemObject instanceof PEMKeyPair) {
      // PKCS#1 private key
      privateKeyInfo = ((PEMKeyPair) pemObject).getPrivateKeyInfo();
    } else if (pemObject instanceof PrivateKeyInfo) {
      // Handle the case where the private key is unencrypted.
      privateKeyInfo = (PrivateKeyInfo) pemObject;
    }
    pemParser.close();
    JcaPEMKeyConverter converter =
        new JcaPEMKeyConverter()
            .setProvider(
                isFipsMode
                    ? SecurityUtil.BOUNCY_CASTLE_FIPS_PROVIDER
                    : SecurityUtil.BOUNCY_CASTLE_PROVIDER);
    return converter.getPrivateKey(privateKeyInfo);
  }

  private PrivateKey extractPrivateKeyWithJdk(byte[] privateKeyFileBytes, String privateKeyPwd)
      throws IOException, NoSuchAlgorithmException, InvalidKeySpecException, InvalidKeyException {
    logger.trace("Extracting private key using JDK");
    String privateKeyContent = new String(privateKeyFileBytes, StandardCharsets.UTF_8);
    if (Strings.isNullOrEmpty(privateKeyPwd)) {
      // unencrypted private key file
      return generatePrivateKey(false, privateKeyContent, privateKeyPwd);
    } else {
      // encrypted private key file
      return generatePrivateKey(true, privateKeyContent, privateKeyPwd);
    }
  }

  private PrivateKey generatePrivateKey(
      boolean isEncrypted, String privateKeyContent, String privateKeyPwd)
      throws IOException, NoSuchAlgorithmException, InvalidKeySpecException, InvalidKeyException {
    if (isEncrypted) {
      try (PemReader pr = new PemReader(new StringReader(privateKeyContent))) {
        byte[] decoded = pr.readPemObject().getContent();
        pr.close();
        EncryptedPrivateKeyInfo pkInfo = new EncryptedPrivateKeyInfo(decoded);
        PBEKeySpec keySpec = new PBEKeySpec(privateKeyPwd.toCharArray());
        SecretKeyFactory pbeKeyFactory = this.getSecretKeyFactory(pkInfo.getAlgName());
        PKCS8EncodedKeySpec encodedKeySpec =
            pkInfo.getKeySpec(pbeKeyFactory.generateSecret(keySpec));
        KeyFactory keyFactory = getKeyFactoryInstance();
        return keyFactory.generatePrivate(encodedKeySpec);
      }
    } else {
      try (PemReader pr = new PemReader(new StringReader(privateKeyContent))) {
        byte[] decoded = pr.readPemObject().getContent();
        pr.close();
        PKCS8EncodedKeySpec encodedKeySpec = new PKCS8EncodedKeySpec(decoded);
        KeyFactory keyFactory = getKeyFactoryInstance();
        return keyFactory.generatePrivate(encodedKeySpec);
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy