All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.smithy.kotlin.codegen.rendering.protocol.HttpProtocolTestGenerator.kt Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.kotlin.codegen.rendering.protocol

import software.amazon.smithy.kotlin.codegen.core.defaultName
import software.amazon.smithy.kotlin.codegen.model.getTrait
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.protocoltests.traits.AppliesTo
import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import java.util.*
import java.util.logging.Logger

enum class TestContainmentMode {
    RUN_TESTS, EXCLUDE_TESTS
}

/**
 * Specifies tests to add or subtract to the complete set.
 */
data class TestMemberDelta(val members: Set, val runMode: TestContainmentMode = TestContainmentMode.EXCLUDE_TESTS)

/**
 * Generates protocol unit tests for the HTTP protocol from smithy models.
 */
class HttpProtocolTestGenerator(
    private val ctx: ProtocolGenerator.GenerationContext,
    private val requestTestBuilder: HttpProtocolUnitTestRequestGenerator.Builder,
    private val responseTestBuilder: HttpProtocolUnitTestResponseGenerator.Builder,
    private val errorTestBuilder: HttpProtocolUnitTestErrorGenerator.Builder,
    // list of test ID's to ignore/skip
    private val testDelta: TestMemberDelta = TestMemberDelta(setOf()),
) {
    private val LOGGER = Logger.getLogger(javaClass.name)

    /**
     * Generates the API HTTP protocol tests defined in the smithy model.
     */
    fun generateProtocolTests() {
        val operationIndex: OperationIndex = OperationIndex.of(ctx.model)
        val topDownIndex: TopDownIndex = TopDownIndex.of(ctx.model)

        for (operation in TreeSet(topDownIndex.getContainedOperations(ctx.service).filterNot(::serverOnly))) {
            // 1. Generate test cases for each request.
            val requestTests = operation.getTrait()
                ?.getTestCasesFor(AppliesTo.CLIENT)
                ?.filter(::isTestCaseAllowedForRunMode)

            requestTests?.let { testCases ->
                val testOperationName = operation.id.name.replaceFirstChar { c -> c.uppercaseChar() }
                val testClassName = "${testOperationName}RequestTest"
                val testFilename = "$testClassName.kt"
                ctx.delegator.useTestFileWriter(testFilename, ctx.settings.pkg.name) { writer ->
                    LOGGER.fine("Generating request protocol test cases for ${operation.id}")

                    // import package.models.*
                    writer.addImport("${ctx.settings.pkg.name}.model", "*")

                    requestTestBuilder
                        .ctx(ctx)
                        .writer(writer)
                        .model(ctx.model)
                        .symbolProvider(ctx.symbolProvider)
                        .operation(operation)
                        .service(ctx.service)
                        .testCases(testCases)
                        .build()
                        .renderTestClass(testClassName)
                }
            }

            // 2. Generate test cases for each response.
            val responseTests = operation.getTrait()
                ?.getTestCasesFor(AppliesTo.CLIENT)
                ?.filter(::isTestCaseAllowedForRunMode)

            responseTests?.let { testCases ->
                val testOperationName = operation.id.name.replaceFirstChar { c -> c.uppercaseChar() }
                val testClassName = "${testOperationName}ResponseTest"
                val testFilename = "$testClassName.kt"
                ctx.delegator.useTestFileWriter(testFilename, ctx.settings.pkg.name) { writer ->
                    LOGGER.fine("Generating response protocol test cases for ${operation.id}")

                    writer.addImport("${ctx.settings.pkg.name}.model", "*")
                    responseTestBuilder
                        .ctx(ctx)
                        .writer(writer)
                        .model(ctx.model)
                        .symbolProvider(ctx.symbolProvider)
                        .operation(operation)
                        .service(ctx.service)
                        .testCases(testCases)
                        .build()
                        .renderTestClass(testClassName)
                }
            }

            // 3. Generate test cases for each error on each operation.
            for (error in operationIndex.getErrors(operation).filterNot(::serverOnly)) {
                val errorTests = error.getTrait()
                    ?.getTestCasesFor(AppliesTo.CLIENT)
                    ?.filter(::isTestCaseAllowedForRunMode)

                errorTests?.let { testCases ->
                    // use operation name as filename
                    val opName = operation.id.name.replaceFirstChar { c -> c.uppercaseChar() }
                    val testFilename = "${opName}ErrorTest.kt"
                    // multiple error (tests) may be associated with a single operation,
                    // use the operation name + error name as the class name
                    val testClassName = "${opName}${error.defaultName(ctx.service)}Test"
                    ctx.delegator.useTestFileWriter(testFilename, ctx.settings.pkg.name) { writer ->
                        LOGGER.fine("Generating error protocol test cases for ${operation.id}")

                        writer.addImport("${ctx.settings.pkg.name}.model", "*")
                        errorTestBuilder
                            .error(error)
                            .ctx(ctx)
                            .writer(writer)
                            .model(ctx.model)
                            .symbolProvider(ctx.symbolProvider)
                            .operation(operation)
                            .service(ctx.service)
                            .testCases(testCases)
                            .build()
                            .renderTestClass(testClassName)
                    }
                }
            }
        }
    }

    private fun  isTestCaseAllowedForRunMode(test: T): Boolean = when (testDelta.runMode) {
        TestContainmentMode.EXCLUDE_TESTS -> test.protocol == ctx.protocol && test.id !in testDelta.members
        TestContainmentMode.RUN_TESTS -> test.protocol == ctx.protocol && test.id in testDelta.members
    }
}

private fun serverOnly(shape: Shape): Boolean = shape.hasTag("server-only")




© 2015 - 2025 Weber Informatics LLC | Privacy Policy