All Downloads are FREE. Search and download functionalities are using the official Maven repository.

jvmTest.aws.sdk.kotlin.runtime.auth.credentials.DefaultChainCredentialsProviderTest.kt Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package aws.sdk.kotlin.runtime.auth.credentials

import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
import aws.smithy.kotlin.runtime.httptest.TestConnection
import aws.smithy.kotlin.runtime.time.Instant
import aws.smithy.kotlin.runtime.util.Filesystem
import aws.smithy.kotlin.runtime.util.OperatingSystem
import aws.smithy.kotlin.runtime.util.OsFamily
import aws.smithy.kotlin.runtime.util.PlatformProvider
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.*
import java.io.File
import java.nio.file.Paths
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

// TODO - refactor to make this work in common
@OptIn(ExperimentalCoroutinesApi::class)
class DefaultChainCredentialsProviderTest {

    class FsRootedAt(val root: File) : Filesystem {
        override val filePathSeparator: String = "/"
        override suspend fun readFileOrNull(path: String): ByteArray? {
            val realPath = Paths.get(root.path, path).toFile()
            return if (realPath.exists()) {
                withContext(Dispatchers.IO) {
                    realPath.readBytes()
                }
            } else {
                null
            }
        }

        override suspend fun writeFile(path: String, data: ByteArray) {
            error("not needed for test")
        }
    }

    class DefaultChainPlatformProvider(
        private val env: Map,
        private val fs: Filesystem,
    ) : PlatformProvider, Filesystem by fs {
        override fun osInfo(): OperatingSystem = OperatingSystem(OsFamily.Linux, "test")
        override val isJvm: Boolean = true
        override val isAndroid: Boolean = false
        override val isBrowser: Boolean = false
        override val isNode: Boolean = false
        override val isNative: Boolean = false

        override fun getAllProperties(): Map = mapOf()
        override fun getProperty(key: String): String? = null
        override fun getAllEnvVars(): Map = env
        override fun getenv(key: String): String? = env[key]
    }

    sealed class TestResult {
        abstract val name: String
        abstract val docs: String

        data class Ok(
            override val name: String,
            override val docs: String,
            val creds: Credentials,
        ) : TestResult()

        data class ErrorContains(
            override val name: String,
            override val docs: String,
            val message: String,
        ) : TestResult()
        companion object {

            fun fromJson(payload: String): TestResult {
                val obj = Json.parseToJsonElement(payload).jsonObject
                val name = checkNotNull(obj["name"]).jsonPrimitive.content
                val docs = checkNotNull(obj["docs"]).jsonPrimitive.content
                val result = checkNotNull(obj["result"]).jsonObject
                return when {
                    "Ok" in result -> {
                        val o = checkNotNull(result["Ok"]).jsonObject
                        val creds = Credentials(
                            checkNotNull(o["access_key_id"]).jsonPrimitive.content,
                            checkNotNull(o["secret_access_key"]).jsonPrimitive.content,
                            o["session_token"]?.jsonPrimitive?.content,
                            o["expiry"]?.jsonPrimitive?.longOrNull?.let { Instant.fromEpochSeconds(it) },
                        )
                        Ok(name, docs, creds)
                    }
                    "ErrorContains" in result -> ErrorContains(name, docs, checkNotNull(result["ErrorContains"]).jsonPrimitive.content)
                    else -> error("unrecognized result object: $result")
                }
            }
        }
    }

    data class TestCase(
        val expected: TestResult,
        val testPlatform: PlatformProvider,
        val testEngine: TestConnection,
    )

    fun makeTest(name: String): TestCase {
        val url = this::class.java.classLoader.getResource("default-provider-chain") ?: error("failed to load default-provider-chain test suite resource")
        val testSuiteDir = Paths.get(url.toURI()).toFile()
        val testDir = testSuiteDir.resolve(name)
        if (!testDir.exists()) error("$testDir does not exist")

        val testCaseFile = testDir.resolve("test-case.json")
        if (!testCaseFile.exists()) error("no test-case.json in $testDir")

        val testResult = TestResult.fromJson(testCaseFile.readText())

        val envFile = testDir.resolve("env.json")
        val env = if (envFile.exists()) {
            val el = Json.parseToJsonElement(envFile.readText())
            el.jsonObject.mapValues { it.value.jsonPrimitive.content }
        } else {
            emptyMap()
        }

        val httpTrafficFile = testDir.resolve("http-traffic.json")
        val testEngine = if (httpTrafficFile.exists()) {
            val traffic = httpTrafficFile.readText()
            TestConnection.fromJson(traffic)
        } else {
            TestConnection()
        }

        val fs = FsRootedAt(testDir.resolve("fs"))
        // TODO - support for system props
        val testProvider = DefaultChainPlatformProvider(env, fs)

        return TestCase(testResult, testProvider, testEngine)
    }

    /**
     * Execute a test from the default chain test suite
     * @param name The name of root directory for the test (from common/test-resources/default-provider-chain)
     */
    fun executeTest(name: String) = runTest {
        val test = makeTest(name)
        val provider = DefaultChainCredentialsProvider(platformProvider = test.testPlatform, httpClient = test.testEngine)
        val actual = runCatching { provider.resolve() }
        val expected = test.expected
        when {
            expected is TestResult.Ok && actual.isFailure -> error("expected success, got error: $actual")
            expected is TestResult.ErrorContains && actual.isSuccess -> error("expected error, succeeded: $actual")
            expected is TestResult.Ok && actual.isSuccess -> {
                // if the expected creds have no expiration, use that, otherwise assert they are the same.
                // This is because the caching provider will expire even static creds after the given default timeout
                val actualCreds = actual.getOrThrow()

                val sanitizedExpiration = if (expected.creds.expiration == null) null else actualCreds.expiration
                val creds = actualCreds.copy(providerName = null, expiration = sanitizedExpiration)
                assertEquals(expected.creds, creds)

                // assert http traffic to the extent we can. These tests do not have specific timestamps they
                // were signed with and some lack enough context to even assert a body (e.g. incorrect content-type).
                // They would require additional metadata to make use of `testEngine.assertRequests()`.
                test.testEngine.requests().forEach { call ->
                    if (call.expected != null) {
                        assertEquals(call.expected!!.url.host, call.actual.url.host)
                    }
                }
            }
            expected is TestResult.ErrorContains && actual.isFailure -> {
                val ex = actual.exceptionOrNull()!!
                // the chain contains a generic exception with the list of providers tried but it
                // contains all of the suppressed exceptions along the way. Inspect them all and their causes.
                // In particular a test case only looks to verify a specific behavior and even though it
                // may fail at the correct spot, later providers may still be tried and also fail.
                val needle = expected.message
                val haystack = listOf(ex.message!!) + ex.suppressed.map { it.message!! } + ex.suppressed.mapNotNull { it.cause?.message }
                val expectedErrorFound = haystack.any { it.contains(needle) }
                assertTrue(expectedErrorFound, "`$needle` not found in any of the chain exception messages: $haystack")
            }
            else -> error("should not be able to get here")
        }
    }

    @Test
    fun testEcsAssumeRole() = executeTest("ecs_assume_role")

    @Test
    fun testEcsCredentials() = executeTest("ecs_credentials")

    @Test
    fun testImdsAssumeRole() = executeTest("imds_assume_role")

    @Test
    fun testImdsConfigWithNoCreds() = executeTest("imds_config_with_no_creds")

    @Test
    fun testImdsDefaultChainError() = executeTest("imds_default_chain_error")

    @Test
    fun testImdsDefaultChainRetries() = executeTest("imds_default_chain_retries")

    @Test
    fun testImdsDefaultChainSuccess() = executeTest("imds_default_chain_success")

    @Test
    fun testImdsDisabled() = executeTest("imds_disabled")

    @Test
    fun testImdsNoIamRole() = executeTest("imds_no_iam_role")

    @Test
    fun testImdsTokenFail() = executeTest("imds_token_fail")

    @Test
    fun testPreferEnvironment() = executeTest("prefer_environment")

    @Test
    fun testProfileName() = executeTest("profile_name")

    // FIXME - need to discuss desired behavior. This tests precedence and assumes that if a provider
    // is configured that any errors should surface and no further providers tried. Not all SDKs do this
    // though (e.g. Java keeps trying until all providers are exhausted).
    // @Test
    // fun testProfileOverridesWebIdentity() = executeTest("profile_overrides_web_identity")

    @Test
    fun testProfileStaticKeys() = executeTest("profile_static_keys")

    @Test
    fun testWebIdentitySourceProfileNoEnv() = executeTest("web_identity_source_profile_no_env")

    @Test
    fun testWebIdentityTokenEnv() = executeTest("web_identity_token_env")

    // NOTE: the  tag here in the HTTP traffic is correctly parsed by the hand-written deserializer
    // to match error code, etc. The model and generated deserializer uses lowercase `message` though so the
    // detailed message we would actually like to check for `No OpenIDConnect provider found in your account for...`
    // is only available on the suppressed exception->cause->sdkErrorMetadata.errorMessage.
    // See https://github.com/awslabs/aws-sdk-kotlin/issues/479
    @Test
    fun testWebIdentityTokenInvalidJwt() = executeTest("web_identity_token_invalid_jwt")

    @Test
    fun testWebIdentityTokenProfile() = executeTest("web_identity_token_profile")

    @Test
    fun testWebIdentityTokenSourceProfile() = executeTest("web_identity_token_source_profile")

    @Test
    fun testLegacySsoRole() = executeTest("legacy_sso_role")

    @Test
    fun testSsoSession() = executeTest("sso_session")

    @Test
    fun testStsRetryOnError() = executeTest("retry_on_error")
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy