All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.awssdk.http.SdkHttpClientTestSuite Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package software.amazon.awssdk.http;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.absent;
import static com.github.tomakehurst.wiremock.client.WireMock.any;
import static com.github.tomakehurst.wiremock.client.WireMock.containing;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching;
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder;
import com.github.tomakehurst.wiremock.common.FatalStartupException;
import com.github.tomakehurst.wiremock.http.RequestMethod;
import com.github.tomakehurst.wiremock.junit.WireMockRule;
import com.github.tomakehurst.wiremock.matching.RequestPatternBuilder;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Random;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import software.amazon.awssdk.utils.IoUtils;
import software.amazon.awssdk.utils.Logger;

/**
 * A set of tests validating that the functionality implemented by a {@link SdkHttpClient}.
 *
 * This is used by an HTTP plugin implementation by extending this class and implementing the abstract methods to provide this
 * suite with a testable HTTP client implementation.
 */
@RunWith(MockitoJUnitRunner.class)
public abstract class SdkHttpClientTestSuite {
    private static final Logger LOG = Logger.loggerFor(SdkHttpClientTestSuite.class);

    private static final ConnectionCountingTrafficListener CONNECTION_COUNTER = new ConnectionCountingTrafficListener();

    @Rule
    public WireMockRule mockServer = createWireMockRule();

    private final Random rng = new Random();


    @Test
    public void supportsResponseCode200() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_OK);
    }

    @Test
    public void supportsResponseCode200Head() throws Exception {
        // HEAD is special due to closing of the connection immediately and streams are null
        testForResponseCode(HttpURLConnection.HTTP_FORBIDDEN, SdkHttpMethod.HEAD);
    }

    @Test
    public void supportsResponseCode202() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_ACCEPTED);
    }

    @Test
    public void supportsResponseCode403() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_FORBIDDEN);
    }

    @Test
    public void supportsResponseCode403Head() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_FORBIDDEN, SdkHttpMethod.HEAD);
    }

    @Test
    public void supportsResponseCode301() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_MOVED_PERM);
    }

    @Test
    public void supportsResponseCode302() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_MOVED_TEMP);
    }

    @Test
    public void supportsResponseCode500() throws Exception {
        testForResponseCode(HttpURLConnection.HTTP_INTERNAL_ERROR);
    }

    @Test
    public void validatesHttpsCertificateIssuer() throws Exception {
        SdkHttpClient client = createSdkHttpClient();

        SdkHttpFullRequest request = mockSdkRequest("https://localhost:" + mockServer.httpsPort(), SdkHttpMethod.POST);

        assertThatThrownBy(client.prepareRequest(HttpExecuteRequest.builder().request(request).build())::call)
                .isInstanceOf(SSLHandshakeException.class);
    }

    @Test
    public void connectionPoolingWorks() throws Exception {
        int initialOpenedConnections = CONNECTION_COUNTER.openedConnections();

        SdkHttpClientOptions httpClientOptions = new SdkHttpClientOptions();
        httpClientOptions.trustAll(true);
        SdkHttpClient client = createSdkHttpClient(httpClientOptions);

        stubForMockRequest(200);

        for (int i = 0; i < 5; i++) {
            SdkHttpFullRequest req = mockSdkRequest("http://localhost:" + mockServer.port(), SdkHttpMethod.POST);
            HttpExecuteResponse response =
                client.prepareRequest(HttpExecuteRequest.builder()
                                                        .request(req)
                                                        .contentStreamProvider(req.contentStreamProvider().orElse(null))
                                                        .build())
                      .call();
            response.responseBody().ifPresent(IoUtils::drainInputStream);
        }

        assertThat(CONNECTION_COUNTER.openedConnections()).isEqualTo(initialOpenedConnections + 1);
    }

    @Test
    public void connectionsAreNotReusedOn5xxErrors() throws Exception {
        int initialOpenedConnections = CONNECTION_COUNTER.openedConnections();

        SdkHttpClientOptions httpClientOptions = new SdkHttpClientOptions();
        httpClientOptions.trustAll(true);
        SdkHttpClient client = createSdkHttpClient(httpClientOptions);

        stubForMockRequest(503);

        for (int i = 0; i < 5; i++) {
            SdkHttpFullRequest req = mockSdkRequest("http://localhost:" + mockServer.port(), SdkHttpMethod.POST);
            HttpExecuteResponse response =
                client.prepareRequest(HttpExecuteRequest.builder()
                                                        .request(req)
                                                        .contentStreamProvider(req.contentStreamProvider().orElse(null))
                                                        .build())
                      .call();
            response.responseBody().ifPresent(IoUtils::drainInputStream);
        }

        assertThat(CONNECTION_COUNTER.openedConnections()).isEqualTo(initialOpenedConnections + 5);
    }

    @Test
    public void testCustomTlsTrustManager() throws Exception {
        WireMockServer selfSignedServer = HttpTestUtils.createSelfSignedServer();

        TrustManagerFactory managerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
        managerFactory.init(HttpTestUtils.getSelfSignedKeyStore());

        SdkHttpClientOptions httpClientOptions = new SdkHttpClientOptions();
        httpClientOptions.tlsTrustManagersProvider(managerFactory::getTrustManagers);

        selfSignedServer.start();

        try {
            SdkHttpClient client = createSdkHttpClient(httpClientOptions);
            SdkHttpFullRequest request = mockSdkRequest("https://localhost:" + selfSignedServer.httpsPort(), SdkHttpMethod.POST);

            client.prepareRequest(HttpExecuteRequest.builder().request(request).build()).call();
        } finally {
            selfSignedServer.stop();
        }
    }

    @Test
    public void testTrustAllWorks() throws Exception {
        SdkHttpClientOptions httpClientOptions = new SdkHttpClientOptions();
        httpClientOptions.trustAll(true);

        testForResponseCodeUsingHttps(createSdkHttpClient(httpClientOptions), HttpURLConnection.HTTP_OK);
    }

    @Test
    public void testCustomTlsTrustManagerAndTrustAllFails() throws Exception {
        SdkHttpClientOptions httpClientOptions = new SdkHttpClientOptions();
        httpClientOptions.tlsTrustManagersProvider(() -> new TrustManager[0]);
        httpClientOptions.trustAll(true);

        assertThatThrownBy(() -> createSdkHttpClient(httpClientOptions)).isInstanceOf(IllegalArgumentException.class);
    }

    protected void testForResponseCode(int returnCode) throws Exception {
        testForResponseCode(returnCode, SdkHttpMethod.POST);
    }

    private void testForResponseCode(int returnCode, SdkHttpMethod method) throws Exception {
        SdkHttpClient client = createSdkHttpClient();

        stubForMockRequest(returnCode);

        SdkHttpFullRequest req = mockSdkRequest("http://localhost:" + mockServer.port(), method);
        HttpExecuteResponse rsp = client.prepareRequest(HttpExecuteRequest.builder()
                                                                          .request(req)
                                                                          .contentStreamProvider(req.contentStreamProvider()
                                                                                                    .orElse(null))
                                                                          .build())
                                        .call();

        validateResponse(rsp, returnCode, method);
    }

    protected void testForResponseCodeUsingHttps(SdkHttpClient client, int returnCode) throws Exception {
        SdkHttpMethod sdkHttpMethod = SdkHttpMethod.POST;
        stubForMockRequest(returnCode);

        SdkHttpFullRequest req = mockSdkRequest("https://localhost:" + mockServer.httpsPort(), sdkHttpMethod);
        HttpExecuteResponse rsp = client.prepareRequest(HttpExecuteRequest.builder()
                                                                          .request(req)
                                                                          .contentStreamProvider(req.contentStreamProvider()
                                                                                                    .orElse(null))
                                                                          .build())
                                        .call();

        validateResponse(rsp, returnCode, sdkHttpMethod);
    }

    private void stubForMockRequest(int returnCode) {
        ResponseDefinitionBuilder responseBuilder = aResponse().withStatus(returnCode)
                                                               .withHeader("Some-Header", "With Value")
                                                               .withBody("hello");

        if (returnCode >= 300 && returnCode <= 399) {
            responseBuilder.withHeader("Location", "Some New Location");
        }

        mockServer.stubFor(any(urlPathEqualTo("/")).willReturn(responseBuilder));
    }

    private void validateResponse(HttpExecuteResponse response, int returnCode, SdkHttpMethod method) throws IOException {
        RequestMethod requestMethod = RequestMethod.fromString(method.name());

        RequestPatternBuilder patternBuilder = RequestPatternBuilder.newRequestPattern(requestMethod, urlMatching("/"))
                                                                           .withHeader("Host", containing("localhost"))
                                                                           .withHeader("User-Agent", equalTo("hello-world!"));

        if (method == SdkHttpMethod.HEAD) {
            patternBuilder.withRequestBody(absent());
        } else {
            patternBuilder.withRequestBody(equalTo("Body"));
        }

        mockServer.verify(1, patternBuilder);

        if (method == SdkHttpMethod.HEAD) {
            assertThat(response.responseBody()).isEmpty();
        } else {
            assertThat(IoUtils.toUtf8String(response.responseBody().orElse(null))).isEqualTo("hello");
        }

        assertThat(response.httpResponse().firstMatchingHeader("Some-Header")).contains("With Value");
        assertThat(response.httpResponse().statusCode()).isEqualTo(returnCode);
        mockServer.resetMappings();
    }

    private SdkHttpFullRequest mockSdkRequest(String uriString, SdkHttpMethod method) {
        URI uri = URI.create(uriString);
        SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
                                                            .uri(uri)
                                                            .method(method)
                                                            .putHeader("Host", uri.getHost())
                                                            .putHeader("User-Agent", "hello-world!");
        if (method != SdkHttpMethod.HEAD) {
            requestBuilder.contentStreamProvider(() -> new ByteArrayInputStream("Body".getBytes(StandardCharsets.UTF_8)));
        }

        return requestBuilder.build();
    }

    /**
     * {@link #createSdkHttpClient(SdkHttpClientOptions)} with default options.
     */
    protected final SdkHttpClient createSdkHttpClient() {
        return createSdkHttpClient(new SdkHttpClientOptions());
    }

    /**
     * Implemented by a child class to create an HTTP client to validate based on the provided options.
     */
    protected abstract SdkHttpClient createSdkHttpClient(SdkHttpClientOptions options);

    /**
     * The options that should be considered when creating the client via {@link #createSdkHttpClient(SdkHttpClientOptions)}.
     */
    protected static final class SdkHttpClientOptions {
        private TlsTrustManagersProvider tlsTrustManagersProvider = null;
        private boolean trustAll = false;

        public TlsTrustManagersProvider tlsTrustManagersProvider() {
            return tlsTrustManagersProvider;
        }

        public void tlsTrustManagersProvider(TlsTrustManagersProvider tlsTrustManagersProvider) {
            this.tlsTrustManagersProvider = tlsTrustManagersProvider;
        }

        public boolean trustAll() {
            return trustAll;
        }

        public void trustAll(boolean trustAll) {
            this.trustAll = trustAll;
        }
    }

    private WireMockRule createWireMockRule() {
        int maxAttempts = 5;
        for (int i = 0; i < maxAttempts; ++i) {
            try {
                return new WireMockRule(wireMockConfig().dynamicPort()
                                                        .dynamicHttpsPort()
                                                        .networkTrafficListener(CONNECTION_COUNTER));
            } catch (FatalStartupException e) {
                int attemptNum = i + 1;
                LOG.debug(() -> "Was not able to start WireMock server. Attempt " + attemptNum, e);

                if (attemptNum != maxAttempts) {
                    try {
                        long sleepMillis = 1_000L + rng.nextInt(1_000);
                        Thread.sleep(sleepMillis);
                    } catch (InterruptedException ie) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException("Backoff interrupted", ie);
                    }
                }
            }
        }

        throw new RuntimeException("Unable to setup WireMock rule");
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy