All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.jdbc.auth.JwtPrivateKeyClientCredentials Maven / Gradle / Ivy

package com.databricks.jdbc.auth;

import static com.nimbusds.jose.JWSAlgorithm.*;

import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
import com.databricks.jdbc.exception.DatabricksHttpException;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.oauth.OAuthResponse;
import com.databricks.sdk.core.oauth.RefreshableTokenSource;
import com.databricks.sdk.core.oauth.Token;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.nimbusds.jose.*;
import com.nimbusds.jose.crypto.*;
import com.nimbusds.jwt.*;
import java.io.*;
import java.io.FileReader;
import java.io.Reader;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.Security;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.RSAPrivateKey;
import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.http.HttpResponse;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.message.BasicNameValuePair;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMException;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder;
import org.bouncycastle.operator.InputDecryptorProvider;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo;
import org.bouncycastle.pkcs.PKCSException;

/**
 * An implementation of RefreshableTokenSource implementing the JWT client_credentials OAuth grant
 * type.
 */
public class JwtPrivateKeyClientCredentials extends RefreshableTokenSource {

  private static final JdbcLogger LOGGER =
      JdbcLoggerFactory.getLogger(JwtPrivateKeyClientCredentials.class);

  public static class Builder {
    private String clientId;
    private String tokenUrl;
    private String jwtKeyFile;
    private String jwtKid;
    private String jwtKeyPassphrase;
    private String jwtAlgorithm;
    private IDatabricksHttpClient hc;
    private List scopes = Collections.emptyList();

    public Builder withClientId(String clientId) {
      this.clientId = clientId;
      return this;
    }

    public Builder withTokenUrl(String tokenUrl) {
      this.tokenUrl = tokenUrl;
      return this;
    }

    public Builder withScopes(List scopes) {
      this.scopes = scopes;
      return this;
    }

    public Builder withHttpClient(IDatabricksHttpClient hc) {
      this.hc = hc;
      return this;
    }

    public Builder withJwtAlgorithm(String jwtAlgorithm) {
      this.jwtAlgorithm = jwtAlgorithm;
      return this;
    }

    public Builder withJwtKeyPassphrase(String jwtKeyPassphrase) {
      this.jwtKeyPassphrase = jwtKeyPassphrase;
      return this;
    }

    public Builder withJwtKid(String jwtKid) {
      this.jwtKid = jwtKid;
      return this;
    }

    public Builder withJwtKeyFile(String jwtKeyFile) {
      this.jwtKeyFile = jwtKeyFile;
      return this;
    }

    public JwtPrivateKeyClientCredentials build() {
      Objects.requireNonNull(this.clientId, "clientId must be specified");
      Objects.requireNonNull(this.jwtKeyFile, "JWT key file must be specified");
      Objects.requireNonNull(this.jwtKid, "JWT KID must be specified");
      return new JwtPrivateKeyClientCredentials(
          hc, clientId, jwtKeyFile, jwtKid, jwtKeyPassphrase, jwtAlgorithm, tokenUrl, scopes);
    }
  }

  private final String BOUNCY_CASTLE_PROVIDER = "BC";
  private IDatabricksHttpClient hc;
  private String clientId;
  private String tokenUrl;
  private final List scopes;

  private final String jwtKeyFile;
  private final String jwtKid;
  private final String jwtKeyPassphrase;
  private final JWSAlgorithm jwtAlgorithm;

  private JwtPrivateKeyClientCredentials(
      IDatabricksHttpClient hc,
      String clientId,
      String jwtKeyFile,
      String jwtKid,
      String jwtKeyPassphrase,
      String jwtAlgorithm,
      String tokenUrl,
      List scopes) {
    this.hc = hc;
    this.clientId = clientId;
    this.jwtKeyFile = jwtKeyFile;
    this.jwtKid = jwtKid;
    this.jwtKeyPassphrase = jwtKeyPassphrase;
    this.jwtAlgorithm = determineSignatureAlgorithm(jwtAlgorithm);
    this.tokenUrl = tokenUrl;
    this.scopes = scopes;
  }

  @Override
  protected Token refresh() {
    Map params = new HashMap<>();
    params.put("grant_type", "client_credentials");
    if (scopes != null) {
      params.put("scope", String.join(" ", scopes));
    }
    params.put("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
    params.put("client_assertion", getSerialisedSignedJWT());
    return retrieveToken(hc, tokenUrl, params, new HashMap<>());
  }

  @VisibleForTesting
  protected static Token retrieveToken(
      IDatabricksHttpClient hc,
      String tokenUrl,
      Map params,
      Map headers) {
    try {
      URIBuilder uriBuilder = new URIBuilder(tokenUrl);
      HttpPost postRequest = new HttpPost(uriBuilder.build());
      postRequest.setEntity(
          new UrlEncodedFormEntity(
              params.entrySet().stream()
                  .map(e -> new BasicNameValuePair(e.getKey(), e.getValue()))
                  .collect(Collectors.toList()),
              StandardCharsets.UTF_8));
      headers.forEach(postRequest::setHeader);
      HttpResponse response = hc.execute(postRequest);
      OAuthResponse resp =
          new ObjectMapper().readValue(response.getEntity().getContent(), OAuthResponse.class);
      LocalDateTime expiry = LocalDateTime.now().plus(resp.getExpiresIn(), ChronoUnit.SECONDS);
      return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry);
    } catch (IOException | URISyntaxException | DatabricksHttpException e) {
      String errorMessage = "Failed to retrieve custom M2M token: " + e.getMessage();
      LOGGER.error(errorMessage);
      throw new DatabricksException(errorMessage, e);
    }
  }

  private String getSerialisedSignedJWT() {
    PrivateKey privateKey = getPrivateKey();
    SignedJWT signedJWT = fetchSignedJWT(privateKey);
    return signedJWT.serialize();
  }

  @VisibleForTesting
  String getTokenEndpoint() {
    return tokenUrl;
  }

  @VisibleForTesting
  JWSAlgorithm determineSignatureAlgorithm(String jwtAlgorithm) {
    if (jwtAlgorithm == null) {
      jwtAlgorithm = "RS256"; // By default, we use RS256
    }
    switch (jwtAlgorithm) {
        // Following are RSA algorithms
      case "RS384":
        return RS384;
      case "RS512":
        return RS512;
      case "PS256":
        return PS256;
      case "PS384":
        return PS384;
      case "PS512":
        return PS512;
      case "RS256":
        return RS256;
        // following are EC algorithms
      case "ES384":
        return JWSAlgorithm.ES384;
      case "ES512":
        return JWSAlgorithm.ES512;
      case "ES256":
        return JWSAlgorithm.ES256;
      default:
        LOGGER.debug("Defaulting to RS256. Provided JWT algorithm not supported " + jwtAlgorithm);
        return RS256;
    }
  }

  private PrivateKey getPrivateKey() {
    try {
      Security.addProvider(new BouncyCastleProvider());
      try (Reader reader = new FileReader(jwtKeyFile);
          PEMParser pemParser = new PEMParser(reader)) {
        Object object = pemParser.readObject();
        return convertPrivateKey(object);
      }
    } catch (DatabricksSQLException | IOException e) {
      String errorMessage = "Failed to parse private key: " + e.getMessage();
      LOGGER.error(errorMessage);
      throw new DatabricksException(errorMessage, e);
    }
  }

  PrivateKey convertPrivateKey(Object pemObject) throws DatabricksParsingException {
    PrivateKeyInfo privateKeyInfo;
    try {
      if (jwtKeyPassphrase != null) {
        // Decrypt and process PKCS #8 keys when JWT passphrase is provided
        PKCS8EncryptedPrivateKeyInfo encryptedKeyInfo = (PKCS8EncryptedPrivateKeyInfo) pemObject;
        JceOpenSSLPKCS8DecryptorProviderBuilder decryptorProviderBuilder =
            new JceOpenSSLPKCS8DecryptorProviderBuilder();
        decryptorProviderBuilder.setProvider(BOUNCY_CASTLE_PROVIDER);
        InputDecryptorProvider decryptorProvider =
            decryptorProviderBuilder.build(jwtKeyPassphrase.toCharArray());
        privateKeyInfo = encryptedKeyInfo.decryptPrivateKeyInfo(decryptorProvider);
      } else {
        // Processing unencrypted private keys when JWT passphrase is absent
        try {
          privateKeyInfo = ((PEMKeyPair) pemObject).getPrivateKeyInfo();
        } catch (ClassCastException e) {
          privateKeyInfo = (PrivateKeyInfo) pemObject;
        }
      }
      JcaPEMKeyConverter keyConverter =
          new JcaPEMKeyConverter().setProvider(BOUNCY_CASTLE_PROVIDER);
      return keyConverter.getPrivateKey(privateKeyInfo);
    } catch (OperatorCreationException | PKCSException | PEMException e) {
      String errorMessage = "Cannot decrypt private JWT key " + e.getMessage();
      LOGGER.error(errorMessage);
      throw new DatabricksParsingException(errorMessage);
    }
  }

  @VisibleForTesting
  SignedJWT fetchSignedJWT(PrivateKey privateKey) {
    try {
      JWSSigner signer;
      if (privateKey instanceof RSAPrivateKey) {
        // Use RSA Signer
        signer = new RSASSASigner(privateKey);
      } else if (privateKey instanceof ECPrivateKey) {
        // Use EC Signer
        signer = new ECDSASigner((ECPrivateKey) privateKey);
      } else {
        String errorMessage = "Unsupported private key type: " + privateKey.getClass().getName();
        LOGGER.error(errorMessage);
        throw new DatabricksException(errorMessage);
      }

      Timestamp timestamp = Timestamp.valueOf(LocalDateTime.now());
      JWTClaimsSet claimsSet =
          new JWTClaimsSet.Builder()
              .subject(clientId)
              .issuer(clientId)
              .issueTime(timestamp)
              .expirationTime(timestamp)
              .jwtID(UUID.randomUUID().toString())
              .audience(this.tokenUrl)
              .build();
      JWSHeader header = new JWSHeader.Builder(this.jwtAlgorithm).keyID(this.jwtKid).build();
      SignedJWT signedJWT = new SignedJWT(header, claimsSet);
      signedJWT.sign(signer);
      return signedJWT;
    } catch (JOSEException e) {
      String errorMessage = "Error signing the JWT: " + e.getMessage();
      LOGGER.error(errorMessage);
      throw new DatabricksException(errorMessage, e);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy