software.amazon.awssdk.codegen.poet.rules.EndpointResolverInterceptorSpec Maven / Gradle / Ivy
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package software.amazon.awssdk.codegen.poet.rules;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.jr.stree.JrsBoolean;
import com.fasterxml.jackson.jr.stree.JrsString;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;
import javax.lang.model.element.Modifier;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.AwsS3V4Signer;
import software.amazon.awssdk.auth.signer.SignerLoader;
import software.amazon.awssdk.awscore.AwsExecutionAttribute;
import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute;
import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme;
import software.amazon.awssdk.awscore.util.SignerOverrideUtils;
import software.amazon.awssdk.codegen.internal.Utils;
import software.amazon.awssdk.codegen.model.config.customization.EndpointAuthSchemeConfig;
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.model.intermediate.OperationModel;
import software.amazon.awssdk.codegen.model.rules.endpoints.ParameterModel;
import software.amazon.awssdk.codegen.model.service.ClientContextParam;
import software.amazon.awssdk.codegen.model.service.ContextParam;
import software.amazon.awssdk.codegen.model.service.EndpointTrait;
import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor;
import software.amazon.awssdk.codegen.model.service.StaticContextParam;
import software.amazon.awssdk.codegen.poet.ClassSpec;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.ModelAuthSchemeClassesKnowledgeIndex;
import software.amazon.awssdk.codegen.poet.waiters.JmesPathAcceptorGenerator;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SelectedAuthScheme;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.core.metrics.CoreMetric;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme;
import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.RegionSet;
import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption;
import software.amazon.awssdk.identity.spi.Identity;
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.HostnameValidator;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.awssdk.utils.internal.CodegenNamingUtils;
public class EndpointResolverInterceptorSpec implements ClassSpec {
private final IntermediateModel model;
private final EndpointRulesSpecUtils endpointRulesSpecUtils;
private final EndpointParamsKnowledgeIndex endpointParamsKnowledgeIndex;
private final PoetExtension poetExtension;
private final JmesPathAcceptorGenerator jmesPathGenerator;
private final boolean dependsOnHttpAuthAws;
private final boolean useSraAuth;
private final boolean multiAuthSigv4a;
private final boolean legacyAuthFromEndpointRulesService;
public EndpointResolverInterceptorSpec(IntermediateModel model) {
this.model = model;
this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model);
this.endpointParamsKnowledgeIndex = EndpointParamsKnowledgeIndex.of(model);
this.poetExtension = new PoetExtension(model);
this.jmesPathGenerator = new JmesPathAcceptorGenerator(poetExtension.jmesPathRuntimeClass());
// We need to know whether the service has a dependency on the http-auth-aws module. Because we can't check that
// directly, assume that if they're using AwsV4AuthScheme or AwsV4aAuthScheme that it's available.
Set> supportedAuthSchemes =
ModelAuthSchemeClassesKnowledgeIndex.of(model).serviceConcreteAuthSchemeClasses();
this.dependsOnHttpAuthAws = supportedAuthSchemes.contains(AwsV4AuthScheme.class) ||
supportedAuthSchemes.contains(AwsV4aAuthScheme.class);
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
this.multiAuthSigv4a = new AuthSchemeSpecUtils(model).usesSigV4a();
this.legacyAuthFromEndpointRulesService = new AuthSchemeSpecUtils(model).generateEndpointBasedParams();
}
@Override
public TypeSpec poetSpec() {
FieldSpec endpointAuthSchemeStrategyFieldSpec = endpointAuthSchemeStrategyFieldSpec();
TypeSpec.Builder b = PoetUtils.createClassBuilder(className())
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addAnnotation(SdkInternalApi.class)
.addSuperinterface(ExecutionInterceptor.class);
if (!useSraAuth) {
b.addField(endpointAuthSchemeStrategyFieldSpec);
b.addMethod(constructorMethodSpec(endpointAuthSchemeStrategyFieldSpec.name));
}
b.addMethod(modifyRequestMethod(endpointAuthSchemeStrategyFieldSpec.name));
b.addMethod(modifyHttpRequestMethod());
b.addMethod(ruleParams());
b.addMethod(setContextParams());
addContextParamMethods(b);
b.addMethod(setStaticContextParamsMethod());
addStaticContextParamMethods(b);
b.addMethod(authSchemeWithEndpointSignerPropertiesMethod());
if (hasClientContextParams()) {
b.addMethod(setClientContextParamsMethod());
}
b.addMethod(setOperationContextParams());
addOperationContextParamMethods(b);
b.addMethod(hostPrefixMethod());
if (!useSraAuth) {
b.addMethod(signerProviderMethod());
}
endpointParamsKnowledgeIndex.addAccountIdMethodsIfPresent(b);
return b.build();
}
@Override
public ClassName className() {
return endpointRulesSpecUtils.resolverInterceptorName();
}
private FieldSpec endpointAuthSchemeStrategyFieldSpec() {
return FieldSpec.builder(endpointRulesSpecUtils.rulesRuntimeClassName("EndpointAuthSchemeStrategy"),
"endpointAuthSchemeStrategy", Modifier.PRIVATE, Modifier.FINAL)
.build();
}
private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldName) {
MethodSpec.Builder b = MethodSpec.methodBuilder("modifyRequest")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.returns(SdkRequest.class)
.addParameter(Context.ModifyRequest.class, "context")
.addParameter(ExecutionAttributes.class, "executionAttributes");
String providerVar = "provider";
b.addStatement("$T result = context.request()", SdkRequest.class);
// We skip resolution if the source of the endpoint is the endpoint discovery call
b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))",
endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"));
b.addStatement("return result");
b.endControlFlow();
b.addStatement("$1T $2N = ($1T) executionAttributes.getAttribute($3T.ENDPOINT_PROVIDER)",
endpointRulesSpecUtils.providerInterfaceName(), providerVar, SdkInternalExecutionAttribute.class);
b.beginControlFlow("try");
b.addStatement("long resolveEndpointStart = $T.nanoTime()", System.class);
b.addStatement("$T endpointParams = ruleParams(result, executionAttributes)",
endpointRulesSpecUtils.parametersClassName());
b.addStatement("$T endpoint = $N.resolveEndpoint(endpointParams).join()",
Endpoint.class, providerVar);
b.addStatement("$1T resolveEndpointDuration = $1T.ofNanos($2T.nanoTime() - resolveEndpointStart)", Duration.class,
System.class);
b.addStatement("$T metricCollector = executionAttributes.getOptionalAttribute($T.API_CALL_METRIC_COLLECTOR)",
ParameterizedTypeName.get(Optional.class, MetricCollector.class), SdkExecutionAttribute.class);
b.addStatement("metricCollector.ifPresent(mc -> mc.reportMetric($T.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration))",
CoreMetric.class);
b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))",
endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"));
b.addStatement("$T hostPrefix = hostPrefix(executionAttributes.getAttribute($T.OPERATION_NAME), result)",
ParameterizedTypeName.get(Optional.class, String.class), SdkExecutionAttribute.class);
b.beginControlFlow("if (hostPrefix.isPresent())");
b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())",
endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"));
b.endControlFlow();
b.endControlFlow();
// If the endpoint resolver returns auth settings, use them as signer properties.
// This effectively works to set the preSRA Signer ExecutionAttributes, so it is not conditional on useSraAuth.
b.addStatement("$T<$T> endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)",
List.class, EndpointAuthScheme.class, AwsEndpointAttribute.class);
b.addStatement("$T> selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)",
SelectedAuthScheme.class, SdkInternalExecutionAttribute.class);
b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)");
b.addStatement("selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)");
if (multiAuthSigv4a || legacyAuthFromEndpointRulesService) {
b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications");
b.beginControlFlow("if(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) "
+ "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)",
AwsV4aAuthScheme.class, AwsV4aHttpSigner.class);
b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()",
AuthSchemeOption.Builder.class);
b.addStatement("$T regionSet = $T.create(endpointParams.region().id())",
RegionSet.class, RegionSet.class);
b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class);
b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), "
+ "optionBuilder.build())", SelectedAuthScheme.class);
b.endControlFlow();
}
b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)",
SdkInternalExecutionAttribute.class);
b.endControlFlow();
// For pre SRA client, use Signer as determined by endpoint resolved auth scheme
if (!useSraAuth) {
b.beginControlFlow("if (endpointAuthSchemes != null)");
b.addStatement("$T chosenAuthScheme = $N.chooseAuthScheme(endpointAuthSchemes)", EndpointAuthScheme.class,
endpointAuthSchemeStrategyFieldName);
b.addStatement("$T<$T> signerProvider = signerProvider(chosenAuthScheme)", Supplier.class, Signer.class);
b.addStatement("result = $T.overrideSignerIfNotOverridden(result, executionAttributes, signerProvider)",
SignerOverrideUtils.class);
b.endControlFlow();
}
b.addStatement("executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint)");
b.addStatement("return result");
b.endControlFlow();
b.beginControlFlow("catch ($T e)", CompletionException.class);
b.addStatement("$T cause = e.getCause()", Throwable.class);
b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class);
b.addStatement("throw ($T) cause", SdkClientException.class);
b.endControlFlow();
b.beginControlFlow("else");
b.addStatement("throw $T.create($S, cause)", SdkClientException.class, "Endpoint resolution failed");
b.endControlFlow();
b.endControlFlow();
return b.build();
}
private MethodSpec modifyHttpRequestMethod() {
MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.returns(SdkHttpRequest.class)
.addParameter(Context.ModifyHttpRequest.class, "context")
.addParameter(ExecutionAttributes.class, "executionAttributes");
b.addStatement("$T resolvedEndpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)",
Endpoint.class, SdkInternalExecutionAttribute.class);
b.beginControlFlow("if (resolvedEndpoint.headers().isEmpty())");
b.addStatement("return context.httpRequest()");
b.endControlFlow();
b.addStatement("$T httpRequestBuilder = context.httpRequest().toBuilder()", SdkHttpRequest.Builder.class);
b.addCode("resolvedEndpoint.headers().forEach((name, values) -> {");
b.addStatement("values.forEach(v -> httpRequestBuilder.appendHeader(name, v))");
b.addCode("});");
b.addStatement("return httpRequestBuilder.build()");
return b.build();
}
private MethodSpec ruleParams() {
MethodSpec.Builder b = MethodSpec.methodBuilder("ruleParams")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(endpointRulesSpecUtils.parametersClassName())
.addParameter(SdkRequest.class, "request")
.addParameter(ExecutionAttributes.class, "executionAttributes");
b.addStatement("$T builder = $T.builder()", paramsBuilderClass(), endpointRulesSpecUtils.parametersClassName());
Map parameters = model.getEndpointRuleSetModel().getParameters();
parameters.forEach((n, m) -> {
if (m.getBuiltInEnum() == null) {
return;
}
String setter = Utils.unCapitalize(CodegenNamingUtils.pascalCase(n));
switch (m.getBuiltInEnum()) {
case AWS_REGION:
b.addStatement(endpointProviderUtilsSetter("regionBuiltIn", setter));
break;
case AWS_USE_DUAL_STACK:
b.addStatement(endpointProviderUtilsSetter("dualStackEnabledBuiltIn", setter));
break;
case AWS_USE_FIPS:
b.addStatement(endpointProviderUtilsSetter("fipsEnabledBuiltIn", setter));
break;
case SDK_ENDPOINT:
b.addStatement(endpointProviderUtilsSetter("endpointBuiltIn", setter));
break;
case AWS_AUTH_ACCOUNT_ID:
b.addStatement("builder.$N(resolveAndRecordAccountIdFromIdentity(executionAttributes))", setter);
break;
case AWS_AUTH_ACCOUNT_ID_ENDPOINT_MODE:
b.addStatement("builder.$N(recordAccountIdEndpointMode(executionAttributes))", setter);
break;
case AWS_S3_USE_GLOBAL_ENDPOINT:
b.addStatement("builder.$N(executionAttributes.getAttribute($T.$N))",
setter, AwsExecutionAttribute.class, model.getNamingStrategy().getEnumValueName(n));
break;
// The S3 specific built-ins are set through the existing S3Configuration and set through client context params
case AWS_S3_ACCELERATE:
case AWS_S3_DISABLE_MULTI_REGION_ACCESS_POINTS:
case AWS_S3_FORCE_PATH_STYLE:
case AWS_S3_USE_ARN_REGION:
case AWS_S3_CONTROL_USE_ARN_REGION:
// end of S3 specific builtins
// V2 doesn't support this, only regional endpoints
case AWS_STS_USE_GLOBAL_ENDPOINT:
return;
default:
throw new RuntimeException("Don't know how to set built-in " + m.getBuiltInEnum());
}
});
if (hasClientContextParams()) {
b.addStatement("setClientContextParams(builder, executionAttributes)");
}
b.addStatement("setContextParams(builder, executionAttributes.getAttribute($T.OPERATION_NAME), request)",
AwsExecutionAttribute.class);
b.addStatement("setStaticContextParams(builder, executionAttributes.getAttribute($T.OPERATION_NAME))",
AwsExecutionAttribute.class);
b.addStatement("setOperationContextParams(builder, executionAttributes.getAttribute($T.OPERATION_NAME), request)",
AwsExecutionAttribute.class);
b.addStatement("return builder.build()");
return b.build();
}
private CodeBlock endpointProviderUtilsSetter(String builtInFn, String setterName) {
return CodeBlock.of("builder.$N($T.$N(executionAttributes))", setterName,
endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), builtInFn);
}
private ClassName paramsBuilderClass() {
return endpointRulesSpecUtils.parametersClassName().nestedClass("Builder");
}
private MethodSpec addStaticContextParamsMethod(OperationModel opModel) {
String methodName = staticContextParamsMethodName(opModel);
MethodSpec.Builder b = MethodSpec.methodBuilder(methodName)
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.returns(void.class)
.addParameter(paramsBuilderClass(), "params");
opModel.getStaticContextParams().forEach((n, m) -> {
String setterName = endpointRulesSpecUtils.paramMethodName(n);
TreeNode value = m.getValue();
switch (value.asToken()) {
case VALUE_STRING:
b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue());
break;
case VALUE_TRUE:
case VALUE_FALSE:
b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue());
break;
default:
throw new RuntimeException("Don't know how to set parameter of type " + value.asToken());
}
});
return b.build();
}
private String staticContextParamsMethodName(OperationModel opModel) {
return opModel.getMethodName() + "StaticContextParams";
}
private boolean hasStaticContextParams(OperationModel opModel) {
Map staticContextParams = opModel.getStaticContextParams();
return staticContextParams != null && !staticContextParams.isEmpty();
}
private boolean hasOperationContextParams(OperationModel opModel) {
return CollectionUtils.isNotEmpty(opModel.getOperationContextParams());
}
private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) {
Map operations = model.getOperations();
operations.forEach((n, m) -> {
if (hasStaticContextParams(m)) {
classBuilder.addMethod(addStaticContextParamsMethod(m));
}
});
}
private void addContextParamMethods(TypeSpec.Builder classBuilder) {
Map operations = model.getOperations();
operations.forEach((n, m) -> {
if (hasContextParams(m)) {
classBuilder.addMethod(setContextParamsMethod(m));
}
});
}
private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) {
Map operations = model.getOperations();
operations.forEach((n, m) -> {
if (hasOperationContextParams(m)) {
classBuilder.addMethod(setOperationContextParamsMethod(m));
}
});
}
private MethodSpec setStaticContextParamsMethod() {
Map operations = model.getOperations();
MethodSpec.Builder b = MethodSpec.methodBuilder("setStaticContextParams")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(paramsBuilderClass(), "params")
.addParameter(String.class, "operationName")
.returns(void.class);
boolean generateSwitch = operations.values().stream().anyMatch(this::hasStaticContextParams);
if (generateSwitch) {
b.beginControlFlow("switch (operationName)");
operations.forEach((n, m) -> {
if (!hasStaticContextParams(m)) {
return;
}
b.addCode("case $S:", n);
b.addStatement("$N(params)", staticContextParamsMethodName(m));
b.addStatement("break");
});
b.addCode("default:");
b.addStatement("break");
b.endControlFlow();
}
return b.build();
}
private MethodSpec setContextParams() {
Map operations = model.getOperations();
MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(paramsBuilderClass(), "params")
.addParameter(String.class, "operationName")
.addParameter(SdkRequest.class, "request")
.returns(void.class);
boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams);
if (generateSwitch) {
b.beginControlFlow("switch (operationName)");
operations.forEach((n, m) -> {
if (!hasContextParams(m)) {
return;
}
String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName());
ClassName requestClass = poetExtension.getModelClass(requestClassName);
b.addCode("case $S:", n);
b.addStatement("setContextParams(params, ($T) request)", requestClass);
b.addStatement("break");
});
b.addCode("default:");
b.addStatement("break");
b.endControlFlow();
}
return b.build();
}
private MethodSpec setOperationContextParams() {
Map operations = model.getOperations();
MethodSpec.Builder b = MethodSpec.methodBuilder("setOperationContextParams")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(paramsBuilderClass(), "params")
.addParameter(String.class, "operationName")
.addParameter(SdkRequest.class, "request")
.returns(void.class);
boolean generateSwitch = operations.values().stream().anyMatch(this::hasOperationContextParams);
if (generateSwitch) {
b.beginControlFlow("switch (operationName)");
operations.forEach((n, m) -> {
if (!hasOperationContextParams(m)) {
return;
}
String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName());
ClassName requestClass = poetExtension.getModelClass(requestClassName);
b.addCode("case $S:", n);
b.addStatement("setOperationContextParams(params, ($T) request)", requestClass);
b.addStatement("break");
});
b.addCode("default:");
b.addStatement("break");
b.endControlFlow();
}
return b.build();
}
private MethodSpec setContextParamsMethod(OperationModel opModel) {
String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName());
ClassName requestClass = poetExtension.getModelClass(requestClassName);
MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(paramsBuilderClass(), "params")
.addParameter(requestClass, "request")
.returns(void.class);
opModel.getInputShape().getMembers().forEach(m -> {
ContextParam param = m.getContextParam();
if (param == null) {
return;
}
String setterName = endpointRulesSpecUtils.paramMethodName(param.getName());
b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName());
});
return b.build();
}
private MethodSpec setOperationContextParamsMethod(OperationModel opModel) {
String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName());
ClassName requestClass = poetExtension.getModelClass(requestClassName);
MethodSpec.Builder b = MethodSpec.methodBuilder("setOperationContextParams")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(paramsBuilderClass(), "params")
.addParameter(requestClass, "request")
.returns(void.class);
b.addStatement("$1T input = new $1T(request)", poetExtension.jmesPathRuntimeClass().nestedClass("Value"));
opModel.getOperationContextParams().forEach((key, value) -> {
if (Objects.requireNonNull(value.getPath().asToken()) == JsonToken.VALUE_STRING) {
String setterName = endpointRulesSpecUtils.paramMethodName(key);
String jmesPathString = ((JrsString) value.getPath()).getValue();
CodeBlock addParam = CodeBlock.builder()
.add("params.$N(", setterName)
.add(jmesPathGenerator.interpret(jmesPathString, "input"))
.add(matchToParameterType(key))
.add(")")
.build();
b.addStatement(addParam);
} else {
throw new RuntimeException("Invalid operation context parameter path for " + opModel.getOperationName() +
". Expected VALUE_STRING, but got " + value.getPath().asToken());
}
});
return b.build();
}
private CodeBlock matchToParameterType(String paramName) {
Map parameters = model.getEndpointRuleSetModel().getParameters();
Optional endpointParameter = parameters.entrySet().stream()
.filter(e -> e.getKey().toLowerCase(Locale.US)
.equals(paramName.toLowerCase(Locale.US)))
.map(Map.Entry::getValue)
.findFirst();
return endpointParameter.map(this::convertValueToParameterType).orElseGet(() -> CodeBlock.of(""));
}
private CodeBlock convertValueToParameterType(ParameterModel parameterModel) {
switch (parameterModel.getType().toLowerCase(Locale.US)) {
case "boolean":
return CodeBlock.of(".booleanValue()");
case "string":
return CodeBlock.of(".stringValue()");
case "stringarray":
return CodeBlock.of(".stringValues()");
default:
throw new UnsupportedOperationException(
"Supported types are boolean, string and stringarray. Given type was " + parameterModel.getType());
}
}
private boolean hasContextParams(OperationModel opModel) {
return opModel.getInputShape().getMembers().stream()
.anyMatch(m -> m.getContextParam() != null);
}
private boolean hasClientContextParams() {
Map clientContextParams = model.getClientContextParams();
return clientContextParams != null && !clientContextParams.isEmpty();
}
private MethodSpec setClientContextParamsMethod() {
MethodSpec.Builder b = MethodSpec.methodBuilder("setClientContextParams")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(paramsBuilderClass(), "params")
.addParameter(ExecutionAttributes.class, "executionAttributes")
.returns(void.class);
b.addStatement("$T clientContextParams = executionAttributes.getAttribute($T.CLIENT_CONTEXT_PARAMS)",
AttributeMap.class, SdkInternalExecutionAttribute.class);
ClassName paramsClass = endpointRulesSpecUtils.clientContextParamsName();
Map params = model.getClientContextParams();
params.forEach((n, m) -> {
String attrName = endpointRulesSpecUtils.clientContextParamName(n);
b.addStatement("$T.ofNullable(clientContextParams.get($T.$N)).ifPresent(params::$N)", Optional.class, paramsClass,
attrName,
endpointRulesSpecUtils.paramMethodName(n));
});
return b.build();
}
private MethodSpec hostPrefixMethod() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("hostPrefix")
.returns(ParameterizedTypeName.get(Optional.class, String.class))
.addParameter(String.class, "operationName")
.addParameter(SdkRequest.class, "request")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC);
boolean generateSwitch =
model.getOperations().values().stream().anyMatch(opModel -> StringUtils.isNotBlank(getHostPrefix(opModel)));
if (!generateSwitch) {
builder.addStatement("return $T.empty()", Optional.class);
} else {
builder.beginControlFlow("switch (operationName)");
model.getOperations().forEach((name, opModel) -> {
String hostPrefix = getHostPrefix(opModel);
if (StringUtils.isBlank(hostPrefix)) {
return;
}
builder.beginControlFlow("case $S:", name);
HostPrefixProcessor processor = new HostPrefixProcessor(hostPrefix);
if (processor.c2jNames().isEmpty()) {
builder.addStatement("return $T.of($S)", Optional.class, processor.hostWithStringSpecifier());
} else {
String requestVar = opModel.getInput().getVariableName();
processor.c2jNames().forEach(c2jName -> {
builder.addStatement("$1T.validateHostnameCompliant(request.getValueForField($2S, $3T.class)"
+ ".orElse(null), $2S, $4S)",
HostnameValidator.class,
c2jName,
String.class,
requestVar);
});
builder.addCode("return $T.of($T.format($S, ", Optional.class, String.class,
processor.hostWithStringSpecifier());
Iterator c2jNamesIter = processor.c2jNames().listIterator();
while (c2jNamesIter.hasNext()) {
builder.addCode("request.getValueForField($S, $T.class).get()", c2jNamesIter.next(), String.class);
if (c2jNamesIter.hasNext()) {
builder.addCode(",");
}
}
builder.addStatement("))");
}
builder.endControlFlow();
});
builder.addCode("default:");
builder.addStatement("return $T.empty()", Optional.class);
builder.endControlFlow();
}
return builder.build();
}
private String getHostPrefix(OperationModel opModel) {
EndpointTrait endpointTrait = opModel.getEndpointTrait();
if (endpointTrait == null) {
return null;
}
return endpointTrait.getHostPrefix();
}
private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() {
TypeVariableName tExtendsIdentity = TypeVariableName.get("T", Identity.class);
TypeName selectedAuthSchemeOfT = ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class),
TypeVariableName.get("T"));
TypeName listOfEndpointAuthScheme = ParameterizedTypeName.get(List.class, EndpointAuthScheme.class);
MethodSpec.Builder method =
MethodSpec.methodBuilder("authSchemeWithEndpointSignerProperties")
.addModifiers(Modifier.PRIVATE)
.addTypeVariable(tExtendsIdentity)
.returns(selectedAuthSchemeOfT)
.addParameter(listOfEndpointAuthScheme, "endpointAuthSchemes")
.addParameter(selectedAuthSchemeOfT, "selectedAuthScheme");
method.beginControlFlow("for ($T endpointAuthScheme : endpointAuthSchemes)", EndpointAuthScheme.class);
if (useSraAuth) {
// Don't include signer properties for auth options that don't match our selected auth scheme
method.beginControlFlow("if (!endpointAuthScheme.schemeId()"
+ ".equals(selectedAuthScheme.authSchemeOption().schemeId()))");
method.addStatement("continue");
method.endControlFlow();
}
method.addStatement("$T option = selectedAuthScheme.authSchemeOption().toBuilder()", AuthSchemeOption.Builder.class);
if (dependsOnHttpAuthAws) {
method.addCode(copyV4EndpointSignerPropertiesToAuth());
method.addCode(copyV4aEndpointSignerPropertiesToAuth());
if (endpointRulesSpecUtils.useS3Express()) {
method.addCode(copyS3ExpressEndpointSignerPropertiesToAuth());
}
}
method.addStatement("throw new $T(\"Endpoint auth scheme '\" + endpointAuthScheme.name() + \"' cannot be mapped to the "
+ "SDK auth scheme. Was it declared in the service's model?\")",
IllegalArgumentException.class);
method.endControlFlow();
method.addStatement("return selectedAuthScheme");
return method.build();
}
private static CodeBlock copyV4EndpointSignerPropertiesToAuth() {
CodeBlock.Builder code = CodeBlock.builder();
code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4AuthScheme.class);
code.addStatement("$1T v4AuthScheme = ($1T) endpointAuthScheme", SigV4AuthScheme.class);
code.beginControlFlow("if (v4AuthScheme.isDisableDoubleEncodingSet())");
code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding())",
AwsV4HttpSigner.class);
code.endControlFlow();
code.beginControlFlow("if (v4AuthScheme.signingRegion() != null)");
code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())",
AwsV4HttpSigner.class);
code.endControlFlow();
code.beginControlFlow("if (v4AuthScheme.signingName() != null)");
code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4AuthScheme.signingName())",
AwsV4HttpSigner.class);
code.endControlFlow();
code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())",
SelectedAuthScheme.class);
code.endControlFlow();
return code.build();
}
private CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
CodeBlock.Builder code = CodeBlock.builder();
code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4aAuthScheme.class);
code.addStatement("$1T v4aAuthScheme = ($1T) endpointAuthScheme", SigV4aAuthScheme.class);
code.beginControlFlow("if (v4aAuthScheme.isDisableDoubleEncodingSet())");
code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding())",
AwsV4aHttpSigner.class);
code.endControlFlow();
if (multiAuthSigv4a || legacyAuthFromEndpointRulesService) {
code.beginControlFlow("if (!(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) "
+ "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) != null) "
+ "&& !$T.isNullOrEmpty(v4aAuthScheme.signingRegionSet()))",
AwsV4aAuthScheme.class, AwsV4aHttpSigner.class, CollectionUtils.class);
} else {
code.beginControlFlow("if (!$T.isNullOrEmpty(v4aAuthScheme.signingRegionSet()))", CollectionUtils.class);
}
code.addStatement("$1T regionSet = $1T.create(v4aAuthScheme.signingRegionSet())", RegionSet.class);
code.addStatement("option.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class);
code.endControlFlow();
code.beginControlFlow("if (v4aAuthScheme.signingName() != null)");
code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName())",
AwsV4aHttpSigner.class);
code.endControlFlow();
code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())",
SelectedAuthScheme.class);
code.endControlFlow();
return code.build();
}
private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() {
CodeBlock.Builder code = CodeBlock.builder();
ClassName s3ExpressEndpointAuthSchemeClassName = ClassName.get(
model.getMetadata().getFullClientPackageName() + ".endpoints.authscheme",
"S3ExpressEndpointAuthScheme");
code.beginControlFlow("if (endpointAuthScheme instanceof $T)", s3ExpressEndpointAuthSchemeClassName);
code.addStatement("$1T s3ExpressAuthScheme = ($1T) endpointAuthScheme", s3ExpressEndpointAuthSchemeClassName);
code.beginControlFlow("if (s3ExpressAuthScheme.isDisableDoubleEncodingSet())");
code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !s3ExpressAuthScheme.disableDoubleEncoding())",
AwsV4HttpSigner.class);
code.endControlFlow();
code.beginControlFlow("if (s3ExpressAuthScheme.signingRegion() != null)");
code.addStatement("option.putSignerProperty($T.REGION_NAME, s3ExpressAuthScheme.signingRegion())",
AwsV4HttpSigner.class);
code.endControlFlow();
code.beginControlFlow("if (s3ExpressAuthScheme.signingName() != null)");
code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, s3ExpressAuthScheme.signingName())",
AwsV4HttpSigner.class);
code.endControlFlow();
code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())",
SelectedAuthScheme.class);
code.endControlFlow();
return code.build();
}
private MethodSpec signerProviderMethod() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("signerProvider")
.addModifiers(Modifier.PRIVATE)
.addParameter(EndpointAuthScheme.class, "authScheme")
.returns(ParameterizedTypeName.get(Supplier.class, Signer.class));
builder.beginControlFlow("switch (authScheme.name())");
builder.addCode("case $S:", "sigv4");
if (endpointRulesSpecUtils.isS3() || endpointRulesSpecUtils.isS3Control()) {
builder.addStatement("return $T::create", AwsS3V4Signer.class);
} else {
builder.addStatement("return $T::create", Aws4Signer.class);
}
builder.addCode("case $S:", "sigv4a");
if (endpointRulesSpecUtils.isS3() || endpointRulesSpecUtils.isS3Control()) {
builder.addStatement("return $T::getS3SigV4aSigner", SignerLoader.class);
} else {
builder.addStatement("return $T::getSigV4aSigner", SignerLoader.class);
}
builder.addCode("default:");
builder.addStatement("break");
builder.endControlFlow();
builder.addStatement("throw $T.create($S + authScheme.name())",
SdkClientException.class,
"Don't know how to create signer for auth scheme: ");
return builder.build();
}
private MethodSpec constructorMethodSpec(String endpointAuthSchemeFieldName) {
MethodSpec.Builder b = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC);
EndpointAuthSchemeConfig endpointAuthSchemeConfig = model.getCustomizationConfig().getEndpointAuthSchemeConfig();
String factoryLocalVarName = "endpointAuthSchemeStrategyFactory";
if (endpointAuthSchemeConfig != null && endpointAuthSchemeConfig.getAuthSchemeStrategyFactoryClass() != null) {
b.addStatement("$T $N = new $T()",
endpointRulesSpecUtils.rulesRuntimeClassName("EndpointAuthSchemeStrategyFactory"),
factoryLocalVarName,
PoetUtils.classNameFromFqcn(endpointAuthSchemeConfig.getAuthSchemeStrategyFactoryClass()));
} else {
b.addStatement("$T $N = new $T()",
endpointRulesSpecUtils.rulesRuntimeClassName("EndpointAuthSchemeStrategyFactory"),
factoryLocalVarName,
endpointRulesSpecUtils.rulesRuntimeClassName("DefaultEndpointAuthSchemeStrategyFactory"));
}
b.addStatement("this.$N = $N.endpointAuthSchemeStrategy()", endpointAuthSchemeFieldName, factoryLocalVarName);
return b.build();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy