
software.amazon.smithy.kotlin.codegen.rendering.UnionGenerator.kt Maven / Gradle / Ivy
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.kotlin.codegen.rendering
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.filterEventStreamErrors
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.isNullable
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.model.traits.StreamingTrait
/**
* Renders Smithy union shapes
*/
class UnionGenerator(
val model: Model,
private val symbolProvider: SymbolProvider,
private val writer: KotlinWriter,
private val shape: UnionShape,
) {
val symbol: Symbol = symbolProvider.toSymbol(shape)
/**
* Renders a Smithy union to a Kotlin sealed class
*/
fun render() {
check(!shape.allMembers.values.any { memberShape -> memberShape.memberName.equals("SdkUnknown", true) }) { "generating SdkUnknown would cause duplicate variant for union shape: $shape" }
writer.renderDocumentation(shape)
writer.renderAnnotations(shape)
writer.openBlock("public sealed class #T {", symbol)
// event streams (@streaming union) MAY have variants that target errors.
// These errors if encountered on the stream will be thrown as an exception rather
// than showing up as one of the possible events the consumer will see on the stream (Flow).
val members = shape.filterEventStreamErrors(model)
members.sortedBy { it.memberName }.forEach {
writer.renderMemberDocumentation(model, it)
writer.renderAnnotations(it)
val variantName = it.unionVariantName()
val variantSymbol = symbolProvider.toSymbol(it)
writer.withBlock("public data class #L(val value: #Q) : #Q() {", "}", variantName, variantSymbol, symbol) {
if (model.expectShape(it.target).type == ShapeType.BLOB) {
renderHashCode(model, listOf(it), symbolProvider, this)
renderEquals(model, listOf(it), variantName, this)
}
renderToString()
}
writer.write("")
}
// generate the unknown which will always be last
writer.withBlock("public object SdkUnknown : #Q() {", "}", symbol) {
renderToString()
}
members.sortedBy { it.memberName }.forEach {
val variantName = it.unionVariantName()
val variantSymbol = symbolProvider.toSymbol(it)
writer.write("")
writer.dokka {
write(
"""
Casts this [#T] as a [#L] and retrieves its [#Q] value. Throws an exception if the [#T] is not a
[#L].
""".trimIndent(),
symbol,
variantName,
variantSymbol,
symbol,
variantName,
)
}
writer.write("public fun as#L(): #Q = (this as #T.#L).value", variantName, variantSymbol, symbol, variantName)
writer.write("")
writer.dokka {
write(
"Casts this [#T] as a [#L] and retrieves its [#Q] value. Returns null if the [#T] is not a [#L].",
symbol,
variantName,
variantSymbol,
symbol,
variantName,
)
}
writer.write(
"public fun as#LOrNull(): #Q? = (this as? #T.#L)?.value",
variantName,
variantSymbol,
symbol,
variantName,
)
}
writer.closeBlock("}").write("")
}
// generate a `hashCode()` implementation
private fun renderHashCode(
model: Model,
sortedMembers: List,
symbolProvider: SymbolProvider,
writer: KotlinWriter,
) {
writer.write("")
writer.withBlock("override fun hashCode(): #Q {", "}", KotlinTypes.Int) {
write("return value#L", selectHashFunctionForShape(model, sortedMembers[0], symbolProvider))
}
}
// Return the appropriate hashCode fragment based on ShapeID of member target.
private fun selectHashFunctionForShape(model: Model, member: MemberShape, symbolProvider: SymbolProvider): String {
val targetShape = model.expectShape(member.target)
// also available already in the byMember map
val targetSymbol = symbolProvider.toSymbol(targetShape)
return when (targetShape.type) {
ShapeType.INTEGER ->
when (targetSymbol.isNullable) {
true -> " ?: 0"
else -> ""
}
ShapeType.BYTE ->
when (targetSymbol.isNullable) {
true -> ".toInt() ?: 0"
else -> ".toInt()"
}
ShapeType.BLOB ->
if (targetShape.hasTrait()) {
// ByteStream
".hashCode() ?: 0"
} else {
// ByteArray
".contentHashCode()"
}
else ->
when (targetSymbol.isNullable) {
true -> ".hashCode() ?: 0"
else -> ".hashCode()"
}
}
}
// generate a `equals()` implementation
private fun renderEquals(model: Model, sortedMembers: List, typeName: String, writer: KotlinWriter) {
writer.write("")
writer.withBlock("override fun equals(other: #Q?): #Q {", "}", KotlinTypes.Any, KotlinTypes.Boolean) {
write("if (this === other) return true")
write("if (other == null || this::class != other::class) return false")
write("")
write("other as $typeName")
write("")
for (memberShape in sortedMembers) {
val target = model.expectShape(memberShape.target)
val memberName = "value"
if (target is BlobShape && !target.hasTrait()) {
writer.write("if (!#1L.contentEquals(other.#1L)) return false", memberName)
} else {
write("if (#1L != other.#1L) return false", memberName)
}
}
write("")
write("return true")
}
}
private fun renderToString() {
if (shape.hasTrait()) {
writer.write(
"""override fun toString(): #Q = "#T(*** Sensitive Data Redacted ***)"""",
KotlinTypes.String,
symbol,
)
} // else just use regular data class toString implementation
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy