All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.smithy.kotlin.codegen.rendering.protocol.HttpProtocolUnitTestRequestGenerator.kt Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.kotlin.codegen.rendering.protocol

import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.integration.SectionId
import software.amazon.smithy.kotlin.codegen.integration.SectionKey
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.model.hasStreamingMember
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.rendering.ShapeValueGenerator
import software.amazon.smithy.kotlin.codegen.rendering.endpoints.EndpointProviderGenerator
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait

/**
 * Generates HTTP protocol unit tests for `httpRequestTest` cases
 */
open class HttpProtocolUnitTestRequestGenerator protected constructor(builder: Builder) :
    HttpProtocolUnitTestGenerator(builder) {

    object ConfigureServiceClient : SectionId {
        val Test: SectionKey = SectionKey("Test")
        val Context: SectionKey = SectionKey("Context")
        val Operation: SectionKey = SectionKey("Operation")
    }

    override fun openTestFunctionBlock(): String = "= httpRequestTest {"

    override fun renderTestBody(test: HttpRequestTestCase) {
        writer.addImport(KotlinDependency.SMITHY_TEST.namespace, "*")
        writer.addImport(KotlinDependency.HTTP.namespace, "HttpMethod")
        writer.addImport(KotlinDependency.KOTLIN_TEST.namespace, "*")
        writer.dependencies.addAll(KotlinDependency.SMITHY_TEST.dependencies)
        writer.dependencies.addAll(KotlinDependency.KOTLIN_TEST.dependencies)
        renderExpectedBlock(test)
        writer.write("")
        renderOperationBlock(test)
    }

    private fun renderExpectedBlock(test: HttpRequestTestCase) {
        writer.openBlock("expected {")
            .call {
                writer
                    .write("method = HttpMethod.${test.method.uppercase()}")
                    .write("uri = #S", test.uri)
                    .call {
                        test.resolvedHost.ifPresent { expectedHost ->
                            writer.write("resolvedHost = #S", expectedHost)
                        }
                    }
                    .call { renderExpectedQueryParams(test) }
                    .call { renderExpectedListOfParams("forbiddenQueryParams", test.forbidQueryParams) }
                    .call { renderExpectedListOfParams("requiredQueryParams", test.requireQueryParams) }
                    .call { renderExpectedHeaders(test) }
                    .call { renderExpectedListOfParams("forbiddenHeaders", test.forbidHeaders) }
                    .call { renderExpectedListOfParams("requiredHeaders", test.requireHeaders) }
                    .call {
                        test.body.ifPresent { body ->
                            var bodyMediaType = ""
                            if (test.bodyMediaType.isPresent) {
                                bodyMediaType = test.bodyMediaType.get()
                                writer.write("bodyMediaType = #S", bodyMediaType)
                            }

                            if (body.isBlank()) {
                                writer.write("bodyAssert = ::assertEmptyBody")
                            } else {
                                writer.write("body = \"\"\"#L\"\"\"", body)

                                val compareFunc = when (bodyMediaType.lowercase()) {
                                    "application/json" -> "::assertJsonBodiesEqual"
                                    "application/xml" -> "::assertXmlBodiesEqual"
                                    "application/x-www-form-urlencoded" -> "::assertFormUrlBodiesEqual"
                                    // compare reader bytes
                                    else -> "::assertBytesEqual"
                                }
                                writer.write("bodyAssert = $compareFunc")
                            }
                        }
                    }
                    .write("")
            }
            .closeBlock("}")
    }

    private fun renderOperationBlock(test: HttpRequestTestCase) {
        writer.openBlock("operation { mockEngine ->")
            .call {
                var inputParamName = ""
                operation.input.ifPresent {
                    inputParamName = "input"
                    val inputShape = model.expectShape(it)

                    // invoke the DSL builder for the input type
                    writer.writeInline("\nval input = ")
                        .indent()
                        .call {
                            ShapeValueGenerator(model, symbolProvider).instantiateShapeInline(writer, inputShape, test.params)
                        }
                        .dedent()
                        .write("")
                }

                val service = symbolProvider.toSymbol(serviceShape)
                writer.openBlock("val service = #L {", service.name)
                    .call { renderConfigureServiceClient(test) }
                    .closeBlock("}")

                // last statement should be service invoke
                val opName = operation.defaultName()

                val isStreamingResponse = model.expectShape(operation.output.get()).hasStreamingMember(model)

                // streaming responses have a different operation signature that require a block to be passed to
                // process the response - add an empty block if necessary
                val block = if (isStreamingResponse) "{}" else ""
                writer.write("service.#L(#L)$block", opName, inputParamName)
            }
            .closeBlock("}")
    }

    /**
     * Configure the service client before executing the request for the test. By default this function
     * configures a mock HttpClientEngine and an idempotency token generator appropriate for protocol tests.
     */
    open fun renderConfigureServiceClient(test: HttpRequestTestCase) {
        writer.declareSection(
            ConfigureServiceClient,
            mapOf(
                ConfigureServiceClient.Test to test,
                ConfigureServiceClient.Context to ctx,
                ConfigureServiceClient.Operation to operation,
            ),
        ) {
            writer.write("httpClient = mockEngine")
            if (idempotentFieldsInModel) {
                writer.write("idempotencyTokenProvider = #T { \"00000000-0000-4000-8000-000000000000\" }", RuntimeTypes.SmithyClient.IdempotencyTokenProvider)
            }

            val hostname = test.host.getOrNull() ?: "hostname"

            // if the model doesn't have endpoint rules we have to fill it in with a default
            if (!serviceShape.hasTrait()) {
                writer.write(
                    "endpointProvider = #T { #T(#S) }",
                    EndpointProviderGenerator.getSymbol(ctx.settings),
                    RuntimeTypes.SmithyClient.Endpoints.Endpoint,
                    "https://$hostname",
                )
            }
        }
    }

    private fun renderExpectedQueryParams(test: HttpRequestTestCase) {
        if (test.queryParams.isEmpty()) return

        val queryParams = test.queryParams
            .map {
                val kvPair = it.split("=", limit = 2)
                val value = kvPair.getOrNull(1) ?: ""
                Pair(kvPair[0], value)
            }

        writer.openBlock("queryParams = listOf(")
            .call {
                queryParams.forEachIndexed { idx, (key, value) ->
                    val suffix = if (idx < queryParams.size - 1) "," else ""
                    writer.write("#S to #S$suffix", key, value)
                }
            }
            .closeBlock(")")
    }

    private fun renderExpectedHeaders(test: HttpRequestTestCase) {
        if (test.headers.isEmpty()) return
        writer.openBlock("headers = mapOf(")
            .call {
                for ((idx, hdr) in test.headers.entries.withIndex()) {
                    val suffix = if (idx < test.headers.size - 1) "," else ""
                    writer.write("#S to #S$suffix", hdr.key, hdr.value)
                }
            }
            .closeBlock(")")
    }

    private fun renderExpectedListOfParams(name: String, params: List) {
        if (params.isEmpty()) return
        val joined = params.joinToString(
            separator = ",",
            transform = { "\"$it\"" },
        )
        writer.write("$name = listOf($joined)")
    }

    open class Builder : HttpProtocolUnitTestGenerator.Builder() {
        override fun build(): HttpProtocolUnitTestGenerator =
            HttpProtocolUnitTestRequestGenerator(this)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy