All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.amazon.smithy.kotlin.codegen.rendering.serde.DeserializeStructGenerator.kt Maven / Gradle / Ivy

/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.kotlin.codegen.rendering.serde

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.SparseTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait

/**
 * Generate deserialization for members bound to the payload.
 *
 * e.g.
 * ```
 * deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
 *    loop@while (true) {
 *        when (findNextFieldIndex()) {
 *             FIELD1_DESCRIPTOR.index -> builder.field1 = deserializeString()
 *             FIELD2_DESCRIPTOR.index -> builder.field2 = deserializeInt()
 *             null -> break@loop
 *             else -> skipValue()
 *         }
 *     }
 * }
 * ```
 */
open class DeserializeStructGenerator(
    protected val ctx: ProtocolGenerator.GenerationContext,
    protected val members: List,
    protected val writer: KotlinWriter,
    protected val defaultTimestampFormat: TimestampFormatTrait.Format,
) {
    /**
     * Enables overriding the codegen output of the final value resulting
     * from the deserialization of a non-primitive type.
     * @param memberShape [MemberShape] associated with entry
     * @param defaultCollectionName the default value produced by this class.
     */
    open fun collectionReturnExpression(memberShape: MemberShape, defaultCollectionName: String): String = defaultCollectionName

    /**
     * Enables overriding of the lhs expression into which a deserialization operation's
     * result is saved.
     */
    open fun deserializationResultName(defaultName: String): String = defaultName

    /**
     * Iterate over all supplied [MemberShape]s to generate serializers.
     */
    open fun render() {
        // inline an empty object descriptor when the struct has no members
        // otherwise use the one generated as part of the companion object
        val objDescriptor = if (members.isNotEmpty()) {
            "OBJ_DESCRIPTOR"
        } else {
            writer.addImport(RuntimeTypes.Serde.SdkObjectDescriptor)
            "SdkObjectDescriptor.build {}"
        }
        writer.withBlock("deserializer.#T($objDescriptor) {", "}", RuntimeTypes.Serde.deserializeStruct) {
            withBlock("loop@while (true) {", "}") {
                withBlock("when (findNextFieldIndex()) {", "}") {
                    members.sortedBy { it.memberName }.forEach { memberShape ->
                        renderMemberShape(memberShape)
                    }
                    write("null -> break@loop")
                    write("else -> skipValue()")
                }
            }
        }
    }

    /**
     * Deserialize top-level members.
     */
    protected open fun renderMemberShape(memberShape: MemberShape) {
        val targetShape = ctx.model.expectShape(memberShape.target)

        when (targetShape.type) {
            ShapeType.LIST,
            ShapeType.SET,
            -> renderListMemberDeserializer(memberShape, targetShape as CollectionShape)

            ShapeType.MAP -> renderMapMemberDeserializer(memberShape, targetShape as MapShape)

            ShapeType.STRUCTURE,
            ShapeType.UNION,
            ShapeType.BLOB,
            ShapeType.BOOLEAN,
            ShapeType.STRING,
            ShapeType.TIMESTAMP,
            ShapeType.BYTE,
            ShapeType.SHORT,
            ShapeType.INTEGER,
            ShapeType.LONG,
            ShapeType.FLOAT,
            ShapeType.DOUBLE,
            ShapeType.DOCUMENT,
            ShapeType.BIG_DECIMAL,
            ShapeType.BIG_INTEGER,
            ShapeType.ENUM,
            ShapeType.INT_ENUM,
            -> renderShapeDeserializer(memberShape)

            else -> error("Unexpected shape type: ${targetShape.type}")
        }
    }

    /**
     * Codegen the deserialization of a primitive value into a response type. Example:
     * ```
     * PAYLOAD_DESCRIPTOR.index -> builder.payload = deserializeString().let { Instant.fromEpochSeconds(it) }
     * ```
     */
    open fun renderShapeDeserializer(memberShape: MemberShape) {
        val memberName = ctx.symbolProvider.toMemberName(memberShape)
        val descriptorName = memberShape.descriptorName()
        val deserialize = deserializerForShape(memberShape)

        writer.write("$descriptorName.index -> builder.$memberName = $deserialize")
    }

    /**
     * Example:
     * ```
     * PAYLOAD_DESCRIPTOR.index -> builder.payload =
     * deserializer.deserializeMap(PAYLOAD_DESCRIPTOR) {
     *      ...
     * }
     */
    protected fun renderMapMemberDeserializer(memberShape: MemberShape, targetShape: MapShape) {
        val nestingLevel = 0
        val memberName = ctx.symbolProvider.toMemberName(memberShape)
        val descriptorName = memberShape.descriptorName()
        val valueCollector = deserializationResultName("builder.$memberName")
        val mutableCollectionName = nestingLevel.variableNameFor(NestedIdentifierType.MAP)
        val collectionReturnExpression = collectionReturnExpression(memberShape, mutableCollectionName)

        writer.write("$descriptorName.index -> $valueCollector = ")
            .indent()
            .withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
                write(
                    "val #L = #T()",
                    mutableCollectionName,
                    KotlinTypes.Collections.mutableMapOf,
                    ctx.symbolProvider.toSymbol(targetShape.value),
                    nullabilitySuffix(targetShape.isSparse),
                )
                withBlock("while (hasNextEntry()) {", "}") {
                    delegateMapDeserialization(memberShape, targetShape, nestingLevel, mutableCollectionName)
                }
                write(collectionReturnExpression)
            }
            .dedent()
    }

    /**
     * Delegates to other functions based on the type of value target of map.
     */
    private fun delegateMapDeserialization(
        rootMemberShape: MemberShape,
        mapShape: MapShape,
        nestingLevel: Int,
        parentMemberName: String,
    ) {
        val elementShape = ctx.model.expectShape(mapShape.value.target)
        val isSparse = mapShape.isSparse

        when (elementShape.type) {
            ShapeType.BOOLEAN,
            ShapeType.STRING,
            ShapeType.BYTE,
            ShapeType.SHORT,
            ShapeType.INTEGER,
            ShapeType.LONG,
            ShapeType.FLOAT,
            ShapeType.DOUBLE,
            ShapeType.BIG_DECIMAL,
            ShapeType.BIG_INTEGER,
            ShapeType.BLOB,
            ShapeType.DOCUMENT,
            ShapeType.TIMESTAMP,
            ShapeType.ENUM,
            ShapeType.INT_ENUM,
            -> renderEntry(elementShape, nestingLevel, isSparse, parentMemberName)

            ShapeType.SET,
            ShapeType.LIST,
            -> renderListEntry(rootMemberShape, elementShape as CollectionShape, nestingLevel, isSparse, parentMemberName)

            ShapeType.MAP -> renderMapEntry(rootMemberShape, elementShape as MapShape, nestingLevel, isSparse, parentMemberName)
            ShapeType.UNION,
            ShapeType.STRUCTURE,
            -> renderNestedStructureEntry(elementShape, nestingLevel, isSparse, parentMemberName)

            else -> error("Unhandled type ${elementShape.type}")
        }
    }

    /**
     * Renders the deserialization of a nested structure contained in a map.  Example:
     *
     * ```
     * val k0 = key()
     * val v0 = if (nextHasValue()) { deserializeString().let { Instant.fromEpochSeconds(it) } } else { deserializeNull(); continue }
     * map0[k0] = v0
     * ```
     */
    private fun renderNestedStructureEntry(
        elementShape: Shape,
        nestingLevel: Int,
        isSparse: Boolean,
        parentMemberName: String,
    ) {
        val deserializerFn = deserializerForShape(elementShape)
        val keyName = nestingLevel.variableNameFor(NestedIdentifierType.KEY)
        val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
        val populateNullValuePostfix = if (isSparse) "" else "; continue"
        if (elementShape.isStructureShape || elementShape.isUnionShape) {
            val symbol = ctx.symbolProvider.toSymbol(elementShape)
            writer.addImport(symbol)
        }

        writer.write("val $keyName = key()")
        writer.write("val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }")
        writer.write("$parentMemberName[$keyName] = $valueName")
    }

    /**
     * Render the deserialization of a map entry.  Example:
     * ```
     * val k0 = key()
     * val v0 = deserializer.deserializeMap(PAYLOAD_C0_DESCRIPTOR) {
     *      val m1 = mutableMapOf()
     *      while (hasNextEntry()) {
     *           ...
     *       }
     *      m1
     * }
     * map0[k0] = v0
     * ```
     */
    private fun renderMapEntry(
        rootMemberShape: MemberShape,
        mapShape: MapShape,
        nestingLevel: Int,
        isSparse: Boolean,
        parentMemberName: String,
    ) {
        val keyName = nestingLevel.variableNameFor(NestedIdentifierType.KEY)
        val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
        val populateNullValuePostfix = if (isSparse) "" else "; continue"
        val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
        val nextNestingLevel = nestingLevel + 1
        val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
        val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)

        writer.write("val $keyName = key()")
        writer.withBlock("val $valueName =", "") {
            withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") {
                withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
                    write(
                        "val #L = #T()",
                        memberName,
                        KotlinTypes.Collections.mutableMapOf,
                        ctx.symbolProvider.toSymbol(mapShape.value),
                        nullabilitySuffix(mapShape.isSparse),
                    )
                    withBlock("while (hasNextEntry()) {", "}") {
                        delegateMapDeserialization(rootMemberShape, mapShape, nextNestingLevel, memberName)
                    }
                    write(collectionReturnExpression)
                }
            }
        }
        writer.write("$parentMemberName[$keyName] = $valueName")
    }

    /**
     * Renders a map value of type list.  Example:
     *
     * ```
     * val k0 = key()
     * val v0 = deserializer.deserializeList(PAYLOAD_C0_DESCRIPTOR) {
     *      val m1 = mutableSetOf()
     *      while (hasNextElement()) {
     *          ...
     *      }
     *      m1
     * }
     * map0[k0] = v0
     * ```
     */
    private fun renderListEntry(
        rootMemberShape: MemberShape,
        collectionShape: CollectionShape,
        nestingLevel: Int,
        isSparse: Boolean,
        parentMemberName: String,
    ) {
        val keyName = nestingLevel.variableNameFor(NestedIdentifierType.KEY)
        val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
        val populateNullValuePostfix = if (isSparse) "" else "; continue"
        val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
        val nextNestingLevel = nestingLevel + 1
        val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
        val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)

        writer.write("val $keyName = key()")
        writer.withBlock("val $valueName =", "") {
            withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") {
                withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
                    write(
                        "val #L = #T<#T#L>()",
                        memberName,
                        KotlinTypes.Collections.mutableListOf,
                        ctx.symbolProvider.toSymbol(collectionShape.member),
                        nullabilitySuffix(collectionShape.isSparse),
                    )
                    withBlock("while (hasNextElement()) {", "}") {
                        delegateListDeserialization(rootMemberShape, collectionShape, nextNestingLevel, memberName)
                    }
                    write(collectionReturnExpression)
                }
            }
        }
        writer.write("$parentMemberName[$keyName] = $valueName")
    }

    /**
     * Example:
     * ```
     * val k0 = key()
     * val el0 = if (nextHasValue()) { deserializeString() } else { deserializeNull(); continue }
     * map0[k0] = el0
     * ```
     */
    private fun renderEntry(elementShape: Shape, nestingLevel: Int, isSparse: Boolean, parentMemberName: String) {
        val deserializerFn = deserializerForShape(elementShape)
        val keyName = nestingLevel.variableNameFor(NestedIdentifierType.KEY)
        val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
        val populateNullValuePostfix = if (isSparse) "" else "; continue"

        writer.write("val $keyName = key()")
        writer.write("val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }")
        writer.write("$parentMemberName[$keyName] = $valueName")
    }

    /**
     * Example:
     * ```
     * PAYLOAD_DESCRIPTOR.index -> builder.payload =
     *  deserializer.deserializeList(PAYLOAD_DESCRIPTOR) {
     *      val col0 = mutableListOf()
     *      while (hasNextElement()) {
     *          ...
     *      }
     *      col0
     *  }
     */
    protected fun renderListMemberDeserializer(memberShape: MemberShape, targetShape: CollectionShape) {
        val nestingLevel = 0
        val memberName = ctx.symbolProvider.toMemberName(memberShape)
        val descriptorName = memberShape.descriptorName()
        val valueCollector = deserializationResultName("builder.$memberName")
        val mutableCollectionName = nestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
        val collectionReturnExpression = collectionReturnExpression(memberShape, mutableCollectionName)

        writer.write("$descriptorName.index -> $valueCollector = ")
            .indent()
            .withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
                write(
                    "val #L = #T<#T#L>()",
                    mutableCollectionName,
                    KotlinTypes.Collections.mutableListOf,
                    ctx.symbolProvider.toSymbol(targetShape.member),
                    nullabilitySuffix(targetShape.isSparse),
                )
                withBlock("while (hasNextElement()) {", "}") {
                    delegateListDeserialization(memberShape, targetShape, nestingLevel, mutableCollectionName)
                }
                write(collectionReturnExpression)
            }
            .dedent()
    }

    /**
     * Delegates to other functions based on the type of element.
     */
    private fun delegateListDeserialization(rootMemberShape: MemberShape, listShape: CollectionShape, nestingLevel: Int, parentMemberName: String) {
        val elementShape = ctx.model.expectShape(listShape.member.target)
        val isSparse = listShape.hasTrait()

        when (elementShape.type) {
            ShapeType.BOOLEAN,
            ShapeType.STRING,
            ShapeType.BYTE,
            ShapeType.SHORT,
            ShapeType.INTEGER,
            ShapeType.LONG,
            ShapeType.FLOAT,
            ShapeType.DOUBLE,
            ShapeType.BIG_DECIMAL,
            ShapeType.BIG_INTEGER,
            ShapeType.BLOB,
            ShapeType.DOCUMENT,
            ShapeType.TIMESTAMP,
            ShapeType.ENUM,
            ShapeType.INT_ENUM,
            -> renderElement(elementShape, nestingLevel, isSparse, parentMemberName)

            ShapeType.LIST,
            ShapeType.SET,
            -> renderListElement(rootMemberShape, elementShape as CollectionShape, nestingLevel, parentMemberName)

            ShapeType.MAP -> renderMapElement(rootMemberShape, elementShape as MapShape, nestingLevel, parentMemberName)
            ShapeType.UNION,
            ShapeType.STRUCTURE,
            -> renderNestedStructureElement(elementShape, nestingLevel, isSparse, parentMemberName)

            else -> error("Unhandled type ${elementShape.type}")
        }
    }

    /**
     * Example:
     * ```
     * val el0 = if (nextHasValue()) { NestedStructureDeserializer().deserialize(deserializer) } else { deserializeNull(); continue }
     * col0.add(el0)
     * ```
     */
    private fun renderNestedStructureElement(elementShape: Shape, nestingLevel: Int, isSparse: Boolean, parentMemberName: String) {
        val deserializer = deserializerForShape(elementShape)
        val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
        val populateNullValuePostfix = if (isSparse) "" else "; continue"
        if (elementShape.isStructureShape || elementShape.isUnionShape) {
            val symbol = ctx.symbolProvider.toSymbol(elementShape)
            writer.addImport(symbol)
        }

        writer.write("val $elementName = if (nextHasValue()) { $deserializer } else { deserializeNull()$populateNullValuePostfix }")
        writer.write("$parentMemberName.add($elementName)")
    }

    /**
     * Renders the deserialization of a list element of type map.
     *
     * Example:
     * ```
     * val el0 = deserializer.deserializeMap(PAYLOAD_C0_DESCRIPTOR) {
     *      val m1 = mutableMapOf()
     *      while (hasNextEntry()) {
     *          ...
     *      }
     *      m1
     * }
     * col0.add(el0)
     * ```
     */
    private fun renderMapElement(
        rootMemberShape: MemberShape,
        mapShape: MapShape,
        nestingLevel: Int,
        parentMapMemberName: String,
    ) {
        val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
        val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
        val nextNestingLevel = nestingLevel + 1
        val mapName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
        val collectionReturnExpression = collectionReturnExpression(rootMemberShape, mapName)

        writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
            write(
                "val #L = #T()",
                mapName,
                KotlinTypes.Collections.mutableMapOf,
                ctx.symbolProvider.toSymbol(mapShape.value),
                nullabilitySuffix(mapShape.isSparse),
            )
            withBlock("while (hasNextEntry()) {", "}") {
                delegateMapDeserialization(rootMemberShape, mapShape, nextNestingLevel, mapName)
            }
            write(collectionReturnExpression)
        }
        writer.write("$parentMapMemberName.add($elementName)")
    }

    /**
     * Render a List element of type List
     *
     * Example:
     *
     * ```
     * val el0 = deserializer.deserializeList(PAYLOAD_C0_DESCRIPTOR) {
     *      val m1 = mutableListOf()
     *      while (hasNextElement()) {
     *          ...
     *      }
     *      m1
     * }
     * col0.add(el0)
     */
    private fun renderListElement(rootMemberShape: MemberShape, elementShape: CollectionShape, nestingLevel: Int, parentListMemberName: String) {
        val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
        val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
        val nextNestingLevel = nestingLevel + 1
        val listName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
        val collectionReturnExpression = collectionReturnExpression(rootMemberShape, listName)

        writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
            write(
                "val #L = #T<#T#L>()",
                listName,
                KotlinTypes.Collections.mutableListOf,
                ctx.symbolProvider.toSymbol(elementShape.member),
                nullabilitySuffix(elementShape.isSparse),
            )
            withBlock("while (hasNextElement()) {", "}") {
                delegateListDeserialization(rootMemberShape, elementShape, nextNestingLevel, listName)
            }
            write(collectionReturnExpression)
        }
        writer.write("$parentListMemberName.add($elementName)")
    }

    /**
     * Example:
     * ```
     * val el0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
     * col0.add(el0)
     * ```
     */
    private fun renderElement(elementShape: Shape, nestingLevel: Int, isSparse: Boolean, listMemberName: String) {
        val deserializerFn = deserializerForShape(elementShape)
        val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
        val populateNullValuePostfix = if (isSparse) "" else "; continue"

        writer.write("val $elementName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }")
        writer.write("$listMemberName.add($elementName)")
    }

    /**
     * Return Kotlin function that deserializes a primitive value.
     * @param shape primitive [Shape] associated with value.
     */
    protected fun deserializerForShape(shape: Shape): String {
        // target shape type to deserialize is either the shape itself or member.target
        val target = shape.targetOrSelf(ctx.model)

        return when {
            target.type == ShapeType.BOOLEAN -> "deserializeBoolean()"
            target.type == ShapeType.BYTE -> "deserializeByte()"
            target.type == ShapeType.SHORT -> "deserializeShort()"
            target.type == ShapeType.INTEGER -> "deserializeInt()"
            target.type == ShapeType.LONG -> "deserializeLong()"
            target.type == ShapeType.FLOAT -> "deserializeFloat()"
            target.type == ShapeType.DOUBLE -> "deserializeDouble()"
            target.type == ShapeType.BIG_INTEGER -> "deserializeBigInteger()"
            target.type == ShapeType.BIG_DECIMAL -> "deserializeBigDecimal()"
            target.type == ShapeType.DOCUMENT -> "deserializeDocument()"

            target.type == ShapeType.BLOB -> {
                writer.addImport(RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes)
                "deserializeString().decodeBase64Bytes()"
            }

            target.type == ShapeType.TIMESTAMP -> {
                writer.addImport(RuntimeTypes.Core.Instant)
                val trait = shape.getTrait() ?: target.getTrait()
                val tsFormat = trait?.format ?: defaultTimestampFormat

                when (tsFormat) {
                    TimestampFormatTrait.Format.EPOCH_SECONDS -> "deserializeString().let { Instant.fromEpochSeconds(it) }"
                    TimestampFormatTrait.Format.DATE_TIME -> "deserializeString().let { Instant.fromIso8601(it) }"
                    TimestampFormatTrait.Format.HTTP_DATE -> "deserializeString().let { Instant.fromRfc5322(it) }"
                    else -> throw CodegenException("unknown timestamp format: $tsFormat")
                }
            }

            target.isStringEnumShape -> {
                val enumSymbol = ctx.symbolProvider.toSymbol(target)
                writer.addImport(enumSymbol)
                "deserializeString().let { ${enumSymbol.name}.fromValue(it) }"
            }
            target.isIntEnumShape -> {
                val enumSymbol = ctx.symbolProvider.toSymbol(target)
                writer.addImport(enumSymbol)
                "deserializeInt().let { ${enumSymbol.name}.fromValue(it) }"
            }

            target.type == ShapeType.STRING -> "deserializeString()"

            target.type == ShapeType.STRUCTURE || target.type == ShapeType.UNION -> {
                val symbol = ctx.symbolProvider.toSymbol(target)
                val deserializerName = symbol.documentDeserializerName()
                "$deserializerName(deserializer)"
            }

            else -> throw CodegenException("unknown deserializer for member: $shape; target: $target")
        }
    }
}

private fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else ""




© 2015 - 2025 Weber Informatics LLC | Privacy Policy