All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tinkerpop.gremlin.driver.auth.Sigv4 Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.tinkerpop.gremlin.driver.auth;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.tinkerpop.gremlin.driver.HttpRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.spi.signer.SignedRequest;
import software.amazon.awssdk.utils.http.SdkHttpUtils;

import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.AUTHORIZATION;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.HOST;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_CONTENT_SHA256;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DATE;
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_SECURITY_TOKEN;

/**
 * A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that provides headers required for SigV4. Because
 * the signing process requires final header and body data, this interceptor should almost always be last.
 */
public class Sigv4 implements Auth {
    private static final Logger logger = LoggerFactory.getLogger(Sigv4.class);
    private final AwsCredentialsProvider awsCredentialsProvider;
    private final AwsV4HttpSigner aws4Signer;
    private final String serviceName;
    private final String regionName;

    public Sigv4(final String regionName, final String serviceName) {
        this(regionName, DefaultCredentialsProvider.create(), serviceName);
    }

    public Sigv4(final String regionName, final AwsCredentialsProvider awsCredentialsProvider, final String serviceName) {
        this.awsCredentialsProvider = awsCredentialsProvider;

        aws4Signer = AwsV4HttpSigner.create();
        this.regionName = regionName;
        this.serviceName = serviceName;
    }

    @Override
    public HttpRequest apply(final HttpRequest httpRequest) {
        try {
            final ContentStreamProvider content = toContentStream(httpRequest);
            // Convert Http request into an AWS SDK signable request
            final SdkHttpRequest awsSignableRequest = toSignableRequest(httpRequest);
            final AwsCredentials credentials = awsCredentialsProvider.resolveCredentials();

            // Sign the AWS SDK signable request (which internally adds some HTTP headers)
            final SignedRequest signed = aws4Signer.sign(r -> r.identity(credentials)
                    .request(awsSignableRequest)
                    .payload(content)
                    .putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, this.serviceName)
                    .putProperty(AwsV4HttpSigner.REGION_NAME, this.regionName));

            final Map headers = httpRequest.headers();
            setSignedHeaders(headers, signed);
            setSessionToken(headers, credentials);
        } catch (final Exception ex) {
            logger.error("Error signing HTTP request: {}", ex.getMessage(), ex);
            throw new AuthenticationException(ex);
        }
        return httpRequest;
    }

    private void setSessionToken(final Map headers, final AwsCredentials credentials) {
        // extract session token if temporary credentials are provided
        if ((credentials instanceof AwsSessionCredentials)) {
            final String sessionToken = ((AwsSessionCredentials) credentials).sessionToken();
            if (sessionToken != null && !sessionToken.isEmpty()) {
                headers.put(X_AMZ_SECURITY_TOKEN, sessionToken);
            }
        }
    }

    private void setSignedHeaders(final Map headers, final SignedRequest signed) {
        headers.remove(HttpRequest.Headers.HOST);
        headers.put(HOST, signed.request().host());
        final Map> signedHeaders = signed.request().headers();
        headers.put(X_AMZ_DATE, getSingleHeaderValue(signedHeaders, X_AMZ_DATE));
        headers.put(AUTHORIZATION, getSingleHeaderValue(signedHeaders, AUTHORIZATION));
        headers.put(X_AMZ_CONTENT_SHA256, getSingleHeaderValue(signedHeaders, X_AMZ_CONTENT_SHA256));
    }

    private String getSingleHeaderValue(final Map> headers, final String headerName) {
        final Set headerValues = new HashSet<>(headers.containsKey(headerName) ? headers.get(headerName) : Collections.emptySet());
        if (headerValues.size() != 1) {
            throw new IllegalArgumentException(String.format("Expected 1 header %s but found %d", headerName, headerValues.size()));
        }
        return headerValues.iterator().next();
    }

    private ContentStreamProvider toContentStream(final HttpRequest httpRequest) {
        // carry over the entity (or an empty entity, if no entity is provided)
        if (!(httpRequest.getBody() instanceof byte[])) {
            throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + httpRequest.getBody().getClass());
        }
        final byte[] body = (byte[]) httpRequest.getBody();
        return (body.length != 0) ? ContentStreamProvider.fromByteArray(body) : ContentStreamProvider.fromUtf8String("");
    }

    private SdkHttpRequest toSignableRequest(final HttpRequest request) {

        // make sure the request contains the minimal required set of information
        checkNotNull(request.getUri(), "The request URI must not be null");
        checkNotNull(request.getMethod(), "The request method must not be null");

        // convert the headers to the internal API format
        final Map headers = request.headers();
        final Map> headersInternal = new HashMap<>();

        // we don't want to add the Host header as the Signer always adds the host header.
        for (Map.Entry header : headers.entrySet()) {
            // Skip adding the Host header as the signing process will add one.
            if (!header.getKey().equalsIgnoreCase(HttpRequest.Headers.HOST)) {
                headersInternal.put(header.getKey(), Collections.singletonList(header.getValue()));
            }
        }

        // convert the parameters to the internal API format
        final URI uri = request.getUri();
        final Map> parametersInternal = extractParametersFromQueryString(uri.getQuery());

        final URI endpointUri = URI.create(uri.getScheme() + "://" + uri.getHost());

        // create the HTTP AWS SdkHttpRequest and carry over information
        return SdkHttpRequest.builder()
                .uri(endpointUri)
                .encodedPath(uri.getPath())
                .method(SdkHttpMethod.fromValue(request.getMethod()))
                .headers(headersInternal)
                .rawQueryParameters(parametersInternal)
                .build();
    }

    private HashMap> extractParametersFromQueryString(final String queryStr) {

        final HashMap> parameters = new HashMap<>();

        if (queryStr == null) {
            return parameters;
        }

        // convert the parameters to the internal API format
        for (final String queryParam : queryStr.split("&")) {

            if (!queryParam.isEmpty()) {
                final String[] keyValuePair = queryParam.split("=", 2);

                // parameters are encoded in the HTTP request, we need to decode them here
                final String key = SdkHttpUtils.urlDecode(keyValuePair[0]);
                final String value;

                if (keyValuePair.length == 2) {
                    value = SdkHttpUtils.urlDecode(keyValuePair[1]);
                } else {
                    value = "";
                }

                // insert the parameter key into the map, if not yet present
                if (!parameters.containsKey(key)) {
                    parameters.put(key, new ArrayList<>());
                }

                // append the parameter value to the list for the given key
                parameters.get(key).add(value);
            }
        }

        return parameters;
    }

    private void checkNotNull(final Object obj, final String errMsg) {
        if (obj == null) {
            throw new IllegalArgumentException(errMsg);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy