All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.smithy.kotlin.codegen.test.CodegenTestUtils.kt Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */
package software.amazon.smithy.kotlin.codegen.test

import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.build.MockManifest
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.namespace
import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.TimestampFormatTrait

/**
 * This file houses test classes and functions relating to the code generator (protocols, serializers, etc)
 *
 * Items contained here should be relatively high-level, utilizing all members of codegen classes, Smithy, and
 * anything else necessary for test functionality.
 */

/**
 * Container for type instances necessary for tests
 */
data class TestContext(
    val generationCtx: ProtocolGenerator.GenerationContext,
    val manifest: MockManifest,
    val generator: ProtocolGenerator
)

// Execute the codegen and return the generated output
internal fun testRender(
    members: List,
    renderFn: (List, KotlinWriter) -> Unit
): String {
    val writer = KotlinWriter(TestModelDefault.NAMESPACE)
    renderFn(members, writer)
    return writer.toString()
}

// Drive codegen for serialization of a given shape
internal fun codegenSerializerForShape(model: Model, shapeId: String, location: HttpBinding.Location = HttpBinding.Location.DOCUMENT): String {
    val ctx = model.newTestContext()

    val op = ctx.generationCtx.model.expectShape(ShapeId.from(shapeId))
    return testRender(ctx.requestMembers(op, location)) { members, writer ->
        SerializeStructGenerator(
            ctx.generationCtx,
            members,
            writer,
            TimestampFormatTrait.Format.EPOCH_SECONDS
        ).render()
    }
}

// Drive codegen for deserialization of a given shape
internal fun codegenDeserializerForShape(model: Model, shapeId: String, location: HttpBinding.Location = HttpBinding.Location.DOCUMENT): String {
    val ctx = model.newTestContext()
    val op = ctx.generationCtx.model.expectShape(ShapeId.from(shapeId))

    return testRender(ctx.responseMembers(op, location)) { members, writer ->
        DeserializeStructGenerator(
            ctx.generationCtx,
            members,
            writer,
            TimestampFormatTrait.Format.EPOCH_SECONDS
        ).render()
    }
}

// Drive codegen for serializer of a union of a given shape
internal fun codegenUnionSerializerForShape(model: Model, shapeId: String): String {
    val ctx = model.newTestContext()

    val testMembers = when (val shape = ctx.generationCtx.model.expectShape(ShapeId.from(shapeId))) {
        is OperationShape -> {
            val bindingIndex = HttpBindingIndex.of(ctx.generationCtx.model)
            val requestBindings = bindingIndex.getRequestBindings(shape)
            val unionShape = ctx.generationCtx.model.expectShape(requestBindings.values.first().member.target)
            unionShape.members().toList().sortedBy { it.memberName }
        }
        is StructureShape -> {
            shape.members().toList().sortedBy { it.memberName }
        }
        else -> throw RuntimeException("unknown conversion for $shapeId")
    }

    return testRender(testMembers) { members, writer ->
        SerializeUnionGenerator(
            ctx.generationCtx,
            members,
            writer,
            TimestampFormatTrait.Format.EPOCH_SECONDS
        ).render()
    }
}

// Retrieves Response Document members for HttpTrait-enabled protocols
internal fun TestContext.responseMembers(shape: Shape, location: HttpBinding.Location = HttpBinding.Location.DOCUMENT): List {
    val bindingIndex = HttpBindingIndex.of(this.generationCtx.model)
    val responseBindings = bindingIndex.getResponseBindings(shape)

    return responseBindings.values
        .filter { it.location == location }
        .sortedBy { it.memberName }
        .map { it.member }
}

// Retrieves Request Document members for HttpTrait-enabled protocols
internal fun TestContext.requestMembers(shape: Shape, location: HttpBinding.Location = HttpBinding.Location.DOCUMENT): List {
    val bindingIndex = HttpBindingIndex.of(this.generationCtx.model)
    val responseBindings = bindingIndex.getRequestBindings(shape)

    return responseBindings.values
        .filter { it.location == location }
        .sortedBy { it.memberName }
        .map { it.member }
}

internal fun TestContext.toGenerationContext(): GenerationContext =
    GenerationContext(generationCtx.model, generationCtx.symbolProvider, generationCtx.settings, generator)

fun  TestContext.toRenderingContext(writer: KotlinWriter, forShape: T? = null): RenderingContext =
    toGenerationContext().toRenderingContext(writer, forShape)

// A HttpProtocolClientGenerator for testing
internal class TestProtocolClientGenerator(
    ctx: ProtocolGenerator.GenerationContext,
    features: List,
    httpBindingResolver: HttpBindingResolver
) : HttpProtocolClientGenerator(ctx, features, httpBindingResolver)

// A HttpBindingProtocolGenerator for testing (nothing is rendered for serializing/deserializing payload bodies)
internal class MockHttpProtocolGenerator : HttpBindingProtocolGenerator() {
    override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS
    override fun getProtocolHttpBindingResolver(model: Model, serviceShape: ServiceShape): HttpBindingResolver =
        HttpTraitResolver(model, serviceShape, "application/json")

    override val protocol: ShapeId = RestJson1Trait.ID

    override fun generateProtocolUnitTests(ctx: ProtocolGenerator.GenerationContext) {}

    override fun getHttpProtocolClientGenerator(ctx: ProtocolGenerator.GenerationContext): HttpProtocolClientGenerator =
        TestProtocolClientGenerator(ctx, getHttpMiddleware(ctx), getProtocolHttpBindingResolver(ctx.model, ctx.service))

    override fun renderSerializeOperationBody(
        ctx: ProtocolGenerator.GenerationContext,
        op: OperationShape,
        writer: KotlinWriter
    ) {
    }

    override fun renderDeserializeOperationBody(
        ctx: ProtocolGenerator.GenerationContext,
        op: OperationShape,
        writer: KotlinWriter
    ) {
    }

    override fun renderSerializeDocumentBody(
        ctx: ProtocolGenerator.GenerationContext,
        shape: Shape,
        writer: KotlinWriter
    ) {
    }

    override fun renderDeserializeDocumentBody(
        ctx: ProtocolGenerator.GenerationContext,
        shape: Shape,
        writer: KotlinWriter
    ) {
    }

    override fun renderDeserializeException(
        ctx: ProtocolGenerator.GenerationContext,
        shape: Shape,
        writer: KotlinWriter
    ) {
    }

    override fun renderThrowOperationError(
        ctx: ProtocolGenerator.GenerationContext,
        op: OperationShape,
        writer: KotlinWriter
    ) {
    }
}

// Create a test harness with all necessary codegen types
fun codegenTestHarnessForModelSnippet(
    generator: ProtocolGenerator,
    namespace: String = TestModelDefault.NAMESPACE,
    serviceName: String = TestModelDefault.SERVICE_NAME,
    operations: List,
    snippet: () -> String
): CodegenTestHarness {
    val protocol = generator.protocol.name
    val model = snippet().generateTestModel(protocol, namespace, serviceName, operations)
    val (ctx, manifest, _) = model.newTestContext(serviceName = serviceName, packageName = namespace, generator = generator)

    return CodegenTestHarness(ctx, manifest, generator, namespace, serviceName, protocol)
}

/**
 * Contains references to all types necessary to drive and validate codegen.
 */
data class CodegenTestHarness(
    val generationCtx: ProtocolGenerator.GenerationContext,
    val manifest: MockManifest,
    val generator: ProtocolGenerator,
    val namespace: String,
    val serviceName: String,
    val protocol: String
)

// Drive de/serializer codegen and return results in map indexed by filename.
fun CodegenTestHarness.generateDeSerializers(): Map {
    generator.generateSerializers(generationCtx)
    generator.generateDeserializers(generationCtx)
    generationCtx.delegator.flushWriters()
    return manifest.files.associate { path -> path.fileName.toString() to manifest.expectFileString(path) }
}

// Create and use a writer to drive codegen from a function taking a writer.
// Strip off comment and package preamble.
fun generateCode(generator: (KotlinWriter) -> Unit): String {
    val packageDeclaration = "some-unique-thing-that-will-never-be-codegened"
    val writer = KotlinWriter(packageDeclaration)
    generator.invoke(writer)
    val rawCodegen = writer.toString()
    return rawCodegen.substring(rawCodegen.indexOf(packageDeclaration) + packageDeclaration.length).trim()
}

fun KotlinCodegenPlugin.Companion.createSymbolProvider(model: Model, rootNamespace: String = TestModelDefault.NAMESPACE, sdkId: String = TestModelDefault.SDK_ID, serviceName: String = TestModelDefault.SERVICE_NAME): SymbolProvider {
    val settings = model.defaultSettings(serviceName = serviceName, packageName = rootNamespace, sdkId = sdkId)
    return createSymbolProvider(model, settings)
}

/**
 * create a new [KotlinWriter] using the test context package name
 */
fun TestContext.newWriter(): KotlinWriter = KotlinWriter(generationCtx.settings.pkg.name)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy