fi.evolver.basics.spring.messaging.sender.aws.AwsUtils Maven / Gradle / Ivy
package fi.evolver.basics.spring.messaging.sender.aws;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.io.input.BoundedInputStream;
import fi.evolver.basics.spring.messaging.entity.Message;
import fi.evolver.basics.spring.messaging.util.SendUtils;
import fi.evolver.utils.CommunicationException;
import fi.evolver.utils.ContextUtils;
import fi.evolver.utils.stream.FinishingInputStream;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
class AwsUtils {
private static final String PROPERTY_API_CALL_TIMEOUT_MS = "ApiCallTimeoutMs";
private static final String PROPERTY_API_CALL_ATTEMPT_TIMEOUT_MS = "ApiCallAttemptTimeoutMs";
private static final String PROPERTY_STS_ASSUME_ROLE_ARN = "StsRoleArn";
private static final String PROPERTY_STS_ROLE_DURATION_SECONDS = "StsRoleDurationSeconds";
private static final String PROPERTY_STS_SESSION_NAME = "StsSessionName";
private static final String PROPERTY_ACCESS_KEY = "AccessKey";
private static final String PROPERTY_SECRET_ACCESS_KEY = "SecretAccessKey";
public static final String STATUS_FAILED = "FAILED";
public static final String STATUS_OK = "OK";
private AwsUtils() { /* Utility class */ }
public static AwsCredentialsProvider createCredentialsProvider(Message message, Region region) {
AwsCredentialsProvider result = createBaseCredentialsProvider(message);
String stsRoleArn = getStringParameter(message, PROPERTY_STS_ASSUME_ROLE_ARN, null);
String stsSessionName = getStringParameter(message, PROPERTY_STS_SESSION_NAME, "DefaultSessionName");
if (stsRoleArn != null) {
StsClient stsClient = StsClient.builder()
.endpointOverride(getEndpointOverride().orElse(null))
.region(region)
.build();
result = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(b -> b.roleArn(stsRoleArn)
.externalId(UUID.randomUUID().toString())
.roleSessionName(stsSessionName)
.durationSeconds(getIntegerParameter(message, PROPERTY_STS_ROLE_DURATION_SECONDS, 3600))
.build())
.build();
}
return result;
}
public static Optional getEndpointOverride() {
return Optional.ofNullable(System.getenv("AWS_ENDPOINT")).map(URI::create);
}
private static AwsCredentialsProvider createBaseCredentialsProvider(Message message) {
String accessKey = getStringParameter(message, PROPERTY_ACCESS_KEY, null);
String secretAccessKey = getStringParameter(message, PROPERTY_SECRET_ACCESS_KEY, null);
if (accessKey != null && secretAccessKey != null)
return StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKey, secretAccessKey));
return DefaultCredentialsProvider.create();
}
public static ClientOverrideConfiguration buildClientConfiguration(Message message) {
var builder = ClientOverrideConfiguration.builder();
Optional.ofNullable(getIntegerParameter(message, PROPERTY_API_CALL_TIMEOUT_MS, null)).map(Duration::ofMillis)
.ifPresent(builder::apiCallTimeout);
Optional.ofNullable(getIntegerParameter(message, PROPERTY_API_CALL_ATTEMPT_TIMEOUT_MS, null))
.map(Duration::ofMillis).ifPresent(builder::apiCallAttemptTimeout);
return builder.build();
}
public static String parseHost(URI uri) throws CommunicationException {
String host = uri.getHost();
if (host == null || host.isEmpty())
throw new CommunicationException("URI (%s) missing required host information", uri);
return host;
}
public static Arn parseArn(String arn, String expectedService) {
Arn result = Arn.fromString(arn);
if (!"aws".equals(result.partition()) || !expectedService.equals(result.service()))
throw new IllegalArgumentException("Invalid ARN, expected one with %s service".formatted(expectedService));
return result;
}
public static long getInputLength(Message message) throws IOException {
BoundedInputStream counter = BoundedInputStream.builder().setInputStream(SendUtils.createDataStream(message)).get();
try (FinishingInputStream finisher = new FinishingInputStream(counter)) {
/* Nothing to do */
}
return counter.getCount();
}
@SuppressWarnings("unchecked")
public static Map getRequestParameterValueMap(Message message) {
return ContextUtils.computeIfAbsent("AwsSender-%s".formatted(message.getId()), LinkedHashMap.class, LinkedHashMap::new);
}
public static Map removeRequestParameterValueMap(Message message) {
Map result = getRequestParameterValueMap(message);
ContextUtils.remove("%s-%s".formatted(AwsUtils.class.getSimpleName(), message.getId()));
return result;
}
public static String getStringParameter(Message message, String name, String defaultValue) {
String result = SendUtils.getTagReplacedTargetProperty(message, name).orElse(defaultValue);
if (result != null)
getRequestParameterValueMap(message).put(name, result);
return result;
}
public static String getStringParameter(Message message, String name) {
String result = getStringParameter(message, name, null);
if (result == null)
throw new IllegalArgumentException("Missing required parameter %s".formatted(name));
return result;
}
public static Integer getIntegerParameter(Message message, String name, Integer defaultValue) {
String value = getStringParameter(message, name, defaultValue == null ? null : defaultValue.toString());
return value != null ? Integer.parseInt(value) : null;
}
public static boolean getBooleanParameter(Message message, String name) {
String value = getStringParameter(message, name, null);
return "true".equals(value);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy