All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.basics.spring.messaging.sender.aws.AwsUtils Maven / Gradle / Ivy

package fi.evolver.basics.spring.messaging.sender.aws;

import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import org.apache.commons.io.input.BoundedInputStream;

import fi.evolver.basics.spring.messaging.entity.Message;
import fi.evolver.basics.spring.messaging.util.SendUtils;
import fi.evolver.utils.CommunicationException;
import fi.evolver.utils.ContextUtils;
import fi.evolver.utils.stream.FinishingInputStream;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;

class AwsUtils {
	private static final String PROPERTY_API_CALL_TIMEOUT_MS = "ApiCallTimeoutMs";
	private static final String PROPERTY_API_CALL_ATTEMPT_TIMEOUT_MS = "ApiCallAttemptTimeoutMs";
	private static final String PROPERTY_STS_ASSUME_ROLE_ARN = "StsRoleArn";
	private static final String PROPERTY_STS_ROLE_DURATION_SECONDS = "StsRoleDurationSeconds";
	private static final String PROPERTY_STS_SESSION_NAME = "StsSessionName";
	private static final String PROPERTY_ACCESS_KEY = "AccessKey";
	private static final String PROPERTY_SECRET_ACCESS_KEY = "SecretAccessKey";

	public static final String STATUS_FAILED = "FAILED";
	public static final String STATUS_OK = "OK";

	private AwsUtils() { /* Utility class */ }

	public static AwsCredentialsProvider createCredentialsProvider(Message message, Region region) {
		AwsCredentialsProvider result = createBaseCredentialsProvider(message);

		String stsRoleArn = getStringParameter(message, PROPERTY_STS_ASSUME_ROLE_ARN, null);
		String stsSessionName = getStringParameter(message, PROPERTY_STS_SESSION_NAME, "DefaultSessionName");
		if (stsRoleArn != null) {
			StsClient stsClient = StsClient.builder()
					.endpointOverride(getEndpointOverride().orElse(null))
					.region(region)
					.build();
			result = StsAssumeRoleCredentialsProvider.builder()
					.stsClient(stsClient)
					.refreshRequest(b -> b.roleArn(stsRoleArn)
							.externalId(UUID.randomUUID().toString())
							.roleSessionName(stsSessionName)
							.durationSeconds(getIntegerParameter(message, PROPERTY_STS_ROLE_DURATION_SECONDS, 3600))
							.build())
					.build();
		}
		return result;
	}


	public static Optional getEndpointOverride()  {
		return Optional.ofNullable(System.getenv("AWS_ENDPOINT")).map(URI::create);
	}


	private static AwsCredentialsProvider createBaseCredentialsProvider(Message message) {
		String accessKey = getStringParameter(message, PROPERTY_ACCESS_KEY, null);
		String secretAccessKey = getStringParameter(message, PROPERTY_SECRET_ACCESS_KEY, null);
		if (accessKey != null && secretAccessKey != null)
			return StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKey, secretAccessKey));
		return DefaultCredentialsProvider.create();
	}

	public static ClientOverrideConfiguration buildClientConfiguration(Message message) {
		var builder = ClientOverrideConfiguration.builder();
		Optional.ofNullable(getIntegerParameter(message, PROPERTY_API_CALL_TIMEOUT_MS, null)).map(Duration::ofMillis)
				.ifPresent(builder::apiCallTimeout);
		Optional.ofNullable(getIntegerParameter(message, PROPERTY_API_CALL_ATTEMPT_TIMEOUT_MS, null))
				.map(Duration::ofMillis).ifPresent(builder::apiCallAttemptTimeout);
		return builder.build();
	}

	public static String parseHost(URI uri) throws CommunicationException {
		String host = uri.getHost();
		if (host == null || host.isEmpty())
			throw new CommunicationException("URI (%s) missing required host information", uri);
		return host;
	}

	public static Arn parseArn(String arn, String expectedService) {
		Arn result = Arn.fromString(arn);
		if (!"aws".equals(result.partition()) || !expectedService.equals(result.service()))
			throw new IllegalArgumentException("Invalid ARN, expected one with %s service".formatted(expectedService));
		return result;
	}

	public static long getInputLength(Message message) throws IOException {
		BoundedInputStream counter = BoundedInputStream.builder().setInputStream(SendUtils.createDataStream(message)).get();
		try (FinishingInputStream finisher = new FinishingInputStream(counter)) {
			/* Nothing to do */
		}
		return counter.getCount();
	}

	@SuppressWarnings("unchecked")
	public static Map getRequestParameterValueMap(Message message) {
		return ContextUtils.computeIfAbsent("AwsSender-%s".formatted(message.getId()), LinkedHashMap.class, LinkedHashMap::new);
	}

	public static Map removeRequestParameterValueMap(Message message) {
		Map result = getRequestParameterValueMap(message);
		ContextUtils.remove("%s-%s".formatted(AwsUtils.class.getSimpleName(), message.getId()));
		return result;
	}

	public static String getStringParameter(Message message, String name, String defaultValue) {
		String result = SendUtils.getTagReplacedTargetProperty(message, name).orElse(defaultValue);
		if (result != null)
			getRequestParameterValueMap(message).put(name, result);
		return result;
	}

	public static String getStringParameter(Message message, String name) {
		String result = getStringParameter(message, name, null);
		if (result == null)
			throw new IllegalArgumentException("Missing required parameter %s".formatted(name));
		return result;
	}

	public static Integer getIntegerParameter(Message message, String name, Integer defaultValue) {
		String value = getStringParameter(message, name, defaultValue == null ? null : defaultValue.toString());
		return value != null ? Integer.parseInt(value) : null;
	}

	public static boolean getBooleanParameter(Message message, String name) {
		String value = getStringParameter(message, name, null);
		return "true".equals(value);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy