
software.amazon.smithy.kotlin.codegen.rendering.waiters.KotlinJmespathExpressionVisitor.kt Maven / Gradle / Ivy
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.kotlin.codegen.rendering.waiters
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.jmespath.ExpressionVisitor
import software.amazon.smithy.jmespath.JmespathExpression
import software.amazon.smithy.jmespath.RuntimeType
import software.amazon.smithy.jmespath.ast.*
import software.amazon.smithy.kotlin.codegen.core.CodegenContext
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.isEnum
import software.amazon.smithy.kotlin.codegen.model.targetOrSelf
import software.amazon.smithy.kotlin.codegen.model.traits.OperationInput
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.kotlin.codegen.utils.toCamelCase
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
private val suffixSequence = sequenceOf("") + generateSequence(2) { it + 1 }.map(Int::toString) // "", "2", "3", etc.
/**
* An [ExpressionVisitor] used for traversing a JMESPath expression to generate code for traversing an equivalent
* modeled object. This visitor is passed to [JmespathExpression.accept], at which point specific expression methods
* will be invoked.
*
* Each step of the traversal returns a [VisitedExpression]. Any intermediate code required to express the query is
* written immediately to the provided writer.
*
* @param ctx The surrounding [CodegenContext].
* @param writer The [KotlinWriter] to generate code into.
* @param shape The modeled [Shape] on which this JMESPath expression is operating.
*/
class KotlinJmespathExpressionVisitor(
val ctx: CodegenContext,
val writer: KotlinWriter,
shape: Shape,
) : ExpressionVisitor {
private val tempVars = mutableSetOf()
// tracks the current shape on which the visitor is operating
private val shapeCursor = ArrayDeque(listOf(shape))
private val currentShape: Shape
get() = shapeCursor.last()
// traverses an independent expression (one whose resolved scope does not persist in the outer evaluation)
private fun acceptSubexpression(expr: JmespathExpression): VisitedExpression {
val pos = shapeCursor.size
val out = expr.accept(this)
val diff = shapeCursor.size - pos
repeat(diff) { shapeCursor.removeLast() } // reset the shape cursor
return out
}
private fun addTempVar(preferredName: String, codegen: String): String {
val name = bestTempVarName(preferredName)
writer.write("val #L = #L", name, codegen)
return name
}
private fun bestTempVarName(preferredName: String): String =
suffixSequence.map { "$preferredName$it" }.first(tempVars::add)
private fun flatMappingBlock(right: JmespathExpression, leftName: String, leftShape: Shape, innerShape: Shape?): VisitedExpression {
if (right is CurrentExpression) return VisitedExpression(leftName, leftShape) // nothing to map
val outerName = bestTempVarName("projection")
val flatMapExpr = ensureNullGuard(leftShape, "flatMap")
writer.openBlock("val #L = #L#L {", outerName, leftName, flatMapExpr)
shapeCursor.addLast(innerShape?.targetMemberOrSelf ?: leftShape.targetMemberOrSelf)
val innerResult = acceptSubexpression(right)
shapeCursor.removeLast()
val innerCollector = when (right) {
is MultiSelectListExpression, is MultiSelectHashExpression -> innerResult.identifier // Already a list
else -> "listOfNotNull(${innerResult.identifier})"
}
writer.write(innerCollector)
writer.closeBlock("}")
return VisitedExpression(outerName, leftShape, innerResult.shape)
}
private data class SubFieldData(val name: String, val codegen: String, val member: Shape?)
private fun subfieldLogic(expression: FieldExpression, parentName: String, isObject: Boolean = false): SubFieldData {
val member = currentShape.targetOrSelf(ctx.model).getMember(expression.name).getOrNull()
val name = expression.name.toCamelCase()
// User created objects are represented as hash maps in code-gen and are marked by `isObject`
val nameExpr = if (isObject) "[\"$name\"]" else ensureNullGuard(currentShape, name)
val unwrapExpr = member?.let {
val memberTarget = ctx.model.expectShape(member.target)
when {
memberTarget.isEnum -> "value"
memberTarget.isEnumList -> "map { it.value }"
memberTarget.isEnumMap -> "mapValues { (_, v) -> v.value }"
memberTarget.isBlobShape || memberTarget.isTimestampShape ->
throw CodegenException("acceptor behavior for shape type ${memberTarget.type} is undefined")
else -> null
}
}
val codegen = buildString {
append("$parentName$nameExpr")
unwrapExpr?.let { append(ensureNullGuard(member, it)) }
}
member?.let { shapeCursor.addLast(it) }
return SubFieldData(name, codegen, member)
}
private fun subfield(expression: FieldExpression, parentName: String, isObject: Boolean = false): VisitedExpression {
val (name, codegen, member) = subfieldLogic(expression, parentName, isObject)
return VisitedExpression(addTempVar(name, codegen), member, nullable = currentShape.isNullable)
}
private fun subfieldCodegen(expression: FieldExpression, parentName: String, isObject: Boolean = false): String =
subfieldLogic(expression, parentName, isObject).codegen
override fun visitAnd(expression: AndExpression): VisitedExpression {
writer.addImport(RuntimeTypes.Core.Utils.truthiness)
val left = acceptSubexpression(expression.left)
val leftTruthinessName = addTempVar("${left.identifier}Truthiness", "truthiness(${left.identifier})")
val right = acceptSubexpression(expression.right)
val ident = addTempVar("and", "if ($leftTruthinessName) ${right.identifier} else ${left.identifier}")
return VisitedExpression(ident)
}
override fun visitComparator(expression: ComparatorExpression): VisitedExpression {
val left = acceptSubexpression(expression.left)
val right = acceptSubexpression(expression.right)
val codegen = buildString {
val nullables = buildList {
if (left.shape?.isNullable == true || left.nullable) add("${left.identifier} == null")
if (right.shape?.isNullable == true || right.nullable) add("${right.identifier} == null")
}
if (nullables.isNotEmpty()) {
val isNullExpr = nullables.joinToString(" || ")
append("if ($isNullExpr) null else ")
}
val comparatorExpr = ".compareTo(${right.identifier}) ${expression.comparator} 0"
append("${left.identifier}$comparatorExpr")
}
val identifier = addTempVar("comparison", codegen)
return VisitedExpression(identifier)
}
override fun visitCurrentNode(expression: CurrentExpression): VisitedExpression {
throw CodegenException("Unexpected current expression outside of flatten expression: $expression")
}
override fun visitExpressionType(expression: ExpressionTypeExpression): VisitedExpression {
throw CodegenException("ExpressionTypeExpression is unsupported")
}
override fun visitField(expression: FieldExpression): VisitedExpression = subfield(expression, "it")
override fun visitFilterProjection(expression: FilterProjectionExpression): VisitedExpression {
val left = expression.left.accept(this)
requireNotNull(left.shape) { "filter projection is operating on nothing?" }
val filteredName = bestTempVarName("${left.identifier}Filtered")
val filterExpr = ensureNullGuard(left.shape, "filter")
writer.withBlock("val #L = #L#L {", "}", filteredName, left.identifier, filterExpr) {
shapeCursor.addLast(left.shape.targetMemberOrSelf)
val comparison = acceptSubexpression(expression.comparison)
shapeCursor.removeLast()
write("#L == true", comparison.identifier)
}
return flatMappingBlock(expression.right, filteredName, left.shape, left.projected)
}
override fun visitFlatten(expression: FlattenExpression): VisitedExpression {
writer.addImport(RuntimeTypes.Core.Utils.flattenIfPossible)
val inner = expression.expression.accept(this)
val flattenExpr = ensureNullGuard(currentShape, "flattenIfPossible()")
val ident = addTempVar("${inner.identifier}OrEmpty", "${inner.identifier}$flattenExpr")
return VisitedExpression(ident, currentShape, inner.projected)
}
private fun FunctionExpression.singleArg(): VisitedExpression =
acceptSubexpression(this.arguments[0])
private fun FunctionExpression.twoArgs(): Pair =
acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1])
private fun FunctionExpression.args(): List =
this.arguments.map { acceptSubexpression(it) }
private fun VisitedExpression.dotFunction(
expression: FunctionExpression,
expr: String,
elvisExpr: String? = null,
isObject: Boolean = false,
ensureNullGuard: Boolean = true,
): VisitedExpression {
val dotFunctionExpr = if (ensureNullGuard) ensureNullGuard(shape, expr, elvisExpr) else ".$expr"
val ident = addTempVar(expression.name.toCamelCase(), "$identifier$dotFunctionExpr")
shape?.let { shapeCursor.addLast(shape) }
return VisitedExpression(ident, shape, isObject = isObject)
}
override fun visitFunction(expression: FunctionExpression): VisitedExpression = when (expression.name) {
"contains" -> {
val (subject, search) = expression.twoArgs()
subject.dotFunction(expression, "contains(${search.identifier})", "false")
}
"length" -> {
writer.addImport(RuntimeTypes.Core.Utils.length)
val subject = expression.singleArg()
subject.dotFunction(expression, "length", "0")
}
"abs", "floor", "ceil" -> {
val number = expression.singleArg()
number.dotFunction(expression, "let { kotlin.math.${expression.name}(it.toDouble()) }")
}
"sum" -> {
val numbers = expression.singleArg()
numbers.dotFunction(expression, "sum()")
}
"avg" -> {
val numbers = expression.singleArg()
numbers.dotFunction(expression, "average()")
}
"join" -> {
val (glue, list) = expression.twoArgs()
list.dotFunction(expression, "joinToString(${glue.identifier})")
}
"starts_with" -> {
val (subject, prefix) = expression.twoArgs()
subject.dotFunction(expression, "startsWith(${prefix.identifier})")
}
"ends_with" -> {
val (subject, suffix) = expression.twoArgs()
subject.dotFunction(expression, "endsWith(${suffix.identifier})")
}
"keys" -> {
val obj = expression.singleArg()
VisitedExpression(addTempVar("keys", obj.getKeys()))
}
"values" -> {
val obj = expression.singleArg()
VisitedExpression(addTempVar("values", obj.getValues()))
}
"merge" -> {
val objects = expression.args()
VisitedExpression(addTempVar("merge", objects.mergeProperties()), isObject = true)
}
"max" -> {
val list = expression.singleArg()
list.dotFunction(expression, "maxOrNull()")
}
"min" -> {
val list = expression.singleArg()
list.dotFunction(expression, "minOrNull()")
}
"reverse" -> {
val listOrString = expression.singleArg()
listOrString.dotFunction(expression, "reversed()")
}
"not_null" -> {
val args = expression.args()
VisitedExpression(addTempVar("notNull", args.getNotNull()))
}
"to_array" -> {
val arg = expression.singleArg()
VisitedExpression(addTempVar("toArray", arg.toArray()))
}
"to_string" -> {
val arg = expression.singleArg()
VisitedExpression(addTempVar("toString", arg.jmesPathToString()))
}
"to_number" -> {
writer.addImport(RuntimeTypes.Core.Utils.toNumber)
val arg = expression.singleArg()
arg.dotFunction(expression, "toNumber()")
}
"type" -> {
writer.addImport(RuntimeTypes.Core.Utils.type)
val arg = expression.singleArg()
arg.dotFunction(expression, "type()", ensureNullGuard = false)
}
"sort" -> {
val arg = expression.singleArg()
arg.dotFunction(expression, "sorted()")
}
"sort_by" -> {
val list = expression.arguments[0].accept(this)
val expressionValue = expression.arguments[1]
list.applyFunction(expression.name.toCamelCase(), "sortedBy", expressionValue)
}
"max_by" -> {
val list = expression.arguments[0].accept(this)
val expressionValue = expression.arguments[1]
list.applyFunction(expression.name.toCamelCase(), "maxBy", expressionValue)
}
"min_by" -> {
val list = expression.arguments[0].accept(this)
val expressionValue = expression.arguments[1]
list.applyFunction(expression.name.toCamelCase(), "minBy", expressionValue)
}
"map" -> {
val list = expression.arguments[1].accept(this)
val expressionValue = expression.arguments[0]
list.applyFunction(expression.name.toCamelCase(), "map", expressionValue)
}
else -> throw CodegenException("Unknown function type in $expression")
}
override fun visitIndex(expression: IndexExpression): VisitedExpression {
throw CodegenException("IndexExpression is unsupported")
}
override fun visitLiteral(expression: LiteralExpression): VisitedExpression {
val ident = when (expression.type) {
RuntimeType.STRING -> addTempVar("string", expression.expectStringValue().dq())
RuntimeType.NUMBER -> addTempVar("number", expression.expectNumberValue().toString())
RuntimeType.BOOLEAN -> addTempVar("bool", expression.expectBooleanValue().toString())
RuntimeType.NULL -> "null"
else -> throw CodegenException("Expression type $expression is unsupported")
}
return VisitedExpression(ident)
}
override fun visitMultiSelectHash(expression: MultiSelectHashExpression): VisitedExpression {
val properties = expression.expressions.keys.joinToString { "val $it: T" }
writer.write("class Selection($properties)")
val listName = bestTempVarName("multiSelect")
writer.withBlock("val $listName = listOfNotNull(", ")") {
withBlock("run {", "}") {
val identifiers = expression.expressions.toList().joinToString { addTempVar(it.first, it.second.accept(this@KotlinJmespathExpressionVisitor).identifier) }
write("Selection($identifiers)")
}
}
return VisitedExpression(listName, currentShape)
}
override fun visitMultiSelectList(expression: MultiSelectListExpression): VisitedExpression {
val listName = bestTempVarName("multiSelect")
writer.openBlock("val #L = listOfNotNull(", listName)
writer.openBlock("listOfNotNull(")
expression.expressions.forEach {
writer.openBlock("run {")
val inner = it.accept(this)
writer.write(inner.identifier)
writer.closeBlock("},")
}
writer.closeBlock(")")
writer.closeBlock(")")
return VisitedExpression(listName, currentShape)
}
override fun visitNot(expression: NotExpression): VisitedExpression {
writer.addImport(RuntimeTypes.Core.Utils.truthiness)
val operand = acceptSubexpression(expression.expression)
val truthinessName = addTempVar("${operand.identifier}Truthiness", "truthiness(${operand.identifier})")
val notName = "not${operand.identifier.replaceFirstChar(Char::uppercaseChar)}"
val ident = addTempVar(notName, "!$truthinessName")
return VisitedExpression(ident)
}
override fun visitObjectProjection(expression: ObjectProjectionExpression): VisitedExpression {
val left = acceptSubexpression(expression.left)
requireNotNull(left.shape) { "object projection is operating on nothing?" }
val valuesExpr = ensureNullGuard(left.shape, "values")
val valuesName = addTempVar("${left.identifier}Values", "${left.identifier}$valuesExpr")
return flatMappingBlock(expression.right, valuesName, left.shape, left.projected)
}
override fun visitOr(expression: OrExpression): VisitedExpression {
writer.addImport(RuntimeTypes.Core.Utils.truthiness)
val left = acceptSubexpression(expression.left)
val leftTruthinessName = addTempVar("${left.identifier}Truthiness", "truthiness(${left.identifier})")
val right = acceptSubexpression(expression.right)
val ident = addTempVar("or", "if ($leftTruthinessName) ${left.identifier} else ${right.identifier}")
return VisitedExpression(ident)
}
override fun visitProjection(expression: ProjectionExpression): VisitedExpression {
val left = expression.left.accept(this)
requireNotNull(left.shape) { "projection is operating on nothing?" }
return flatMappingBlock(expression.right, left.identifier, left.shape, left.projected)
}
private fun projection(expression: ProjectionExpression, parentName: String): VisitedExpression {
val left = when (expression.left) {
is SliceExpression -> slice(expression.left as SliceExpression, parentName)
else -> expression.left.accept(this)
}
requireNotNull(left.shape) { "projection is operating on nothing" }
return flatMappingBlock(expression.right, left.identifier, left.shape, left.projected)
}
override fun visitSlice(expression: SliceExpression): VisitedExpression {
throw CodegenException("SliceExpression is unsupported")
}
private fun slice(expression: SliceExpression, parentName: String): VisitedExpression {
val startIndex = if (!expression.start.isPresent) {
"0"
} else {
if (expression.start.asInt < 0) "$parentName.size${expression.start.asInt}" else expression.start.asInt
}
val stopIndex = if (!expression.stop.isPresent) {
"$parentName.size"
} else {
if (expression.stop.asInt < 0) "$parentName.size${expression.stop.asInt}" else expression.stop.asInt
}
val sliceExpr = ensureNullGuard(currentShape, "slice($startIndex..<$stopIndex step ${expression.step}")
writer.write("@OptIn(ExperimentalStdlibApi::class)")
val slicedListName = addTempVar("slicedList", "$parentName$sliceExpr)")
return VisitedExpression(slicedListName, currentShape)
}
override fun visitSubexpression(expression: Subexpression): VisitedExpression {
val left = expression.left.accept(this)
return processRightSubexpression(expression.right, left.identifier, left.isObject)
}
private fun subexpression(expression: Subexpression, parentName: String): VisitedExpression {
val left = when (val left = expression.left) {
is FieldExpression -> subfield(left, parentName)
is Subexpression -> subexpression(left, parentName)
else -> throw CodegenException("Subexpression type $left is unsupported")
}
return processRightSubexpression(expression.right, left.identifier, left.isObject)
}
private fun processRightSubexpression(expression: JmespathExpression, leftName: String, isObject: Boolean = false): VisitedExpression =
when (expression) {
is FieldExpression -> subfield(expression, leftName, isObject)
is IndexExpression -> index(expression, leftName)
is Subexpression -> subexpression(expression, leftName)
is ProjectionExpression -> projection(expression, leftName)
else -> throw CodegenException("Subexpression type $expression is unsupported")
}
private fun index(expression: IndexExpression, parentName: String): VisitedExpression {
val index = if (expression.index < 0) "$parentName.size${expression.index}" else expression.index
val indexExpr = ensureNullGuard(currentShape.targetMemberOrSelf, "get($index)")
return VisitedExpression(addTempVar("index", "$parentName$indexExpr"), currentShape)
}
private val Shape.isEnumList: Boolean
get() = this is ListShape && ctx.model.expectShape(member.target).isEnum
private val Shape.isEnumMap: Boolean
get() = this is MapShape && ctx.model.expectShape(value.target).isEnum
private fun ensureNullGuard(shape: Shape?, expr: String, elvisExpr: String? = null): String =
if (shape?.isNullable == true) {
buildString {
append("?.$expr")
elvisExpr?.let { append(" ?: $it") }
}
} else {
".$expr"
}
private fun VisitedExpression.getKeys(): String {
val keys = this.shape?.targetOrSelf(ctx.model)?.allMembers
?.keys?.joinToString(", ", "listOf(", ")") { "\"$it\"" }
return keys ?: "listOf()"
}
private fun VisitedExpression.getValues(): String {
val values = this.shape?.targetOrSelf(ctx.model)?.allMembers?.keys
?.joinToString(", ", "listOf(", ")") { "${this.identifier}${ensureNullGuard(this.shape, it)}" }
return values ?: "listOf()"
}
private fun List.mergeProperties(): String {
val union = addTempVar("union", "HashMap()")
forEach { obj ->
val keys = addTempVar("keys", obj.getKeys())
val values = addTempVar("values", obj.getValues())
writer.withBlock("for(i in $keys.indices){", "}") {
write("union[$keys[i]] = $values[i]")
}
}
return union
}
private fun VisitedExpression.jmesPathToString(): String =
addTempVar("answer", "if(${this.identifier} as Any is String) ${this.identifier} else ${this.identifier}.toString()")
private fun VisitedExpression.toArray(): String =
addTempVar("answer", "if(${this.identifier} as Any is List<*> || ${this.identifier} as Any is Array<*>) ${this.identifier} as List<*> else listOf(${this.identifier})")
private fun List.getNotNull(): String {
val notNull = bestTempVarName("notNull")
writer.withBlock("val $notNull = listOfNotNull(", ").firstOrNull()") {
forEach {
write("${it.identifier},")
}
}
return notNull
}
private fun VisitedExpression.applyFunction(
name: String,
operation: String,
expression: JmespathExpression,
): VisitedExpression {
val result = bestTempVarName(name)
writer.withBlock("val $result = ${this.identifier}?.$operation {", "}") {
val expressionValue = subfieldCodegen((expression as ExpressionTypeExpression).expression as FieldExpression, "it")
write("$expressionValue!!")
}
return VisitedExpression(result)
}
private val Shape.isNullable: Boolean
get() = this is MemberShape &&
ctx.model.expectShape(target).let { !it.hasTrait() && !it.hasTrait() }
private val Shape.targetMemberOrSelf: Shape
get() = when (val target = targetOrSelf(ctx.model)) {
is ListShape -> target.member
is MapShape -> target.value
else -> this
}
}
/**
* Contains information about the output of a visited [JmespathExpression].
* @param identifier The generated identifier in which the expression result is stored.
* @param shape The underlying shape (if any) that the identifier represents. Not all expressions reference a modeled
* shape, e.g. [LiteralExpression] (the value is just a literal) or [FunctionExpression]s where the
* returned value is scalar.
* @param projected For projections, the context of the inner shape. For example, given the expression
* `foo[].bar[].baz.qux`, the shape that backs the identifier (and therefore determines overall nullability)
* is `foo`, but the shape that needs carried through to subfield expressions in the following projection
* is the target of `bar`, such that its subfields `baz` and `qux` can be recognized.
* @param nullable Boolean to indicate that a visited expression is nullable. Shape is used for this mostly but sometimes an
* expression is nullable for reasons that are not shape related
* @param isObject Boolean to indicate that a visited expression results in an object. Objects are represented as hash maps
* because it is not possible to construct a class at runtime
*/
data class VisitedExpression(val identifier: String, val shape: Shape? = null, val projected: Shape? = null, val nullable: Boolean = false, val isObject: Boolean = false)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy