All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.snowflake.client.core.SessionUtilKeyPair Maven / Gradle / Ivy

/*
 * Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved.
 */
package net.snowflake.client.core;

import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetEnv;

import com.google.common.base.Strings;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.io.IOException;
import java.io.StringReader;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.*;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.util.Date;
import javax.crypto.EncryptedPrivateKeyInfo;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import org.apache.commons.codec.binary.Base64;
import org.bouncycastle.util.io.pem.PemReader;

/** Class used to compute jwt token for key pair authentication Created by hyu on 1/16/18. */
class SessionUtilKeyPair {

  static final SFLogger logger = SFLoggerFactory.getLogger(SessionUtilKeyPair.class);

  // user name in upper case
  private final String userName;

  // account name in upper case
  private final String accountName;

  private final PrivateKey privateKey;

  private PublicKey publicKey = null;

  private boolean isFipsMode = false;

  private Provider SecurityProvider = null;

  private SecretKeyFactory secretKeyFactory = null;

  private static final String ISSUER_FMT = "%s.%s.%s";

  private static final String SUBJECT_FMT = "%s.%s";

  private static final int JWT_DEFAULT_AUTH_TIMEOUT = 10;

  SessionUtilKeyPair(
      PrivateKey privateKey,
      String privateKeyFile,
      String privateKeyFilePwd,
      String accountName,
      String userName)
      throws SFException {
    this.userName = userName.toUpperCase();
    this.accountName = accountName.toUpperCase();

    // check if in FIPS mode
    for (Provider p : Security.getProviders()) {
      if ("BCFIPS".equals(p.getName())) {
        this.isFipsMode = true;
        this.SecurityProvider = p;
        break;
      }
    }

    // if there is both a file and a private key, there is a problem
    if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) {
      throw new SFException(
          ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
          "Cannot have both private key value and private key file.");
    } else {
      // if privateKeyFile has a value and privateKey is null
      this.privateKey =
          Strings.isNullOrEmpty(privateKeyFile)
              ? privateKey
              : extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd);
    }
    // construct public key from raw bytes
    if (this.privateKey instanceof RSAPrivateCrtKey) {
      RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) this.privateKey;
      RSAPublicKeySpec rsaPublicKeySpec =
          new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());

      try {
        this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec);
      } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
        throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error retrieving public key");
      }
    } else {
      throw new SFException(
          ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
          "Use java.security.interfaces.RSAPrivateCrtKey.class for the private key");
    }
  }

  private KeyFactory getKeyFactoryInstance() throws NoSuchAlgorithmException {
    if (isFipsMode) {
      return KeyFactory.getInstance("RSA", this.SecurityProvider);
    } else {
      return KeyFactory.getInstance("RSA");
    }
  }

  private SecretKeyFactory getSecretKeyFactory(String algorithm) throws NoSuchAlgorithmException {
    if (isFipsMode) {
      return SecretKeyFactory.getInstance(algorithm, this.SecurityProvider);
    } else {
      return SecretKeyFactory.getInstance(algorithm);
    }
  }

  private PrivateKey extractPrivateKeyFromFile(String privateKeyFile, String privateKeyFilePwd)
      throws SFException {
    try {
      String privateKeyContent = new String(Files.readAllBytes(Paths.get(privateKeyFile)));
      if (Strings.isNullOrEmpty(privateKeyFilePwd)) {
        // unencrypted private key file
        PemReader pr = new PemReader(new StringReader(privateKeyContent));
        byte[] decoded = pr.readPemObject().getContent();
        pr.close();
        PKCS8EncodedKeySpec encodedKeySpec = new PKCS8EncodedKeySpec(decoded);
        KeyFactory keyFactory = getKeyFactoryInstance();
        return keyFactory.generatePrivate(encodedKeySpec);
      } else {
        // encrypted private key file
        PemReader pr = new PemReader(new StringReader(privateKeyContent));
        byte[] decoded = pr.readPemObject().getContent();
        pr.close();
        EncryptedPrivateKeyInfo pkInfo = new EncryptedPrivateKeyInfo(decoded);
        PBEKeySpec keySpec = new PBEKeySpec(privateKeyFilePwd.toCharArray());
        SecretKeyFactory pbeKeyFactory = this.getSecretKeyFactory(pkInfo.getAlgName());
        PKCS8EncodedKeySpec encodedKeySpec =
            pkInfo.getKeySpec(pbeKeyFactory.generateSecret(keySpec));
        KeyFactory keyFactory = getKeyFactoryInstance();
        return keyFactory.generatePrivate(encodedKeySpec);
      }
    } catch (NoSuchAlgorithmException
        | InvalidKeySpecException
        | IOException
        | IllegalArgumentException
        | NullPointerException
        | InvalidKeyException e) {
      throw new SFException(
          e, ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, privateKeyFile + ": " + e.getMessage());
    }
  }

  public String issueJwtToken() throws SFException {
    JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder();
    String sub = String.format(SUBJECT_FMT, this.accountName, this.userName);
    String iss =
        String.format(
            ISSUER_FMT,
            this.accountName,
            this.userName,
            this.calculatePublicKeyFingerprint(this.publicKey));

    // iat is now
    Date iat = new Date(System.currentTimeMillis());

    // expiration is 60 seconds later
    Date exp = new Date(iat.getTime() + 60L * 1000);

    JWTClaimsSet claimsSet =
        builder.issuer(iss).subject(sub).issueTime(iat).expirationTime(exp).build();

    SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256), claimsSet);
    JWSSigner signer = new RSASSASigner(this.privateKey);

    try {
      signedJWT.sign(signer);
    } catch (JOSEException e) {
      throw new SFException(e, ErrorCode.FAILED_TO_GENERATE_JWT);
    }
    // Log the contents of the token, displaying expiration and issue time in epoch time
    logger.debug(
        "JWT:\n'{'\niss: {}\nsub: {}\niat: {}\nexp: {}\n'}'",
        iss,
        sub,
        String.valueOf(iat.getTime() / 1000),
        String.valueOf(exp.getTime() / 1000));
    return signedJWT.serialize();
  }

  private String calculatePublicKeyFingerprint(PublicKey publicKey) throws SFException {
    try {
      MessageDigest md = MessageDigest.getInstance("SHA-256");
      byte[] sha256Hash = md.digest(publicKey.getEncoded());
      return "SHA256:" + Base64.encodeBase64String(sha256Hash);
    } catch (NoSuchAlgorithmException e) {
      throw new SFException(e, ErrorCode.INTERNAL_ERROR, "Error when calculating fingerprint");
    }
  }

  public static int getTimeout() {
    String jwtAuthTimeoutStr = systemGetEnv("JWT_AUTH_TIMEOUT");
    int jwtAuthTimeout = JWT_DEFAULT_AUTH_TIMEOUT;
    if (jwtAuthTimeoutStr != null) {
      jwtAuthTimeout = Integer.parseInt(jwtAuthTimeoutStr);
    }
    return jwtAuthTimeout;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy