All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.openmetadata.service.security.AuthenticationCodeFlowHandler Maven / Gradle / Ivy

There is a newer version: 1.5.11
Show newest version
package org.openmetadata.service.security;

import static org.openmetadata.common.utils.CommonUtil.listOrEmpty;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import static org.openmetadata.service.security.JwtFilter.EMAIL_CLAIM_KEY;
import static org.openmetadata.service.security.JwtFilter.USERNAME_CLAIM_KEY;
import static org.openmetadata.service.security.SecurityUtil.findEmailFromClaims;
import static org.openmetadata.service.security.SecurityUtil.getClaimOrObject;
import static org.openmetadata.service.security.SecurityUtil.getFirstMatchJwtClaim;
import static org.openmetadata.service.security.SecurityUtil.writeJsonResponse;
import static org.openmetadata.service.util.UserUtil.getRoleListFromUser;
import static org.pac4j.core.util.CommonHelper.assertNotNull;
import static org.pac4j.core.util.CommonHelper.isNotEmpty;

import com.fasterxml.jackson.core.type.TypeReference;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.ErrorObject;
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.ClientSecretPost;
import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallenge;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.RefreshToken;
import com.nimbusds.oauth2.sdk.util.JSONObjectUtils;
import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.AuthenticationResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser;
import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse;
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.BadJWTExceptions;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TimeZone;
import java.util.TreeMap;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.ws.rs.BadRequestException;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.minidev.json.JSONObject;
import org.openmetadata.schema.api.security.AuthenticationConfiguration;
import org.openmetadata.schema.api.security.AuthorizerConfiguration;
import org.openmetadata.schema.auth.JWTAuthMechanism;
import org.openmetadata.schema.auth.ServiceTokenType;
import org.openmetadata.schema.entity.teams.User;
import org.openmetadata.schema.security.client.OidcClientConfig;
import org.openmetadata.schema.type.Include;
import org.openmetadata.service.Entity;
import org.openmetadata.service.auth.JwtResponse;
import org.openmetadata.service.security.jwt.JWTTokenGenerator;
import org.openmetadata.service.util.JsonUtils;
import org.pac4j.core.context.HttpConstants;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.core.util.HttpUtils;
import org.pac4j.oidc.client.AzureAd2Client;
import org.pac4j.oidc.client.GoogleOidcClient;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.config.AzureAd2OidcConfiguration;
import org.pac4j.oidc.config.OidcConfiguration;
import org.pac4j.oidc.config.PrivateKeyJWTClientAuthnMethodConfig;
import org.pac4j.oidc.credentials.OidcCredentials;

@Slf4j
public class AuthenticationCodeFlowHandler {
  private static final Collection SUPPORTED_METHODS =
      Arrays.asList(
          ClientAuthenticationMethod.CLIENT_SECRET_POST,
          ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
          ClientAuthenticationMethod.PRIVATE_KEY_JWT,
          ClientAuthenticationMethod.NONE);

  public static final String DEFAULT_PRINCIPAL_DOMAIN = "openmetadata.org";
  public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile";
  private final OidcClient client;
  private final List claimsOrder;
  private final Map claimsMapping;
  private final String serverUrl;
  private final ClientAuthentication clientAuthentication;
  private final String principalDomain;
  private final int tokenValidity;

  public AuthenticationCodeFlowHandler(
      AuthenticationConfiguration authenticationConfiguration,
      AuthorizerConfiguration authorizerConfiguration) {
    // Assert oidcConfig and Callback Url
    CommonHelper.assertNotNull(
        "OidcConfiguration", authenticationConfiguration.getOidcConfiguration());
    CommonHelper.assertNotBlank(
        "CallbackUrl", authenticationConfiguration.getOidcConfiguration().getCallbackUrl());
    CommonHelper.assertNotBlank(
        "ServerUrl", authenticationConfiguration.getOidcConfiguration().getServerUrl());

    // Build Required Params
    this.client = buildOidcClient(authenticationConfiguration.getOidcConfiguration());
    client.setCallbackUrl(authenticationConfiguration.getOidcConfiguration().getCallbackUrl());
    this.clientAuthentication = getClientAuthentication(client.getConfiguration());
    this.serverUrl = authenticationConfiguration.getOidcConfiguration().getServerUrl();
    this.claimsOrder = authenticationConfiguration.getJwtPrincipalClaims();
    this.claimsMapping =
        listOrEmpty(authenticationConfiguration.getJwtPrincipalClaimsMapping()).stream()
            .map(s -> s.split(":"))
            .collect(Collectors.toMap(s -> s[0], s -> s[1]));
    validatePrincipalClaimsMapping(claimsMapping);
    this.principalDomain = authorizerConfiguration.getPrincipalDomain();
    this.tokenValidity = authenticationConfiguration.getOidcConfiguration().getTokenValidity();
  }

  private OidcClient buildOidcClient(OidcClientConfig clientConfig) {
    String id = clientConfig.getId();
    String secret = clientConfig.getSecret();
    if (CommonHelper.isNotBlank(id) && CommonHelper.isNotBlank(secret)) {
      OidcConfiguration configuration = new OidcConfiguration();
      configuration.setClientId(id);

      configuration.setResponseMode("query");

      // Add Secret
      if (CommonHelper.isNotBlank(secret)) {
        configuration.setSecret(secret);
      }

      // Response Type
      String responseType = clientConfig.getResponseType();
      if (CommonHelper.isNotBlank(responseType)) {
        configuration.setResponseType(responseType);
      }

      String scope = clientConfig.getScope();
      if (CommonHelper.isNotBlank(scope)) {
        configuration.setScope(scope);
      }

      String discoveryUri = clientConfig.getDiscoveryUri();
      if (CommonHelper.isNotBlank(discoveryUri)) {
        configuration.setDiscoveryURI(discoveryUri);
      }

      String useNonce = clientConfig.getUseNonce();
      if (CommonHelper.isNotBlank(useNonce)) {
        configuration.setUseNonce(Boolean.parseBoolean(useNonce));
      }

      String jwsAlgo = clientConfig.getPreferredJwsAlgorithm();
      if (CommonHelper.isNotBlank(jwsAlgo)) {
        configuration.setPreferredJwsAlgorithm(JWSAlgorithm.parse(jwsAlgo));
      }

      String maxClockSkew = clientConfig.getMaxClockSkew();
      if (CommonHelper.isNotBlank(maxClockSkew)) {
        configuration.setMaxClockSkew(Integer.parseInt(maxClockSkew));
      }

      String clientAuthenticationMethod = clientConfig.getClientAuthenticationMethod().value();
      if (CommonHelper.isNotBlank(clientAuthenticationMethod)) {
        configuration.setClientAuthenticationMethod(
            ClientAuthenticationMethod.parse(clientAuthenticationMethod));
      }

      // Disable PKCE
      configuration.setDisablePkce(clientConfig.getDisablePkce());

      // Add Custom Params
      if (clientConfig.getCustomParams() != null) {
        for (int j = 1; j <= 5; ++j) {
          if (clientConfig.getCustomParams().containsKey(String.format("customParamKey%d", j))) {
            configuration.addCustomParam(
                clientConfig.getCustomParams().get(String.format("customParamKey%d", j)),
                clientConfig.getCustomParams().get(String.format("customParamValue%d", j)));
          }
        }
      }

      String type = clientConfig.getType();
      OidcClient oidcClient;
      if ("azure".equalsIgnoreCase(type)) {
        AzureAd2OidcConfiguration azureAdConfiguration =
            new AzureAd2OidcConfiguration(configuration);
        String tenant = clientConfig.getTenant();
        if (CommonHelper.isNotBlank(tenant)) {
          azureAdConfiguration.setTenant(tenant);
        }

        oidcClient = new AzureAd2Client(azureAdConfiguration);
      } else if ("google".equalsIgnoreCase(type)) {
        oidcClient = new GoogleOidcClient(configuration);
        // Google needs it as param
        oidcClient.getConfiguration().getCustomParams().put("access_type", "offline");
      } else {
        oidcClient = new OidcClient(configuration);
      }

      oidcClient.setName(String.format("OMOidcClient%s", oidcClient.getName()));
      return oidcClient;
    }
    throw new IllegalArgumentException(
        "Client ID and Client Secret is required to create OidcClient");
  }

  // Login
  public void handleLogin(HttpServletRequest req, HttpServletResponse resp) {
    try {
      LOG.debug("Performing Auth Login For User Session: {} ", req.getSession().getId());
      Optional credentials = getUserCredentialsFromSession(req);
      if (credentials.isPresent()) {
        LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId());
        sendRedirectWithToken(resp, credentials.get());
      } else {
        LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId());
        Map params = buildLoginParams();

        params.put(OidcConfiguration.REDIRECT_URI, client.getCallbackUrl());

        addStateAndNonceParameters(client, req, params);

        // This is always used to prompt the user to login
        if (client instanceof GoogleOidcClient) {
          params.put(OidcConfiguration.PROMPT, "consent");
        } else {
          params.put(OidcConfiguration.PROMPT, "login");
        }
        params.put(OidcConfiguration.MAX_AGE, "0");

        String location = buildLoginAuthenticationRequestUrl(params);
        LOG.debug("Authentication request url: {}", location);

        resp.sendRedirect(location);
      }
    } catch (Exception e) {
      getErrorMessage(resp, new TechnicalException(e));
    }
  }

  // Callback
  public void handleCallback(HttpServletRequest req, HttpServletResponse resp) {
    try {
      LOG.debug("Performing Auth Callback For User Session: {} ", req.getSession().getId());
      String computedCallbackUrl = client.getCallbackUrl();
      Map> parameters = retrieveCallbackParameters(req);
      AuthenticationResponse response =
          AuthenticationResponseParser.parse(new URI(computedCallbackUrl), parameters);

      if (response instanceof AuthenticationErrorResponse authenticationErrorResponse) {
        LOG.error(
            "Bad authentication response, error={}", authenticationErrorResponse.getErrorObject());
        throw new TechnicalException("Bad authentication response");
      }

      LOG.debug("Authentication response successful");
      AuthenticationSuccessResponse successResponse = (AuthenticationSuccessResponse) response;

      OIDCProviderMetadata metadata = client.getConfiguration().getProviderMetadata();
      if (metadata.supportsAuthorizationResponseIssuerParam()
          && !metadata.getIssuer().equals(successResponse.getIssuer())) {
        throw new TechnicalException("Issuer mismatch, possible mix-up attack.");
      }

      // Optional state validation
      validateStateIfRequired(req, resp, successResponse);

      // Build Credentials
      OidcCredentials credentials = buildCredentials(successResponse);

      // Validations
      validateAndSendTokenRequest(req, credentials, computedCallbackUrl);

      // Log Error if the Refresh Token is null
      if (credentials.getRefreshToken() == null) {
        LOG.error("Refresh token is null for user session: {}", req.getSession().getId());
      }

      validateNonceIfRequired(req, credentials.getIdToken().getJWTClaimsSet());

      // Put Credentials in Session
      req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);

      // Redirect
      sendRedirectWithToken(resp, credentials);
    } catch (Exception e) {
      getErrorMessage(resp, e);
    }
  }

  // Logout
  public void handleLogout(
      HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
    try {
      LOG.debug("Performing application logout");
      HttpSession session = httpServletRequest.getSession(false);
      if (session != null) {
        LOG.debug("Invalidating the session for logout");
        session.invalidate();
        httpServletResponse.sendRedirect(serverUrl);
      } else {
        LOG.error("No session store available for this web context");
      }
    } catch (Exception ex) {
      LOG.error("[Auth Logout] Error while performing logout", ex);
    }
  }

  // Refresh
  public void handleRefresh(
      HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
    try {
      LOG.debug(
          "Performing Auth Refresh For User Session: {} ", httpServletRequest.getSession().getId());
      Optional credentials = getUserCredentialsFromSession(httpServletRequest);
      if (credentials.isPresent()) {
        LOG.debug(
            "Credentials Found For User Session: {} ", httpServletRequest.getSession().getId());
        JwtResponse jwtResponse = new JwtResponse();
        jwtResponse.setAccessToken(credentials.get().getIdToken().getParsedString());
        jwtResponse.setExpiryDuration(
            credentials
                .get()
                .getIdToken()
                .getJWTClaimsSet()
                .getExpirationTime()
                .toInstant()
                .getEpochSecond());
        writeJsonResponse(httpServletResponse, JsonUtils.pojoToJson(jwtResponse));
      } else {
        LOG.debug(
            "Credentials Not Found For User Session: {}, Redirect to Logout ",
            httpServletRequest.getSession().getId());
        httpServletResponse.sendRedirect(String.format("%s/logout", serverUrl));
      }
    } catch (Exception e) {
      getErrorMessage(httpServletResponse, new TechnicalException(e));
    }
  }

  private String buildLoginAuthenticationRequestUrl(final Map params) {
    // Build authentication request query string
    String queryString;
    try {
      queryString =
          AuthenticationRequest.parse(
                  params.entrySet().stream()
                      .collect(
                          Collectors.toMap(
                              Map.Entry::getKey, e -> Collections.singletonList(e.getValue()))))
              .toQueryString();
    } catch (Exception e) {
      throw new TechnicalException(e);
    }
    return client.getConfiguration().getProviderMetadata().getAuthorizationEndpointURI().toString()
        + '?'
        + queryString;
  }

  private Map buildLoginParams() {
    Map authParams = new HashMap<>();
    authParams.put(OidcConfiguration.SCOPE, client.getConfiguration().getScope());
    authParams.put(OidcConfiguration.RESPONSE_TYPE, client.getConfiguration().getResponseType());
    authParams.put(OidcConfiguration.RESPONSE_MODE, "query");
    authParams.putAll(client.getConfiguration().getCustomParams());
    authParams.put(OidcConfiguration.CLIENT_ID, client.getConfiguration().getClientId());

    return new HashMap<>(authParams);
  }

  private Optional getUserCredentialsFromSession(HttpServletRequest request)
      throws URISyntaxException {
    OidcCredentials credentials =
        (OidcCredentials) request.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE);

    if (credentials != null && credentials.getRefreshToken() != null) {
      LOG.trace("Credentials found in session: {}", credentials);
      renewOidcCredentials(request, credentials);
      return Optional.of(credentials);
    } else {
      if (credentials == null) {
        LOG.error("No credentials found against session. ID: {}", request.getSession().getId());
      } else {
        LOG.error("No refresh token found against session. ID: {}", request.getSession().getId());
      }
    }
    return Optional.empty();
  }

  private void validateAndSendTokenRequest(
      HttpServletRequest httpServletRequest,
      OidcCredentials oidcCredentials,
      String computedCallbackUrl)
      throws IOException, com.nimbusds.oauth2.sdk.ParseException, URISyntaxException {
    if (oidcCredentials.getCode() != null) {
      LOG.debug(
          "Initiating Token Request for User Session: {} ",
          httpServletRequest.getSession().getId());
      CodeVerifier verifier =
          (CodeVerifier)
              httpServletRequest
                  .getSession()
                  .getAttribute(client.getCodeVerifierSessionAttributeName());
      // Token request
      TokenRequest request =
          createTokenRequest(
              new AuthorizationCodeGrant(
                  oidcCredentials.getCode(), new URI(computedCallbackUrl), verifier));
      executeAuthorizationCodeTokenRequest(httpServletRequest, request, oidcCredentials);
    }
  }

  private void validateStateIfRequired(
      HttpServletRequest req,
      HttpServletResponse resp,
      AuthenticationSuccessResponse successResponse) {
    if (client.getConfiguration().isWithState()) {
      // Validate state for CSRF mitigation
      State requestState =
          (State) req.getSession().getAttribute(client.getStateSessionAttributeName());
      if (requestState == null || CommonHelper.isBlank(requestState.getValue())) {
        getErrorMessage(resp, new TechnicalException("Missing state parameter"));
        return;
      }

      State responseState = successResponse.getState();
      if (responseState == null) {
        throw new TechnicalException("Missing state parameter");
      }

      LOG.debug("Request state: {}/response state: {}", requestState, responseState);
      if (!requestState.equals(responseState)) {
        throw new TechnicalException(
            "State parameter is different from the one sent in authentication request.");
      }
    }
  }

  private OidcCredentials buildCredentials(AuthenticationSuccessResponse successResponse) {
    OidcCredentials credentials = new OidcCredentials();
    // get authorization code
    AuthorizationCode code = successResponse.getAuthorizationCode();
    if (code != null) {
      credentials.setCode(code);
    }
    // get ID token
    JWT idToken = successResponse.getIDToken();
    if (idToken != null) {
      credentials.setIdToken(idToken);
    }
    // get access token
    AccessToken accessToken = successResponse.getAccessToken();
    if (accessToken != null) {
      credentials.setAccessToken(accessToken);
    }

    return credentials;
  }

  private void validateNonceIfRequired(HttpServletRequest req, JWTClaimsSet claimsSet)
      throws BadJOSEException {
    if (client.getConfiguration().isUseNonce()) {
      String expectedNonce =
          (String) req.getSession().getAttribute(client.getNonceSessionAttributeName());
      if (CommonHelper.isNotBlank(expectedNonce)) {
        String tokenNonce;
        try {
          tokenNonce = claimsSet.getStringClaim("nonce");
        } catch (java.text.ParseException var10) {
          throw new BadJWTException("Invalid JWT nonce (nonce) claim: " + var10.getMessage());
        }

        if (tokenNonce == null) {
          throw BadJWTExceptions.MISSING_NONCE_CLAIM_EXCEPTION;
        }

        if (!expectedNonce.equals(tokenNonce)) {
          throw new BadJWTException("Unexpected JWT nonce (nonce) claim: " + tokenNonce);
        }
      } else {
        throw new TechnicalException("Missing nonce parameter from Session.");
      }
    }
  }

  protected Map> retrieveCallbackParameters(HttpServletRequest request) {
    Map requestParameters = request.getParameterMap();
    Map> map = new HashMap<>();
    for (var entry : requestParameters.entrySet()) {
      map.put(entry.getKey(), Arrays.asList(entry.getValue()));
    }
    return map;
  }

  private ClientAuthentication getClientAuthentication(OidcConfiguration configuration) {
    ClientID clientID = new ClientID(configuration.getClientId());
    ClientAuthentication clientAuthenticationMechanism = null;
    if (configuration.getSecret() != null) {
      // check authentication methods
      List metadataMethods =
          configuration.findProviderMetadata().getTokenEndpointAuthMethods();

      ClientAuthenticationMethod preferredMethod = getPreferredAuthenticationMethod(configuration);

      final ClientAuthenticationMethod chosenMethod;
      if (isNotEmpty(metadataMethods)) {
        if (preferredMethod != null) {
          if (metadataMethods.contains(preferredMethod)) {
            chosenMethod = preferredMethod;
          } else {
            throw new TechnicalException(
                "Preferred authentication method ("
                    + preferredMethod
                    + ") not supported "
                    + "by provider according to provider metadata ("
                    + metadataMethods
                    + ").");
          }
        } else {
          chosenMethod = firstSupportedMethod(metadataMethods);
        }
      } else {
        chosenMethod =
            preferredMethod != null ? preferredMethod : ClientAuthenticationMethod.getDefault();
        LOG.info(
            "Provider metadata does not provide Token endpoint authentication methods. Using: {}",
            chosenMethod);
      }

      if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(chosenMethod)) {
        Secret clientSecret = new Secret(configuration.getSecret());
        clientAuthenticationMechanism = new ClientSecretPost(clientID, clientSecret);
      } else if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(chosenMethod)) {
        Secret clientSecret = new Secret(configuration.getSecret());
        clientAuthenticationMechanism = new ClientSecretBasic(clientID, clientSecret);
      } else if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(chosenMethod)) {
        PrivateKeyJWTClientAuthnMethodConfig privateKetJwtConfig =
            configuration.getPrivateKeyJWTClientAuthnMethodConfig();
        assertNotNull("privateKetJwtConfig", privateKetJwtConfig);
        JWSAlgorithm jwsAlgo = privateKetJwtConfig.getJwsAlgorithm();
        assertNotNull("privateKetJwtConfig.getJwsAlgorithm()", jwsAlgo);
        PrivateKey privateKey = privateKetJwtConfig.getPrivateKey();
        assertNotNull("privateKetJwtConfig.getPrivateKey()", privateKey);
        String keyID = privateKetJwtConfig.getKeyID();
        try {
          clientAuthenticationMechanism =
              new PrivateKeyJWT(
                  clientID,
                  configuration.findProviderMetadata().getTokenEndpointURI(),
                  jwsAlgo,
                  privateKey,
                  keyID,
                  null);
        } catch (final JOSEException e) {
          throw new TechnicalException(
              "Cannot instantiate private key JWT client authentication method", e);
        }
      }
    }

    return clientAuthenticationMechanism;
  }

  private static ClientAuthenticationMethod getPreferredAuthenticationMethod(
      OidcConfiguration config) {
    ClientAuthenticationMethod configurationMethod = config.getClientAuthenticationMethod();
    if (configurationMethod == null) {
      return null;
    }

    if (!SUPPORTED_METHODS.contains(configurationMethod)) {
      throw new TechnicalException(
          "Configured authentication method (" + configurationMethod + ") is not supported.");
    }

    return configurationMethod;
  }

  private ClientAuthenticationMethod firstSupportedMethod(
      final List metadataMethods) {
    Optional firstSupported =
        metadataMethods.stream().filter(SUPPORTED_METHODS::contains).findFirst();
    if (firstSupported.isPresent()) {
      return firstSupported.get();
    } else {
      throw new TechnicalException(
          "None of the Token endpoint provider metadata authentication methods are supported: "
              + metadataMethods);
    }
  }

  @SneakyThrows
  public static void getErrorMessage(HttpServletResponse resp, Exception e) {
    resp.setContentType("text/html; charset=UTF-8");
    LOG.error("[Auth Callback Servlet] Failed in Auth Login : {}", e.getMessage());
    resp.getOutputStream()
        .println(
            String.format(
                "

[Auth Callback Servlet] Failed in Auth Login : %s

", e.getMessage())); } private void sendRedirectWithToken(HttpServletResponse response, OidcCredentials credentials) throws ParseException, IOException { JWT jwt = credentials.getIdToken(); Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); claims.putAll(jwt.getJWTClaimsSet().getClaims()); String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims); String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, principalDomain); String url = String.format( "%s/auth/callback?id_token=%s&email=%s&name=%s", serverUrl, credentials.getIdToken().getParsedString(), email, userName); response.sendRedirect(url); } private void renewOidcCredentials(HttpServletRequest request, OidcCredentials credentials) { LOG.debug("Renewing Credentials for User Session {}", request.getSession().getId()); if (client.getConfiguration() instanceof AzureAd2OidcConfiguration azureAd2OidcConfiguration) { refreshAccessTokenAzureAd2Token(azureAd2OidcConfiguration, credentials); } else { refreshTokenRequest(request, credentials); } request.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); } public void refreshTokenRequest( final HttpServletRequest httpServletRequest, final OidcCredentials credentials) { final var refreshToken = credentials.getRefreshToken(); if (refreshToken != null) { try { final var request = createTokenRequest(new RefreshTokenGrant(refreshToken)); HTTPResponse httpResponse = executeTokenHttpRequest(request); if (httpResponse.getStatusCode() == 200) { JSONObject jsonObjectResponse = httpResponse.getContentAsJSONObject(); String idTokenKey = "id_token"; if (jsonObjectResponse.containsKey(idTokenKey)) { Object value = jsonObjectResponse.get(idTokenKey); if (value == null) { throw new com.nimbusds.oauth2.sdk.ParseException( "JSON object member with key " + idTokenKey + " has null value"); } else { LOG.info("Found a JWT token in the response, trying to parse it"); OIDCTokenResponse tokenSuccessResponse = parseTokenResponseFromHttpResponse(httpResponse); // Populate credentials populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials); } } else { // Note: since the id_token is not present, we must receive accessToken // We can do better and get userInfo from // client.getConfiguration().findProviderMetadata().getUserInfoEndpointURI() // but currently we are just return the OM created token in the response String accessToken = JSONObjectUtils.getString(jsonObjectResponse, "access_token"); LOG.info( "Found an access token in the response, trying to parse it, Value : {}", accessToken); OIDCTokenResponse tokenSuccessResponse = parseTokenResponseFromHttpResponse(httpResponse); // Populate credentials populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials); OidcCredentials storedCredentials = (OidcCredentials) httpServletRequest.getSession().getAttribute(OIDC_CREDENTIAL_PROFILE); // Get the claims from the stored credentials Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); claims.putAll(storedCredentials.getIdToken().getJWTClaimsSet().getClaims()); String username = SecurityUtil.findUserNameFromClaims(claimsMapping, claimsOrder, claims); User user = Entity.getEntityByName(Entity.USER, username, "id", Include.NON_DELETED); // Create a JWT here JWTAuthMechanism jwtAuthMechanism = JWTTokenGenerator.getInstance() .generateJWTToken( username, getRoleListFromUser(user), !nullOrEmpty(user.getIsAdmin()) && user.getIsAdmin(), user.getEmail(), tokenValidity, false, ServiceTokenType.OM_USER); // Set the access token to the new JWT token credentials.setIdToken(SignedJWT.parse(jwtAuthMechanism.getJWTToken())); } return; } else { throw new TechnicalException( String.format( "Failed to refresh id_token, response code:%s , Error : %s", httpResponse.getStatusCode(), httpResponse.getContent())); } } catch (final IOException | com.nimbusds.oauth2.sdk.ParseException e) { throw new TechnicalException(e); } catch (ParseException e) { throw new RuntimeException(e); } } throw new BadRequestException("No refresh token available"); } public static boolean isJWT(String token) { return token.split("\\.").length == 3; } private void refreshAccessTokenAzureAd2Token( AzureAd2OidcConfiguration azureConfig, OidcCredentials azureAdProfile) { HttpURLConnection connection = null; try { Map headers = new HashMap<>(); headers.put( HttpConstants.CONTENT_TYPE_HEADER, HttpConstants.APPLICATION_FORM_ENCODED_HEADER_VALUE); headers.put(HttpConstants.ACCEPT_HEADER, HttpConstants.APPLICATION_JSON); // get the token endpoint from discovery URI URL tokenEndpointURL = azureConfig.findProviderMetadata().getTokenEndpointURI().toURL(); connection = HttpUtils.openPostConnection(tokenEndpointURL, headers); BufferedWriter out = new BufferedWriter( new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8)); out.write(azureConfig.makeOauth2TokenRequest(azureAdProfile.getRefreshToken().getValue())); out.close(); int responseCode = connection.getResponseCode(); if (responseCode != 200) { throw new TechnicalException( "request for access token failed: " + HttpUtils.buildHttpErrorMessage(connection)); } var body = HttpUtils.readBody(connection); Map res = JsonUtils.readValue(body, new TypeReference<>() {}); // Populate Tokens azureAdProfile.setAccessToken(new BearerAccessToken((String) res.get("access_token"))); azureAdProfile.setRefreshToken(new RefreshToken((String) res.get("refresh_token"))); azureAdProfile.setIdToken(SignedJWT.parse((String) res.get("id_token"))); } catch (final IOException e) { throw new TechnicalException(e); } catch (ParseException e) { throw new TechnicalException(e); } finally { HttpUtils.closeConnection(connection); } } public static String findUserNameFromClaims( Map jwtPrincipalClaimsMapping, List jwtPrincipalClaimsOrder, Map claims) { if (!nullOrEmpty(jwtPrincipalClaimsMapping)) { // We have a mapping available so we will use that String usernameClaim = jwtPrincipalClaimsMapping.get(USERNAME_CLAIM_KEY); String userNameClaimValue = getClaimOrObject(claims.get(usernameClaim)); if (!nullOrEmpty(userNameClaimValue)) { return userNameClaimValue; } else { throw new AuthenticationException("Invalid JWT token, 'username' claim is not present"); } } else { String jwtClaim = getFirstMatchJwtClaim(jwtPrincipalClaimsOrder, claims); String userName; if (jwtClaim.contains("@")) { userName = jwtClaim.split("@")[0]; } else { userName = jwtClaim; } return userName; } } public static void validatePrincipalClaimsMapping(Map mapping) { if (!nullOrEmpty(mapping)) { String username = mapping.get(USERNAME_CLAIM_KEY); String email = mapping.get(EMAIL_CLAIM_KEY); if (nullOrEmpty(username) || nullOrEmpty(email)) { throw new IllegalArgumentException( "Invalid JWT Principal Claims Mapping. Both username and email should be present"); } } // If emtpy, jwtPrincipalClaims will be used so no need to validate } private HTTPResponse executeTokenHttpRequest(TokenRequest request) throws IOException { HTTPRequest tokenHttpRequest = request.toHTTPRequest(); client.getConfiguration().configureHttpRequest(tokenHttpRequest); HTTPResponse httpResponse = tokenHttpRequest.send(); LOG.debug( "Token response: status={}, content={}", httpResponse.getStatusCode(), httpResponse.getContent()); return httpResponse; } private TokenRequest createTokenRequest(final AuthorizationGrant grant) { if (clientAuthentication != null) { return new TokenRequest( client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), this.clientAuthentication, grant); } else { return new TokenRequest( client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), new ClientID(client.getConfiguration().getClientId()), grant); } } private void addStateAndNonceParameters( final OidcClient client, final HttpServletRequest request, final Map params) { // Init state for CSRF mitigation if (client.getConfiguration().isWithState()) { State state = new State(CommonHelper.randomString(10)); params.put(OidcConfiguration.STATE, state.getValue()); request.getSession().setAttribute(client.getStateSessionAttributeName(), state); } // Init nonce for replay attack mitigation if (client.getConfiguration().isUseNonce()) { Nonce nonce = new Nonce(); params.put(OidcConfiguration.NONCE, nonce.getValue()); request.getSession().setAttribute(client.getNonceSessionAttributeName(), nonce.getValue()); } CodeChallengeMethod pkceMethod = client.getConfiguration().findPkceMethod(); // Use Default PKCE method if not disabled if (pkceMethod == null && !client.getConfiguration().isDisablePkce()) { pkceMethod = CodeChallengeMethod.S256; } if (pkceMethod != null) { CodeVerifier verfifier = new CodeVerifier(CommonHelper.randomString(43)); request.getSession().setAttribute(client.getCodeVerifierSessionAttributeName(), verfifier); params.put( OidcConfiguration.CODE_CHALLENGE, CodeChallenge.compute(pkceMethod, verfifier).getValue()); params.put(OidcConfiguration.CODE_CHALLENGE_METHOD, pkceMethod.getValue()); } } @SneakyThrows private void executeAuthorizationCodeTokenRequest( HttpServletRequest httpServletRequest, TokenRequest request, OidcCredentials credentials) throws IOException, com.nimbusds.oauth2.sdk.ParseException { HTTPResponse httpResponse = executeTokenHttpRequest(request); OIDCTokenResponse tokenSuccessResponse = parseTokenResponseFromHttpResponse(httpResponse); // Populate credentials populateCredentialsFromTokenResponse(tokenSuccessResponse, credentials); // Check expiry, azure on first go itself is returning a expried token sometimes Date expirationTime = credentials.getIdToken().getJWTClaimsSet().getExpirationTime(); if (expirationTime != null && expirationTime.before(Calendar.getInstance(TimeZone.getTimeZone("UTC")).getTime())) { renewOidcCredentials(httpServletRequest, credentials); } } private void populateCredentialsFromTokenResponse( OIDCTokenResponse tokenSuccessResponse, OidcCredentials credentials) { OIDCTokens oidcTokens = tokenSuccessResponse.getOIDCTokens(); credentials.setAccessToken(oidcTokens.getAccessToken()); credentials.setRefreshToken(oidcTokens.getRefreshToken()); if (oidcTokens.getIDToken() != null) { credentials.setIdToken(oidcTokens.getIDToken()); } } private OIDCTokenResponse parseTokenResponseFromHttpResponse(HTTPResponse httpResponse) throws com.nimbusds.oauth2.sdk.ParseException { TokenResponse response = OIDCTokenResponseParser.parse(httpResponse); if (response instanceof TokenErrorResponse tokenErrorResponse) { ErrorObject errorObject = tokenErrorResponse.getErrorObject(); throw new TechnicalException( "Bad token response, error=" + errorObject.getCode() + "," + " description=" + errorObject.getDescription()); } LOG.debug("Token response successful"); return (OIDCTokenResponse) response; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy