All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.salesforce.einsteinbot.sdk.auth.JwtBearerOAuth Maven / Gradle / Ivy


/*
 * Copyright (c) 2022, salesforce.com, inc.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 */

package com.salesforce.einsteinbot.sdk.auth;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTCreationException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import com.salesforce.einsteinbot.sdk.cache.Cache;
import com.salesforce.einsteinbot.sdk.exception.OAuthResponseException;
import com.salesforce.einsteinbot.sdk.util.WebClientUtil;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.apache.http.HttpHeaders;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

/**
 * Implementation of AuthMechanism interface that is used to integrate with Einstein Bots using
 * OAuth.
 */
public class JwtBearerOAuth implements AuthMechanism {

  private static final Logger logger = LoggerFactory.getLogger(JwtBearerOAuth.class);

  public static final String JWT_AUTH_TOKEN_PREFIX = "Bearer ";

  private final int jwtExpiryMinutes = 15;
  private final String cacheKeyPrefix = "bots-oAuthToken-";

  private String loginEndpoint;
  private String connectedAppId;
  private String userId;
  private PrivateKey privateKey;
  private WebClient webClient;
  private Cache cache;
  private Introspector introspector;

  public JwtBearerOAuth(String privateKeyFilePath, String loginEndpoint, String connectedAppId,
      String connectedAppSecret,
      String userId, Cache cache) {
    this(getPrivateKey(privateKeyFilePath), loginEndpoint, connectedAppId, connectedAppSecret,
        userId, cache);
  }

  public JwtBearerOAuth(PrivateKey privateKey, String loginEndpoint, String connectedAppId,
      String connectedAppSecret,
      String userId, Cache cache) {
    this.privateKey = privateKey;
    this.userId = userId;
    this.connectedAppId = connectedAppId;
    this.loginEndpoint = loginEndpoint;
    this.webClient = WebClient
        .builder()
        .baseUrl(loginEndpoint)
        .filter(WebClientUtil.createFilter(
            clientRequest -> WebClientUtil.createLoggingRequestProcessor(clientRequest),
            clientResponse -> WebClientUtil
                .createErrorResponseProcessor(clientResponse, this::mapErrorResponse)))
        .build();
    this.cache = cache;
    this.introspector = new Introspector(connectedAppId, connectedAppSecret, loginEndpoint);
  }

  private Mono mapErrorResponse(ClientResponse clientResponse) {
    return clientResponse
        .bodyToMono(String.class)
        .flatMap(errorDetails -> Mono
            .error(new OAuthResponseException(clientResponse.statusCode(), errorDetails)));
  }

  @VisibleForTesting
  void setIntrospector(Introspector introspector) {
    this.introspector = introspector;
  }

  @Override
  public String getToken() {
    Optional token = cache.get(getCacheKey());
    if (token.isPresent()) {
      logger.debug("Found cached OAuth token.");
      return token.get();
    }

    logger.debug("Did not find OAuth token in cache. Will retrieve from OAuth server.");
    Instant now = Instant.now();
    String jwt = null;

    try {
      Map headers = new HashMap();
      headers.put("alg", "RS256");
      Algorithm algorithm = Algorithm.RSA256(null, (RSAPrivateKey) privateKey);
      jwt = JWT.create()
          .withHeader(headers)
          .withAudience(loginEndpoint)
          .withExpiresAt(Date.from(now.plus(jwtExpiryMinutes, ChronoUnit.MINUTES)))
          .withIssuer(connectedAppId)
          .withSubject(userId)
          .sign(algorithm);

      logger.debug("Generated jwt: {} ", jwt);

    } catch (JWTCreationException exception) {
      //Invalid Signing configuration / Couldn't convert Claims.
      throw new RuntimeException(exception);
    }

    String response = webClient.post()
        .uri("/services/oauth2/token")
        .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
        .body(
            BodyInserters.fromFormData("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
                .with("assertion", jwt))
        .retrieve()
        .bodyToMono(String.class)
        .block();

    String oAuthToken = null;
    try {
      ObjectNode node = new ObjectMapper().readValue(response, ObjectNode.class);
      oAuthToken = node.get("access_token").asText();
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }

    IntrospectionResult iResult = introspector.introspect(oAuthToken);
    if (!iResult.isActive()) {
      throw new RuntimeException("OAuth token is not active.");
    }

    Instant expiry = Instant.ofEpochSecond(iResult.getExp());
    long ttl = Math.max(0, Instant.now().until(expiry, ChronoUnit.SECONDS) - 300);

    cache.set(getCacheKey(), oAuthToken, ttl);
    return oAuthToken;
  }

  @Override
  public String getAuthorizationHeader() {
    return JWT_AUTH_TOKEN_PREFIX + getToken();
  }

  private static PrivateKey getPrivateKey(String filename) {
    try {
      File f = new File(filename);
      FileInputStream fis = new FileInputStream(f);
      DataInputStream dis = new DataInputStream(fis);
      byte[] keyBytes = new byte[(int) f.length()];
      dis.readFully(keyBytes);
      dis.close();

      PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
      KeyFactory kf = KeyFactory.getInstance("RSA");
      return kf.generatePrivate(spec);
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }
  }

  private String getCacheKey() {
    return cacheKeyPrefix + connectedAppId;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy