All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.server.security.oauth2.OidcDiscovery Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.server.security.oauth2;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Inject;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderConfigurationRequest;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import dev.failsafe.Failsafe;
import dev.failsafe.RetryPolicy;
import io.airlift.json.ObjectMapperProvider;
import io.airlift.log.Logger;

import java.net.URI;
import java.time.Duration;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static io.airlift.http.client.HttpStatus.OK;
import static io.airlift.http.client.HttpStatus.REQUEST_TIMEOUT;
import static io.airlift.http.client.HttpStatus.TOO_MANY_REQUESTS;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.ACCESS_TOKEN_ISSUER;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.AUTH_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.END_SESSION_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.JWKS_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.TOKEN_URL;
import static io.trino.server.security.oauth2.StaticOAuth2ServerConfiguration.USERINFO_URL;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class OidcDiscovery
        implements OAuth2ServerConfigProvider
{
    private static final Logger LOG = Logger.get(OidcDiscovery.class);

    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapperProvider().get();
    private final Issuer issuer;
    private final Duration discoveryTimeout;
    private final boolean userinfoEndpointEnabled;
    private final Optional accessTokenIssuer;
    private final Optional authUrl;
    private final Optional tokenUrl;
    private final Optional jwksUrl;
    private final Optional userinfoUrl;
    private final NimbusHttpClient httpClient;

    @Inject
    public OidcDiscovery(OAuth2Config oauthConfig, OidcDiscoveryConfig oidcConfig, NimbusHttpClient httpClient)
    {
        issuer = new Issuer(requireNonNull(oauthConfig.getIssuer(), "issuer is null"));
        userinfoEndpointEnabled = oidcConfig.isUserinfoEndpointEnabled();
        discoveryTimeout = Duration.ofMillis(requireNonNull(oidcConfig.getDiscoveryTimeout(), "discoveryTimeout is null").toMillis());
        accessTokenIssuer = requireNonNull(oidcConfig.getAccessTokenIssuer(), "accessTokenIssuer is null");
        authUrl = requireNonNull(oidcConfig.getAuthUrl(), "authUrl is null");
        tokenUrl = requireNonNull(oidcConfig.getTokenUrl(), "tokenUrl is null");
        jwksUrl = requireNonNull(oidcConfig.getJwksUrl(), "jwksUrl is null");
        userinfoUrl = requireNonNull(oidcConfig.getUserinfoUrl(), "userinfoUrl is null");
        this.httpClient = requireNonNull(httpClient, "httpClient is null");
    }

    @Override
    public OAuth2ServerConfig get()
    {
        return Failsafe.with(RetryPolicy.builder()
                        .withMaxAttempts(-1)
                        .withMaxDuration(discoveryTimeout)
                        .withDelay(Duration.ofSeconds(1))
                        .abortOn(IllegalStateException.class)
                        .onFailedAttempt(attempt -> LOG.debug("OpenID Connect Metadata read failed: %s", attempt.getLastException()))
                        .build())
                .get(() -> httpClient.execute(new OIDCProviderConfigurationRequest(issuer), this::parseConfigurationResponse));
    }

    private OAuth2ServerConfig parseConfigurationResponse(HTTPResponse response)
            throws ParseException
    {
        int statusCode = response.getStatusCode();
        if (statusCode != OK.code()) {
            // stop on any client errors other than REQUEST_TIMEOUT and TOO_MANY_REQUESTS
            if (statusCode < 400 || statusCode >= 500 || statusCode == REQUEST_TIMEOUT.code() || statusCode == TOO_MANY_REQUESTS.code()) {
                throw new RuntimeException("Invalid response from OpenID Metadata endpoint: " + statusCode);
            }
            throw new IllegalStateException(format("Invalid response from OpenID Metadata endpoint. Expected response code to be %s, but was %s", OK.code(), statusCode));
        }
        return readConfiguration(response.getBody());
    }

    private OAuth2ServerConfig readConfiguration(String body)
            throws ParseException
    {
        OIDCProviderMetadata metadata = OIDCProviderMetadata.parse(body);
        checkMetadataState(issuer.equals(metadata.getIssuer()), "The value of the \"issuer\" claim in Metadata document different than the Issuer URL used for the Configuration Request.");
        try {
            JsonNode metadataJson = OBJECT_MAPPER.readTree(body);
            Optional userinfoEndpoint;
            if (userinfoEndpointEnabled) {
                userinfoEndpoint = getOptionalField("userinfo_endpoint", Optional.ofNullable(metadata.getUserInfoEndpointURI()).map(URI::toString), USERINFO_URL, userinfoUrl);
            }
            else {
                userinfoEndpoint = Optional.empty();
            }
            Optional endSessionEndpoint = Optional.of(getRequiredField("end_session_endpoint", metadata.getEndSessionEndpointURI(), END_SESSION_URL, Optional.empty()));
            return new OAuth2ServerConfig(
                    // AD FS server can include "access_token_issuer" field in OpenID Provider Metadata.
                    // It's not a part of the OIDC standard thus have to be handled separately.
                    // see: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/f629647a-4825-465b-80bb-32c7e9cec2c8
                    getOptionalField("access_token_issuer", Optional.ofNullable(metadataJson.get("access_token_issuer")).map(JsonNode::textValue), ACCESS_TOKEN_ISSUER, accessTokenIssuer),
                    getRequiredField("authorization_endpoint", metadata.getAuthorizationEndpointURI(), AUTH_URL, authUrl),
                    getRequiredField("token_endpoint", metadata.getTokenEndpointURI(), TOKEN_URL, tokenUrl),
                    getRequiredField("jwks_uri", metadata.getJWKSetURI(), JWKS_URL, jwksUrl),
                    userinfoEndpoint.map(URI::create),
                    endSessionEndpoint);
        }
        catch (JsonProcessingException e) {
            throw new ParseException("Invalid JSON value", e);
        }
    }

    private static URI getRequiredField(String metadataField, URI metadataValue, String configurationField, Optional configurationValue)
    {
        Optional uri = getOptionalField(metadataField, Optional.ofNullable(metadataValue).map(URI::toString), configurationField, configurationValue);
        checkMetadataState(uri.isPresent(), "Missing required \"%s\" property.", metadataField);
        return URI.create(uri.get());
    }

    private static Optional getOptionalField(String metadataField, Optional metadataValue, String configurationField, Optional configurationValue)
    {
        if (configurationValue.isEmpty()) {
            return metadataValue;
        }

        if (metadataValue.isEmpty()) {
            return configurationValue;
        }

        if (!configurationValue.equals(metadataValue)) {
            LOG.warn("Overriding \"%s=%s\" from OpenID metadata document with value \"%s=%s\" defined in configuration",
                    metadataField, metadataValue.orElse(""), configurationField, configurationValue.orElse(""));
        }
        else {
            LOG.warn("Provided redundant configuration property \"%s\" with the same value as \"%s\" field in OpenID metadata document",
                    configurationField, metadataField);
        }
        return configurationValue;
    }

    private static void checkMetadataState(boolean expression, String additionalMessage, String... additionalMessageArgs)
    {
        checkState(expression, "Invalid response from OpenID Metadata endpoint. " + additionalMessage, (Object[]) additionalMessageArgs);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy