Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.server.security.oauth2.NimbusOAuth2Client Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server.security.oauth2;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.inject.Inject;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import com.nimbusds.oauth2.sdk.AccessTokenResponse;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationRequest;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.RefreshToken;
import com.nimbusds.oauth2.sdk.token.Tokens;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.UserInfoRequest;
import com.nimbusds.openid.connect.sdk.UserInfoResponse;
import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash;
import com.nimbusds.openid.connect.sdk.claims.IDTokenClaimsSet;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.AccessTokenValidator;
import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator;
import com.nimbusds.openid.connect.sdk.validators.InvalidHashException;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.server.security.oauth2.OAuth2ServerConfigProvider.OAuth2ServerConfig;
import jakarta.ws.rs.core.UriBuilder;
import java.net.MalformedURLException;
import java.net.URI;
import java.time.Instant;
import java.util.Date;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.hash.Hashing.sha256;
import static com.nimbusds.oauth2.sdk.ResponseType.CODE;
import static com.nimbusds.openid.connect.sdk.OIDCScopeValue.OPENID;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.SECONDS;
public class NimbusOAuth2Client
implements OAuth2Client
{
private static final Logger LOG = Logger.get(NimbusAirliftHttpClient.class);
private final Issuer issuer;
private final ClientID clientId;
private final ClientSecretBasic clientAuth;
private final Scope scope;
private final String principalField;
private final Set accessTokenAudiences;
private final Duration maxClockSkew;
private final Optional jwtType;
private final NimbusHttpClient httpClient;
private final OAuth2ServerConfigProvider serverConfigurationProvider;
private volatile boolean loaded;
private URI authUrl;
private URI tokenUrl;
private Optional userinfoUrl;
private Optional endSessionUrl;
private JWSKeySelector jwsKeySelector;
private JWTProcessor accessTokenProcessor;
private AuthorizationCodeFlow flow;
@Inject
public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider serverConfigurationProvider, NimbusHttpClient httpClient)
{
issuer = new Issuer(oauthConfig.getIssuer());
clientId = new ClientID(oauthConfig.getClientId());
clientAuth = new ClientSecretBasic(clientId, new Secret(oauthConfig.getClientSecret()));
scope = Scope.parse(oauthConfig.getScopes());
principalField = oauthConfig.getPrincipalField();
maxClockSkew = oauthConfig.getMaxClockSkew();
jwtType = oauthConfig.getJwtType();
accessTokenAudiences = new HashSet<>(oauthConfig.getAdditionalAudiences());
accessTokenAudiences.add(clientId.getValue());
accessTokenAudiences.add(null); // A null value in the set allows JWTs with no audience
this.serverConfigurationProvider = requireNonNull(serverConfigurationProvider, "serverConfigurationProvider is null");
this.httpClient = requireNonNull(httpClient, "httpClient is null");
}
@Override
public void load()
{
OAuth2ServerConfig config = serverConfigurationProvider.get();
this.authUrl = config.authUrl();
this.tokenUrl = config.tokenUrl();
this.userinfoUrl = config.userinfoUrl();
this.endSessionUrl = config.endSessionUrl();
try {
jwsKeySelector = new JWSVerificationKeySelector<>(
Stream.concat(JWSAlgorithm.Family.RSA.stream(), JWSAlgorithm.Family.EC.stream()).collect(toImmutableSet()),
JWKSourceBuilder.create(config.jwksUrl().toURL(), httpClient).build());
}
catch (MalformedURLException e) {
throw new RuntimeException(e);
}
DefaultJWTProcessor processor = new DefaultJWTProcessor<>();
if (jwtType.isPresent()) {
processor.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType(jwtType.get())));
}
processor.setJWSKeySelector(jwsKeySelector);
DefaultJWTClaimsVerifier accessTokenVerifier = new DefaultJWTClaimsVerifier<>(
accessTokenAudiences,
new JWTClaimsSet.Builder()
.issuer(config.accessTokenIssuer().orElse(issuer.getValue()))
.build(),
ImmutableSet.of(principalField),
ImmutableSet.of());
accessTokenVerifier.setMaxClockSkew((int) maxClockSkew.roundTo(SECONDS));
processor.setJWTClaimsSetVerifier(accessTokenVerifier);
accessTokenProcessor = processor;
flow = scope.contains(OPENID) ? new OAuth2WithOidcExtensionsCodeFlow() : new OAuth2AuthorizationCodeFlow();
loaded = true;
}
@Override
public Request createAuthorizationRequest(String state, URI callbackUri)
{
checkState(loaded, "OAuth2 client not initialized");
return flow.createAuthorizationRequest(state, callbackUri);
}
@Override
public Response getOAuth2Response(String code, URI callbackUri, Optional nonce)
throws ChallengeFailedException
{
checkState(loaded, "OAuth2 client not initialized");
return flow.getOAuth2Response(code, callbackUri, nonce);
}
@Override
public Optional> getClaims(String accessToken)
{
checkState(loaded, "OAuth2 client not initialized");
return getJWTClaimsSet(accessToken).map(JWTClaimsSet::getClaims);
}
@Override
public Response refreshTokens(String refreshToken)
throws ChallengeFailedException
{
checkState(loaded, "OAuth2 client not initialized");
return flow.refreshTokens(refreshToken);
}
@Override
public Optional getLogoutEndpoint(Optional idToken, URI callbackUrl)
{
if (endSessionUrl.isPresent()) {
UriBuilder builder = UriBuilder.fromUri(endSessionUrl.get());
idToken.ifPresent(token -> builder.queryParam("id_token_hint", token));
builder.queryParam("post_logout_redirect_uri", callbackUrl);
return Optional.of(builder.build());
}
return Optional.empty();
}
private interface AuthorizationCodeFlow
{
Request createAuthorizationRequest(String state, URI callbackUri);
Response getOAuth2Response(String code, URI callbackUri, Optional nonce)
throws ChallengeFailedException;
Response refreshTokens(String refreshToken)
throws ChallengeFailedException;
}
private class OAuth2AuthorizationCodeFlow
implements AuthorizationCodeFlow
{
@Override
public Request createAuthorizationRequest(String state, URI callbackUri)
{
return new Request(
new AuthorizationRequest.Builder(CODE, clientId)
.redirectionURI(callbackUri)
.scope(scope)
.endpointURI(authUrl)
.state(new State(state))
.build()
.toURI(),
Optional.empty());
}
@Override
public Response getOAuth2Response(String code, URI callbackUri, Optional nonce)
throws ChallengeFailedException
{
checkArgument(nonce.isEmpty(), "Unexpected nonce provided");
AccessTokenResponse tokenResponse = getTokenResponse(code, callbackUri, AccessTokenResponse::parse);
Tokens tokens = tokenResponse.toSuccessResponse().getTokens();
return toResponse(tokens, Optional.empty());
}
@Override
public Response refreshTokens(String refreshToken)
throws ChallengeFailedException
{
requireNonNull(refreshToken, "refreshToken is null");
AccessTokenResponse tokenResponse = getTokenResponse(refreshToken, AccessTokenResponse::parse);
return toResponse(tokenResponse.toSuccessResponse().getTokens(), Optional.of(refreshToken));
}
private Response toResponse(Tokens tokens, Optional existingRefreshToken)
throws ChallengeFailedException
{
AccessToken accessToken = tokens.getAccessToken();
RefreshToken refreshToken = tokens.getRefreshToken();
JWTClaimsSet claims = getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token"));
return new Response(
accessToken.getValue(),
determineExpiration(getExpiration(accessToken), claims.getExpirationTime()),
Optional.empty(),
Optional.ofNullable(refreshToken)
.map(RefreshToken::getValue)
.or(() -> existingRefreshToken));
}
}
private class OAuth2WithOidcExtensionsCodeFlow
implements AuthorizationCodeFlow
{
private final IDTokenValidator idTokenValidator;
public OAuth2WithOidcExtensionsCodeFlow()
{
idTokenValidator = new IDTokenValidator(issuer, clientId, jwsKeySelector, null);
idTokenValidator.setMaxClockSkew((int) maxClockSkew.roundTo(SECONDS));
}
@Override
public Request createAuthorizationRequest(String state, URI callbackUri)
{
String nonce = new Nonce().getValue();
return new Request(
new AuthenticationRequest.Builder(CODE, scope, clientId, callbackUri)
.endpointURI(authUrl)
.state(new State(state))
.nonce(new Nonce(hashNonce(nonce)))
.build()
.toURI(),
Optional.of(nonce));
}
@Override
public Response getOAuth2Response(String code, URI callbackUri, Optional nonce)
throws ChallengeFailedException
{
if (nonce.isEmpty()) {
throw new ChallengeFailedException("Missing nonce");
}
OIDCTokenResponse tokenResponse = getTokenResponse(code, callbackUri, OIDCTokenResponse::parse);
OIDCTokens tokens = tokenResponse.getOIDCTokens();
validateTokens(tokens, nonce);
return toResponse(tokens, Optional.empty());
}
@Override
public Response refreshTokens(String refreshToken)
throws ChallengeFailedException
{
OIDCTokenResponse tokenResponse = getTokenResponse(refreshToken, OIDCTokenResponse::parse);
OIDCTokens tokens = tokenResponse.getOIDCTokens();
validateTokens(tokens);
return toResponse(tokens, Optional.of(refreshToken));
}
private Response toResponse(OIDCTokens tokens, Optional existingRefreshToken)
throws ChallengeFailedException
{
AccessToken accessToken = tokens.getAccessToken();
RefreshToken refreshToken = tokens.getRefreshToken();
JWTClaimsSet claims = getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token"));
return new Response(
accessToken.getValue(),
determineExpiration(getExpiration(accessToken), claims.getExpirationTime()),
Optional.ofNullable(tokens.getIDTokenString()),
Optional.ofNullable(refreshToken)
.map(RefreshToken::getValue)
.or(() -> existingRefreshToken));
}
private void validateTokens(OIDCTokens tokens, Optional nonce)
throws ChallengeFailedException
{
try {
IDTokenClaimsSet idToken = idTokenValidator.validate(
tokens.getIDToken(),
nonce.map(this::hashNonce)
.map(Nonce::new)
.orElse(null));
AccessTokenHash accessTokenHash = idToken.getAccessTokenHash();
if (accessTokenHash != null) {
AccessTokenValidator.validate(tokens.getAccessToken(), ((JWSHeader) tokens.getIDToken().getHeader()).getAlgorithm(), accessTokenHash);
}
}
catch (BadJOSEException | JOSEException | InvalidHashException e) {
throw new ChallengeFailedException("Cannot validate tokens", e);
}
}
private void validateTokens(OIDCTokens tokens)
throws ChallengeFailedException
{
validateTokens(tokens, Optional.empty());
}
private String hashNonce(String nonce)
{
return sha256()
.hashString(nonce, UTF_8)
.toString();
}
}
private T getTokenResponse(String code, URI callbackUri, NimbusAirliftHttpClient.Parser parser)
throws ChallengeFailedException
{
return getTokenResponse(new TokenRequest(tokenUrl, clientAuth, new AuthorizationCodeGrant(new AuthorizationCode(code), callbackUri)), parser);
}
private T getTokenResponse(String refreshToken, NimbusAirliftHttpClient.Parser parser)
throws ChallengeFailedException
{
return getTokenResponse(new TokenRequest(tokenUrl, clientAuth, new RefreshTokenGrant(new RefreshToken(refreshToken)), scope), parser);
}
private T getTokenResponse(TokenRequest tokenRequest, NimbusAirliftHttpClient.Parser parser)
throws ChallengeFailedException
{
T tokenResponse = httpClient.execute(tokenRequest, parser);
if (!tokenResponse.indicatesSuccess()) {
throw new ChallengeFailedException("Error while fetching access token: " + tokenResponse.toErrorResponse().toHTTPResponse().getBody());
}
return tokenResponse;
}
private Optional getJWTClaimsSet(String accessToken)
{
if (userinfoUrl.isPresent()) {
return queryUserInfo(accessToken);
}
return parseAccessToken(accessToken);
}
private Optional queryUserInfo(String accessToken)
{
try {
UserInfoResponse response = httpClient.execute(new UserInfoRequest(userinfoUrl.get(), new BearerAccessToken(accessToken)), UserInfoResponse::parse);
if (!response.indicatesSuccess()) {
LOG.error("Received bad response from userinfo endpoint: %s", response.toErrorResponse().getErrorObject());
return Optional.empty();
}
return Optional.of(response.toSuccessResponse().getUserInfo().toJWTClaimsSet());
}
catch (ParseException | RuntimeException e) {
LOG.error(e, "Received bad response from userinfo endpoint");
return Optional.empty();
}
}
private Optional parseAccessToken(String accessToken)
{
try {
return Optional.of(accessTokenProcessor.process(accessToken, null));
}
catch (java.text.ParseException | BadJOSEException | JOSEException e) {
LOG.error(e, "Failed to parse JWT access token");
return Optional.empty();
}
}
private static Instant determineExpiration(Optional validUntil, Date expiration)
throws ChallengeFailedException
{
if (validUntil.isPresent()) {
if (expiration != null) {
return Ordering.natural().min(validUntil.get(), expiration.toInstant());
}
return validUntil.get();
}
if (expiration != null) {
return expiration.toInstant();
}
throw new ChallengeFailedException("no valid expiration date");
}
private static Optional getExpiration(AccessToken accessToken)
{
return accessToken.getLifetime() != 0 ? Optional.of(Instant.now().plusSeconds(accessToken.getLifetime())) : Optional.empty();
}
}