All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.snowflake.client.core.SessionUtilKeyPair Maven / Gradle / Ivy

/*
 * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved.
 */
package net.snowflake.client.core;

import com.google.common.base.Strings;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import net.snowflake.client.jdbc.ErrorCode;
import org.apache.commons.codec.binary.Base64;

import javax.crypto.EncryptedPrivateKeyInfo;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import java.io.IOException;
import java.io.StringReader;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.*;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.util.Date;

import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemReader;

/**
 * Class used to compute jwt token for key pair authentication
 * Created by hyu on 1/16/18.
 */
class SessionUtilKeyPair
{
  // user name in upper case
  final private String userName;

  // account name in upper case
  final private String accountName;

  final private PrivateKey privateKey;

  private PublicKey publicKey = null;

  private boolean isFipsMode = false;

  private Provider SecurityProvider = null;

  private SecretKeyFactory secretKeyFactory = null;

  static private final String ISSUER_FMT = "%s.%s.%s";

  static private final String SUBJECT_FMT = "%s.%s";


  SessionUtilKeyPair(PrivateKey privateKey, String privateKeyFile, String privateKeyFilePwd,
                     String accountName,
                     String userName) throws SFException
  {
    this.userName = userName.toUpperCase();
    this.accountName = accountName.toUpperCase();

    // check if in FIPS mode
    for (Provider p : Security.getProviders())
    {
      if ("BCFIPS".equals(p.getName()))
      {
        this.isFipsMode = true;
        this.SecurityProvider = p;
        break;
      }
    }

    // if there is both a file and a private key, there is a problem
    if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null)
    {
      throw new SFException(ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
                            "Cannot have both private key value and private key file.");
    }
    else
    {
      // if privateKeyFile has a value and privateKey is null
      this.privateKey = Strings.isNullOrEmpty(privateKeyFile) ?
                        privateKey :
                        extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd);
    }
    // construct public key from raw bytes
    if (this.privateKey instanceof RSAPrivateCrtKey)
    {
      RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) this.privateKey;
      RSAPublicKeySpec rsaPublicKeySpec = new RSAPublicKeySpec(
          rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());

      try
      {
        this.publicKey = getKeyFactoryInstance().generatePublic(rsaPublicKeySpec);
      }
      catch (NoSuchAlgorithmException | InvalidKeySpecException e)
      {
        throw new SFException(e, ErrorCode.INTERNAL_ERROR,
                              "Error retrieving public key");
      }
    }
    else
    {
      throw new SFException(ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY,
                            "Use java.security.interfaces.RSAPrivateCrtKey.class for the private key");
    }
  }

  private KeyFactory getKeyFactoryInstance() throws NoSuchAlgorithmException
  {
    if (isFipsMode)
    {
      return KeyFactory.getInstance("RSA", this.SecurityProvider);
    }
    else
    {
      return KeyFactory.getInstance("RSA");
    }
  }

  private SecretKeyFactory getSecretKeyFactory(String algorithm) throws NoSuchAlgorithmException
  {
    if (isFipsMode)
    {
      return SecretKeyFactory.getInstance(algorithm, this.SecurityProvider);
    }
    else
    {
      return SecretKeyFactory.getInstance(algorithm);
    }
  }

  private PrivateKey extractPrivateKeyFromFile(String privateKeyFile, String privateKeyFilePwd) throws SFException
  {
    try
    {
      String privateKeyContent = new String(Files.readAllBytes(Paths.get(privateKeyFile)));
      if (Strings.isNullOrEmpty(privateKeyFilePwd))
      {
        // unencrypted private key file
        PemReader pr = new PemReader(new StringReader(privateKeyContent));
        byte[] decoded = pr.readPemObject().getContent();
        pr.close();
        PKCS8EncodedKeySpec encodedKeySpec = new PKCS8EncodedKeySpec(decoded);
        KeyFactory keyFactory = getKeyFactoryInstance();
        return keyFactory.generatePrivate(encodedKeySpec);
      }
      else
      {
        // encrypted private key file
        PemReader pr = new PemReader(new StringReader(privateKeyContent));
        byte[] decoded = pr.readPemObject().getContent();
        pr.close();
        EncryptedPrivateKeyInfo pkInfo = new EncryptedPrivateKeyInfo(decoded);
        PBEKeySpec keySpec = new PBEKeySpec(privateKeyFilePwd.toCharArray());
        SecretKeyFactory pbeKeyFactory = this.getSecretKeyFactory(pkInfo.getAlgName());
        PKCS8EncodedKeySpec encodedKeySpec = pkInfo.getKeySpec(pbeKeyFactory.generateSecret(keySpec));
        KeyFactory keyFactory = getKeyFactoryInstance();
        return keyFactory.generatePrivate(encodedKeySpec);
      }
    }
    catch (NoSuchAlgorithmException | InvalidKeySpecException |
        IOException | IllegalArgumentException | NullPointerException | InvalidKeyException e)
    {
      throw new SFException(
          e, ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, privateKeyFile + ": " + e.getMessage());
    }
  }

  public String issueJwtToken() throws SFException
  {
    JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder();
    String sub = String.format(SUBJECT_FMT, this.accountName, this.userName);
    String iss = String.format(ISSUER_FMT, this.accountName, this.userName,
                               this.calculatePublicKeyFingerprint(this.publicKey));

    // iat is now
    Date iat = new Date(System.currentTimeMillis());

    // expiration is 60 seconds later
    Date exp = new Date(iat.getTime() + 60L * 1000);

    JWTClaimsSet claimsSet = builder.issuer(iss)
        .subject(sub)
        .issueTime(iat)
        .expirationTime(exp)
        .build();

    SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256),
                                        claimsSet);
    JWSSigner signer = new RSASSASigner(this.privateKey);

    try
    {
      signedJWT.sign(signer);
    }
    catch (JOSEException e)
    {
      throw new SFException(e, ErrorCode.FAILED_TO_GENERATE_JWT);
    }

    return signedJWT.serialize();
  }

  private String calculatePublicKeyFingerprint(PublicKey publicKey)
  throws SFException
  {
    try
    {
      MessageDigest md = MessageDigest.getInstance("SHA-256");
      byte[] sha256Hash = md.digest(publicKey.getEncoded());
      return "SHA256:" + Base64.encodeBase64String(sha256Hash);
    }
    catch (NoSuchAlgorithmException e)
    {
      throw new SFException(e, ErrorCode.INTERNAL_ERROR,
                            "Error when calculating fingerprint");
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy