All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.server.security.oauth2.NimbusAirliftHttpClient Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableMultimap;
import com.google.inject.Inject;
import com.nimbusds.jose.util.Resource;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.http.client.Response;
import io.airlift.http.client.ResponseHandler;
import io.airlift.http.client.ResponseHandlerUtils;
import io.airlift.http.client.StringResponseHandler;
import jakarta.ws.rs.core.UriBuilder;

import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;

import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.DELETE;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.GET;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.POST;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.PUT;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

public class NimbusAirliftHttpClient
        implements NimbusHttpClient
{
    private final HttpClient httpClient;

    @Inject
    public NimbusAirliftHttpClient(@ForOAuth2 HttpClient httpClient)
    {
        this.httpClient = requireNonNull(httpClient, "httpClient is null");
    }

    @Override
    public Resource retrieveResource(URL url)
            throws IOException
    {
        try {
            StringResponseHandler.StringResponse response = httpClient.execute(
                    prepareGet().setUri(url.toURI()).build(),
                    createStringResponseHandler());
            return new Resource(response.getBody(), response.getHeader(CONTENT_TYPE));
        }
        catch (URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public  T execute(com.nimbusds.oauth2.sdk.Request nimbusRequest, Parser parser)
    {
        HTTPRequest httpRequest = nimbusRequest.toHTTPRequest();
        HTTPRequest.Method method = httpRequest.getMethod();

        Request.Builder request = new Request.Builder()
                .setMethod(method.name())
                .setFollowRedirects(httpRequest.getFollowRedirects());

        UriBuilder url = UriBuilder.fromUri(httpRequest.getURI());
        if (method.equals(GET) || method.equals(DELETE)) {
            httpRequest.getQueryStringParameters().forEach((key, value) -> url.queryParam(key, value.toArray()));
        }

        url.fragment(httpRequest.getURL().getRef());

        request.setUri(url.build());

        ImmutableMultimap.Builder headers = ImmutableMultimap.builder();
        httpRequest.getHeaderMap().forEach(headers::putAll);
        request.addHeaders(headers.build());

        if (method.equals(POST) || method.equals(PUT)) {
            String query = httpRequest.getBody();
            if (query != null) {
                request.setBodyGenerator(createStaticBodyGenerator(query, UTF_8));
            }
        }
        return httpClient.execute(request.build(), new NimbusResponseHandler<>(parser));
    }

    public static class NimbusResponseHandler
            implements ResponseHandler
    {
        private final StringResponseHandler handler = createStringResponseHandler();
        private final Parser parser;

        public NimbusResponseHandler(Parser parser)
        {
            this.parser = requireNonNull(parser, "parser is null");
        }

        @Override
        public T handleException(Request request, Exception exception)
        {
            throw ResponseHandlerUtils.propagate(request, exception);
        }

        @Override
        public T handle(Request request, Response response)
        {
            StringResponseHandler.StringResponse stringResponse = handler.handle(request, response);
            HTTPResponse nimbusResponse = new HTTPResponse(response.getStatusCode());
            response.getHeaders().asMap().forEach((name, values) -> nimbusResponse.setHeader(name.toString(), values.toArray(new String[0])));
            nimbusResponse.setBody(stringResponse.getBody());
            try {
                return parser.parse(nimbusResponse);
            }
            catch (ParseException e) {
                throw new RuntimeException(format("Unable to parse response status=[%d], body=[%s]", stringResponse.getStatusCode(), stringResponse.getBody()), e);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy