All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.awssdk.codegen.poet.client.SyncClientClass Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package software.amazon.awssdk.codegen.poet.client;

import static javax.lang.model.element.Modifier.FINAL;
import static javax.lang.model.element.Modifier.PRIVATE;
import static javax.lang.model.element.Modifier.PROTECTED;
import static javax.lang.model.element.Modifier.PUBLIC;
import static javax.lang.model.element.Modifier.STATIC;
import static software.amazon.awssdk.codegen.poet.PoetUtils.classNameFromFqcn;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.addS3ArnableFieldCode;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.applySignerOverrideMethod;

import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.WildcardTypeName;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata;
import software.amazon.awssdk.awscore.internal.AwsServiceProtocol;
import software.amazon.awssdk.codegen.emitters.GeneratorTaskParams;
import software.amazon.awssdk.codegen.model.config.customization.UtilitiesMethod;
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.model.intermediate.OperationModel;
import software.amazon.awssdk.codegen.model.intermediate.Protocol;
import software.amazon.awssdk.codegen.model.service.ClientContextParam;
import software.amazon.awssdk.codegen.model.service.PreClientExecutionRequestCustomizer;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.handler.SyncClientHandler;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest;
import software.amazon.awssdk.core.metrics.CoreMetric;
import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.MetricPublisher;
import software.amazon.awssdk.metrics.NoOpMetricCollector;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.CompletableFutureUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

public class SyncClientClass extends SyncClientInterface {

    private final IntermediateModel model;
    private final PoetExtension poetExtensions;
    private final ClassName className;
    private final ProtocolSpec protocolSpec;
    private final ClassName serviceClientConfigurationClassName;
    private final ServiceClientConfigurationUtils configurationUtils;
    private final boolean useSraAuth;

    public SyncClientClass(GeneratorTaskParams taskParams) {
        super(taskParams.getModel());
        this.model = taskParams.getModel();
        this.poetExtensions = taskParams.getPoetExtensions();
        this.className = poetExtensions.getClientClass(model.getMetadata().getSyncClient());
        this.protocolSpec = getProtocolSpecs(poetExtensions, model);
        this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
        this.configurationUtils = new ServiceClientConfigurationUtils(model);
        this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
    }

    @Override
    protected void addInterfaceClass(TypeSpec.Builder type) {
        ClassName interfaceClass = poetExtensions.getClientClass(model.getMetadata().getSyncInterface());
        type.addSuperinterface(interfaceClass)
            .addJavadoc("Internal implementation of {@link $1T}.\n\n@see $1T#builder()", interfaceClass);
    }

    @Override
    protected TypeSpec.Builder createTypeSpec() {
        return PoetUtils.createClassBuilder(className);
    }

    @Override
    protected void addAnnotations(TypeSpec.Builder type) {
        type.addAnnotation(SdkInternalApi.class);
    }

    @Override
    protected void addModifiers(TypeSpec.Builder type) {
        type.addModifiers(FINAL);
    }

    @Override
    protected void addFields(TypeSpec.Builder type) {
        type.addField(logger())
            .addField(protocolMetadata())
            .addField(SyncClientHandler.class, "clientHandler", PRIVATE, FINAL)
            .addField(protocolSpec.protocolFactory(model))
            .addField(SdkClientConfiguration.class, "clientConfiguration", PRIVATE, FINAL);
    }

    @Override
    protected void addAdditionalMethods(TypeSpec.Builder type) {
        if (!useSraAuth && model.containsRequestSigners()) {
            type.addMethod(applySignerOverrideMethod(poetExtensions, model));
        }

        model.getEndpointOperation().ifPresent(
            o -> type.addField(EndpointDiscoveryRefreshCache.class, "endpointDiscoveryCache", PRIVATE));

        type.addMethod(constructor())
            .addMethod(nameMethod())
            .addMethods(protocolSpec.additionalMethods())
            .addMethod(resolveMetricPublishersMethod());

        protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod);
        type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod());
        type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName()));
        type.addMethod(protocolSpec.initProtocolFactory(model));
    }

    private FieldSpec logger() {
        return FieldSpec.builder(Logger.class, "log", PRIVATE, STATIC, FINAL)
                        .initializer("$T.loggerFor($T.class)", Logger.class, className)
                        .build();
    }

    private FieldSpec protocolMetadata() {
        return FieldSpec.builder(AwsProtocolMetadata.class, "protocolMetadata", PRIVATE, STATIC, FINAL)
                        .initializer("$T.builder().serviceProtocol($T.$L).build()",
                                     AwsProtocolMetadata.class, AwsServiceProtocol.class, model.getMetadata().getProtocol())
                        .build();
    }

    private MethodSpec nameMethod() {
        return MethodSpec.methodBuilder("serviceName")
                         .addAnnotation(Override.class)
                         .addModifiers(PUBLIC, FINAL)
                         .returns(String.class)
                         .addStatement("return SERVICE_NAME")
                         .build();
    }

    @Override
    protected MethodSpec serviceClientConfigMethod() {
        return MethodSpec.methodBuilder("serviceClientConfiguration")
                         .addAnnotation(Override.class)
                         .addModifiers(PUBLIC, FINAL)
                         .returns(serviceClientConfigurationClassName)
                         .addStatement("return new $T(this.clientConfiguration.toBuilder()).build()",
                                       this.configurationUtils.serviceClientConfigurationBuilderClassName())
                         .build();
    }

    @Override
    public ClassName className() {
        return className;
    }

    private MethodSpec constructor() {
        MethodSpec.Builder builder
            = MethodSpec.constructorBuilder()
                        .addModifiers(PROTECTED)
                        .addParameter(SdkClientConfiguration.class, "clientConfiguration")
                        .addStatement("this.clientHandler = new $T(clientConfiguration)", protocolSpec.getClientHandlerClass())
                        .addStatement("this.clientConfiguration = clientConfiguration.toBuilder()"
                                      + ".option($T.SDK_CLIENT, this)"
                                      + ".build()", SdkClientOption.class);

        FieldSpec protocolFactoryField = protocolSpec.protocolFactory(model);
        if (model.getMetadata().isJsonProtocol()) {
            builder.addStatement("this.$N = init($T.builder()).build()", protocolFactoryField.name,
                                 protocolFactoryField.type);
        } else {
            builder.addStatement("this.$N = init()", protocolFactoryField.name);
        }

        if (model.getEndpointOperation().isPresent()) {
            builder.beginControlFlow("if (clientConfiguration.option(SdkClientOption.ENDPOINT_DISCOVERY_ENABLED))");
            builder.addStatement("this.endpointDiscoveryCache = $T.create($T.create(this))",
                                 EndpointDiscoveryRefreshCache.class,
                                 poetExtensions.getClientClass(model.getNamingStrategy().getServiceName() +
                                                               "EndpointDiscoveryCacheLoader"));

            if (model.getCustomizationConfig().allowEndpointOverrideForEndpointDiscoveryRequiredOperations()) {
                builder.beginControlFlow("if (clientConfiguration.option(SdkClientOption.CLIENT_ENDPOINT_PROVIDER)"
                                         + ".isEndpointOverridden())");
                builder.addStatement("log.warn(() -> $S)",
                                     "Endpoint discovery is enabled for this client, and an endpoint override was also "
                                     + "specified. This will disable endpoint discovery for methods that require it, instead "
                                     + "using the specified endpoint override. This may or may not be what you intended.");
                builder.endControlFlow();
            }

            builder.endControlFlow();
        }

        return builder.build();
    }

    @Override
    protected List operations() {
        return model.getOperations().values().stream()
                    .filter(o -> !o.hasEventStreamInput())
                    .filter(o -> !o.hasEventStreamOutput())
                    .flatMap(this::operations)
                    .collect(Collectors.toList());
    }

    private Stream operations(OperationModel opModel) {
        List methods = new ArrayList<>();
        methods.add(traditionalMethod(opModel));
        return methods.stream();
    }

    private MethodSpec traditionalMethod(OperationModel opModel) {
        MethodSpec.Builder method = SyncClientInterface.operationMethodSignature(model, opModel)
                                                       .addAnnotation(Override.class);

        addRequestModifierCode(opModel, model).ifPresent(method::addCode);
        if (!useSraAuth) {
            method.addCode(ClientClassUtils.callApplySignerOverrideMethod(opModel));
        }
        method.addCode(protocolSpec.responseHandler(model, opModel));

        protocolSpec.errorResponseHandler(opModel).ifPresent(method::addCode);

        if (opModel.getEndpointDiscovery() != null) {
            method.addStatement("boolean endpointDiscoveryEnabled = "
                                + "clientConfiguration.option(SdkClientOption.ENDPOINT_DISCOVERY_ENABLED)");
            method.addStatement("boolean endpointOverridden = "
                                + "clientConfiguration.option(SdkClientOption.CLIENT_ENDPOINT_PROVIDER)"
                                + ".isEndpointOverridden()");

            if (opModel.getEndpointDiscovery().isRequired()) {
                if (!model.getCustomizationConfig().allowEndpointOverrideForEndpointDiscoveryRequiredOperations()) {
                    method.beginControlFlow("if (endpointOverridden)");
                    method.addStatement("throw new $T($S)", IllegalStateException.class,
                                        "This operation requires endpoint discovery, but an endpoint override was specified "
                                        + "when the client was created. This is not supported.");
                    method.endControlFlow();

                    method.beginControlFlow("if (!endpointDiscoveryEnabled)");
                    method.addStatement("throw new $T($S)", IllegalStateException.class,
                                        "This operation requires endpoint discovery, but endpoint discovery was disabled on the "
                                        + "client.");
                    method.endControlFlow();
                } else {
                    method.beginControlFlow("if (endpointOverridden)");
                    method.addStatement("endpointDiscoveryEnabled = false");
                    method.nextControlFlow("else if (!endpointDiscoveryEnabled)");
                    method.addStatement("throw new $T($S)", IllegalStateException.class,
                                        "This operation requires endpoint discovery to be enabled, or for you to specify an "
                                        + "endpoint override when the client is created.");
                    method.endControlFlow();
                }
            }

            method.addStatement("$T cachedEndpoint = null", URI.class);
            method.beginControlFlow("if (endpointDiscoveryEnabled)");

            ParameterizedTypeName identityFutureTypeName =
                ParameterizedTypeName.get(ClassName.get(CompletableFuture.class),
                                          WildcardTypeName.subtypeOf(AwsCredentialsIdentity.class));
            method.addCode("$T identityFuture = $N.overrideConfiguration()",
                           identityFutureTypeName,
                           opModel.getInput().getVariableName())
                  .addCode("    .flatMap($T::credentialsIdentityProvider)", AwsRequestOverrideConfiguration.class)
                  .addCode("    .orElseGet(() -> clientConfiguration.option($T.CREDENTIALS_IDENTITY_PROVIDER))",
                           AwsClientOption.class)
                  .addCode("    .resolveIdentity();");

            method.addCode("$T key = $T.joinLikeSync(identityFuture).accessKeyId();", String.class, CompletableFutureUtils.class);

            method.addCode("$1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class)
                  .addCode("    .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired())
                  .addCode("    .defaultEndpoint(clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER).clientEndpoint())",
                           SdkClientOption.class)
                  .addCode("    .overrideConfiguration($N.overrideConfiguration().orElse(null))",
                           opModel.getInput().getVariableName())
                  .addCode("    .build();");

            method.addStatement("cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest)");
            method.endControlFlow();
        }

        method.addStatement("$T clientConfiguration = updateSdkClientConfiguration($L, this.clientConfiguration)",
                            SdkClientConfiguration.class, opModel.getInput().getVariableName());
        method.addStatement("$T<$T> metricPublishers = "
                            + "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
                            List.class,
                            MetricPublisher.class,
                            opModel.getInput().getVariableName())
              .addStatement("$1T apiCallMetricCollector = metricPublishers.isEmpty() ? $2T.create() : $1T.create($3S)",
                            MetricCollector.class, NoOpMetricCollector.class, "ApiCall");

        method.beginControlFlow("try")
                .addStatement("apiCallMetricCollector.reportMetric($T.$L, $S)",
                              CoreMetric.class, "SERVICE_ID", model.getMetadata().getServiceId())
                .addStatement("apiCallMetricCollector.reportMetric($T.$L, $S)",
                              CoreMetric.class, "OPERATION_NAME", opModel.getOperationName());

        addS3ArnableFieldCode(opModel, model).ifPresent(method::addCode);
        method.addCode(ClientClassUtils.addEndpointTraitCode(opModel));

        method.addCode(protocolSpec.executionHandler(opModel))
              .endControlFlow()
              .beginControlFlow("finally")
              .addStatement("metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()))")
              .endControlFlow();

        return method.build();
    }

    public static Optional addRequestModifierCode(OperationModel opModel, IntermediateModel model) {

        Map preClientExecutionRequestCustomizer =
            model.getCustomizationConfig().getPreClientExecutionRequestCustomizer();

        if (!CollectionUtils.isNullOrEmpty(preClientExecutionRequestCustomizer)) {
            PreClientExecutionRequestCustomizer requestCustomizer =
                preClientExecutionRequestCustomizer.get(opModel.getOperationName());
            if (requestCustomizer != null) {
                CodeBlock.Builder builder = CodeBlock.builder();
                ClassName instanceType = classNameFromFqcn(requestCustomizer.getClassName());
                builder.addStatement("$L = $T.$N($L)",
                                     opModel.getInput().getVariableName(),
                                     instanceType,
                                     requestCustomizer.getMethodName(),
                                     opModel.getInput().getVariableName());
                return Optional.of(builder.build());
            }
        }
        return Optional.empty();
    }

    @Override
    protected void addCloseMethod(TypeSpec.Builder type) {
        MethodSpec method = MethodSpec.methodBuilder("close")
                                      .addAnnotation(Override.class)
                                      .addModifiers(PUBLIC)
                                      .addStatement("$N.close()", "clientHandler")
                                      .build();

        type.addMethod(method);
    }

    @Override
    protected MethodSpec.Builder utilitiesOperationBody(MethodSpec.Builder builder) {
        UtilitiesMethod config = model.getCustomizationConfig().getUtilitiesMethod();
        String instanceClass = config.getInstanceType();
        if (instanceClass == null) {
            instanceClass = config.getReturnType();
        }
        ClassName instanceType = PoetUtils.classNameFromFqcn(instanceClass);

        return builder.addAnnotation(Override.class)
                      .addStatement("return $T.create($L)", instanceType,
                                    String.join(",", config.getCreateMethodParams()));
    }

    static ProtocolSpec getProtocolSpecs(PoetExtension poetExtensions, IntermediateModel model) {
        Protocol protocol = model.getMetadata().getProtocol();
        switch (protocol) {
            case QUERY:
                return new QueryProtocolSpec(model, poetExtensions);
            case REST_XML:
                return new XmlProtocolSpec(model, poetExtensions);
            case EC2:
                return new Ec2ProtocolSpec(model, poetExtensions);
            case AWS_JSON:
            case REST_JSON:
            case CBOR:
            case SMITHY_RPC_V2_CBOR:
                return new JsonProtocolSpec(poetExtensions, model);
            default:
                throw new RuntimeException("Unknown protocol: " + protocol.name());
        }
    }

    private MethodSpec resolveMetricPublishersMethod() {
        String clientConfigName = "clientConfiguration";
        String requestOverrideConfigName = "requestOverrideConfiguration";

        MethodSpec.Builder methodBuilder = MethodSpec.methodBuilder("resolveMetricPublishers")
                .addModifiers(PRIVATE, STATIC)
                .returns(ParameterizedTypeName.get(List.class, MetricPublisher.class))
                .addParameter(SdkClientConfiguration.class, clientConfigName)
                .addParameter(RequestOverrideConfiguration.class, requestOverrideConfigName);

        String publishersName = "publishers";

        methodBuilder.addStatement("$T $N = null", ParameterizedTypeName.get(List.class, MetricPublisher.class), publishersName);

        methodBuilder.beginControlFlow("if ($N != null)", requestOverrideConfigName)
                .addStatement("$N = $N.metricPublishers()", publishersName, requestOverrideConfigName)
                .endControlFlow();

        methodBuilder.beginControlFlow("if ($1N == null || $1N.isEmpty())", publishersName)
                .addStatement("$N = $N.option($T.$N)",
                              publishersName,
                              clientConfigName,
                              SdkClientOption.class,
                              "METRIC_PUBLISHERS")
                .endControlFlow();

        methodBuilder.beginControlFlow("if ($1N == null)", publishersName)
                .addStatement("$N = $T.emptyList()", publishersName, Collections.class)
                .endControlFlow();

        methodBuilder.addStatement("return $N", publishersName);

        return methodBuilder.build();
    }

    @Override
    protected MethodSpec.Builder waiterOperationBody(MethodSpec.Builder builder) {
        return builder.addAnnotation(Override.class)
                      .addStatement("return $T.builder().client(this).build()",
                                    poetExtensions.getSyncWaiterInterface());
    }

    protected MethodSpec updateSdkClientConfigurationMethod(
        TypeName serviceClientConfigurationBuilderClassName) {
        MethodSpec.Builder builder = MethodSpec.methodBuilder("updateSdkClientConfiguration")
                                               .addModifiers(PRIVATE)
                                               .addParameter(SdkRequest.class, "request")
                                               .addParameter(SdkClientConfiguration.class, "clientConfiguration")
                                               .returns(SdkClientConfiguration.class);

        builder.addStatement("$T plugins = request.overrideConfiguration()\n"
                             + ".map(c -> c.plugins()).orElse(Collections.emptyList())",
                             ParameterizedTypeName.get(List.class, SdkPlugin.class))
               .addStatement("$T configuration = clientConfiguration.toBuilder()", SdkClientConfiguration.Builder.class);

        builder.beginControlFlow("if (plugins.isEmpty())")
               .addStatement("return configuration.build()")
               .endControlFlow()
               .addStatement("$1T serviceConfigBuilder = new $1T(configuration)", serviceClientConfigurationBuilderClassName)
               .beginControlFlow("for ($T plugin : plugins)", SdkPlugin.class)
               .addStatement("plugin.configureClient(serviceConfigBuilder)")
               .endControlFlow();
        EndpointRulesSpecUtils endpointRulesSpecUtils = new EndpointRulesSpecUtils(this.model);

        if (model.getCustomizationConfig() == null ||
            CollectionUtils.isNullOrEmpty(model.getCustomizationConfig().getCustomClientContextParams())) {
            builder.addStatement("updateRetryStrategyClientConfiguration(configuration)");
            builder.addStatement("return configuration.build()");
            return builder.build();
        }

        Map customClientConfigParams = model.getCustomizationConfig().getCustomClientContextParams();

        builder.addCode("$1T newContextParams = configuration.option($2T.CLIENT_CONTEXT_PARAMS);\n"
                        + "$1T originalContextParams = clientConfiguration.option($2T.CLIENT_CONTEXT_PARAMS);",
                        AttributeMap.class, SdkClientOption.class);

        builder.addCode("newContextParams = (newContextParams != null) ? newContextParams : $1T.empty();\n"
                        + "originalContextParams = originalContextParams != null ? originalContextParams : $1T.empty();",
                        AttributeMap.class);

        customClientConfigParams.forEach((n, m) -> {
            String keyName = model.getNamingStrategy().getEnumValueName(n);
            builder.addStatement("$1T.validState($2T.equals(originalContextParams.get($3T.$4N), newContextParams.get($3T.$4N)),"
                                 + " $5S)",
                                 Validate.class, Objects.class, endpointRulesSpecUtils.clientContextParamsName(), keyName,
                                 keyName + " cannot be modified by request level plugins");
        });
        builder.addStatement("updateRetryStrategyClientConfiguration(configuration)");
        builder.addStatement("return configuration.build()");
        return builder.build();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy