
software.amazon.smithy.kotlin.codegen.rendering.endpoints.DefaultEndpointProviderGenerator.kt Maven / Gradle / Ivy
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.kotlin.codegen.rendering.endpoints
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.defaultName
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.model.SourceLocation
import software.amazon.smithy.rulesengine.language.EndpointRuleSet
import software.amazon.smithy.rulesengine.language.syntax.Identifier
import software.amazon.smithy.rulesengine.language.syntax.ToExpression
import software.amazon.smithy.rulesengine.language.syntax.expressions.*
import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.*
import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal
import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.LiteralVisitor
import software.amazon.smithy.rulesengine.language.syntax.rule.Condition
import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule
import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule
import software.amazon.smithy.rulesengine.language.syntax.rule.Rule
import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule
/**
* The core set of standard library functions available to the rules language.
*/
internal val coreFunctions: Map = mapOf(
"substring" to RuntimeTypes.SmithyClient.Endpoints.Functions.substring,
"isValidHostLabel" to RuntimeTypes.SmithyClient.Endpoints.Functions.isValidHostLabel,
"uriEncode" to RuntimeTypes.SmithyClient.Endpoints.Functions.uriEncode,
"parseURL" to RuntimeTypes.SmithyClient.Endpoints.Functions.parseUrl,
)
/**
* Defines a callback that renders an SDK-specific endpoint property.
* The callback is passed the following:
* - a writer to which code should be rendered
* - the root expression construct for the property
* - a generic expression renderer to defer back to the base implementation (for example, for handling generic sub-expressions
* or string templates for which the caller doesn't need to provide extended behavior)
*/
typealias EndpointPropertyRenderer = (KotlinWriter, Expression, ExpressionRenderer) -> Unit
/**
* An expression renderer generates code for an endpoint expression construct.
*/
fun interface ExpressionRenderer {
fun renderExpression(expr: Expression)
}
/**
* Renders the default endpoint provider based on the provided rule set.
*/
class DefaultEndpointProviderGenerator(
private val ctx: ProtocolGenerator.GenerationContext,
private val rules: EndpointRuleSet,
private val writer: KotlinWriter,
) : ExpressionRenderer {
companion object {
fun getSymbol(settings: KotlinSettings): Symbol =
buildSymbol {
val prefix = clientName(settings.sdkId)
name = "Default${prefix}EndpointProvider"
namespace = "${settings.pkg.name}.endpoints"
}
}
private val endpointCustomizations = ctx.integrations.mapNotNull { it.customizeEndpointResolution(ctx) }
private val externalFunctions = endpointCustomizations
.map { it.externalFunctions }
.fold(mutableMapOf()) { acc, extFunctions ->
acc.putAll(extFunctions)
acc
}.toMap()
private val propertyRenderers = endpointCustomizations
.map { it.propertyRenderers }
.fold(mutableMapOf()) { acc, propRenderers ->
acc.putAll(propRenderers)
acc
}
private val expressionGenerator = ExpressionGenerator(writer, rules, coreFunctions + externalFunctions)
private val defaultProviderSymbol = getSymbol(ctx.settings)
private val interfaceSymbol = EndpointProviderGenerator.getSymbol(ctx.settings)
private val paramsSymbol = EndpointParametersGenerator.getSymbol(ctx.settings)
fun render() {
renderDocumentation()
writer.withBlock(
"#L class #T: #T {",
"}",
ctx.settings.api.visibility,
defaultProviderSymbol,
interfaceSymbol,
) {
renderResolve()
}
}
override fun renderExpression(expr: Expression) {
expr.accept(expressionGenerator)
}
private fun renderDocumentation() {
writer.dokka {
write("The default endpoint provider as specified by the service model.")
}
}
private fun renderResolve() {
writer.withBlock(
"public override suspend fun resolveEndpoint(params: #T): #T {",
"}",
paramsSymbol,
RuntimeTypes.SmithyClient.Endpoints.Endpoint,
) {
rules.rules.forEach(::renderRule)
write("")
write("throw #T(\"endpoint rules were exhausted without a match\")", RuntimeTypes.SmithyClient.Endpoints.EndpointProviderException)
}
}
private fun renderRule(rule: Rule) {
when (rule) {
is EndpointRule -> renderEndpointRule(rule)
is ErrorRule -> renderErrorRule(rule)
is TreeRule -> renderTreeRule(rule)
else -> throw CodegenException("unexpected rule")
}
}
private fun withConditions(conditions: List, block: () -> Unit) {
val (assignments, expressions) = conditions.partition()
// explicitly wrap blocks with assignments to restrict scope therein
writer.wrapBlockIf(assignments.isNotEmpty(), "run {", "}") {
assignments.forEach {
writer.writeInline("val #L = ", it.result.get().defaultName())
renderExpression(it.function)
writer.write("")
}
if (expressions.isNotEmpty()) {
writer.openBlock("if (")
expressions.forEachIndexed { index, it ->
renderExpression(it.function)
if (!it.function.isBooleanFunction()) { // these are meant to be evaluated on "truthiness" (i.e. is the result non-null)
writeInline(" != null")
}
write(if (index == expressions.lastIndex) "" else " &&")
}
writer.closeAndOpenBlock(") {")
}
block()
if (expressions.isNotEmpty()) {
writer.closeBlock("}")
}
}
}
private fun renderEndpointRule(rule: EndpointRule) {
withConditions(rule.conditions) {
writer.withBlock("return #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url.Url)
renderExpression(rule.endpoint.url)
write("),")
if (rule.endpoint.headers.isNotEmpty()) {
withBlock("headers = #T {", "},", RuntimeTypes.Http.Headers) {
rule.endpoint.headers.entries.forEach { (k, v) ->
v.forEach {
writeInline("append(#S, ", k)
renderExpression(it)
write(")")
}
}
}
}
if (rule.endpoint.properties.isNotEmpty()) {
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Collections.attributesOf) {
rule.endpoint.properties.entries.forEach { (k, v) ->
val kStr = k.toString()
// caller has a chance to generate their own value for a recognized property
if (kStr in propertyRenderers) {
propertyRenderers[kStr]!!(writer, v, this@DefaultEndpointProviderGenerator)
return@forEach
}
// otherwise, we just traverse the value like any other rules expression, object values will
// be rendered as Documents
writeInline("#S to ", kStr)
renderExpression(v)
ensureNewline()
}
}
}
}
}
}
private fun renderErrorRule(rule: ErrorRule) {
withConditions(rule.conditions) {
writer.writeInline("throw #T(", RuntimeTypes.SmithyClient.Endpoints.EndpointProviderException)
renderExpression(rule.error)
writer.write(")")
}
}
private fun renderTreeRule(rule: TreeRule) {
withConditions(rule.conditions) {
rule.rules.forEach(::renderRule)
}
}
}
class ExpressionGenerator(
private val writer: KotlinWriter,
private val rules: EndpointRuleSet,
private val functions: Map,
) : ExpressionVisitor, LiteralVisitor, TemplateVisitor {
override fun visitLiteral(literal: Literal) {
literal.accept(this as LiteralVisitor)
}
override fun visitRef(reference: Reference) {
if (isParamRef(reference)) {
writer.writeInline("params.")
}
writer.writeInline(reference.name.defaultName())
}
override fun visitGetAttr(getAttr: GetAttr) {
getAttr.target.accept(this)
getAttr.path.forEach {
when (it) {
is GetAttr.Part.Key -> writer.writeInline("?.#L", it.key().toString())
is GetAttr.Part.Index -> writer.writeInline("?.getOrNull(#L)", it.index())
else -> throw CodegenException("unexpected path")
}
}
}
override fun visitIsSet(target: Expression) {
target.accept(this)
writer.writeInline(" != null")
}
override fun visitNot(target: Expression) {
writer.writeInline("!(")
target.accept(this)
writer.writeInline(")")
}
override fun visitBoolEquals(left: Expression, right: Expression) {
visitEquals(left, right)
}
override fun visitStringEquals(left: Expression, right: Expression) {
visitEquals(left, right)
}
private fun visitEquals(left: Expression, right: Expression) {
left.accept(this)
writer.writeInline(" == ")
right.accept(this)
}
override fun visitLibraryFunction(fn: FunctionDefinition, args: MutableList) {
writer.writeInline("#T(", functions.getValue(fn.id))
args.forEachIndexed { index, it ->
it.accept(this)
if (index < args.lastIndex) {
writer.writeInline(", ")
}
}
writer.writeInline(")")
}
override fun visitInteger(value: Int) {
writer.writeInline("#L", value)
}
override fun visitString(value: Template) {
writer.writeInline("\"")
value.accept(this).forEach {} // must "consume" the stream to actually generate everything
writer.writeInline("\"")
}
override fun visitBoolean(value: Boolean) {
writer.writeInline("#L", value)
}
override fun visitRecord(value: MutableMap) {
writer.withInlineBlock("#T {", "}", RuntimeTypes.Core.Content.buildDocument) {
value.entries.forEachIndexed { index, (k, v) ->
writeInline("#S to ", k.toString())
v.accept(this@ExpressionGenerator as LiteralVisitor)
if (index < value.size - 1) write("")
}
}
}
override fun visitTuple(value: MutableList) {
writer.withInlineBlock("listOf(", ")") {
value.forEachIndexed { index, it ->
it.accept(this@ExpressionGenerator as LiteralVisitor)
if (index < value.size - 1) write(",") else writeInline(",")
}
}
}
override fun visitStaticTemplate(value: String) = writeTemplateString(value)
override fun visitSingleDynamicTemplate(value: Expression) = writeTemplateExpression(value)
override fun visitStaticElement(value: String) = writeTemplateString(value)
override fun visitDynamicElement(value: Expression) = writeTemplateExpression(value)
// no-ops for kotlin codegen
override fun startMultipartTemplate() {}
override fun finishMultipartTemplate() {}
private fun writeTemplateString(value: String) {
writer.writeInline(value.replace("\"", "\\\""))
}
private fun writeTemplateExpression(expr: Expression) {
writer.writeInline("\${")
expr.accept(this)
writer.writeInline("}")
}
private fun isParamRef(ref: Reference): Boolean = rules.parameters.toList().any { it.name == ref.name }
}
// splits a list of conditions into a set of assignments and expressions
// adds "implicit" isSet (x != null) checks that must be evaluated for each assignment
private fun List.partition(): Pair, List> {
val (assignments, expressions) = partition { it.result.isPresent }
val implicitExpressions = assignments.map(Condition::buildResultIsSetExpression)
return Pair(assignments, implicitExpressions + expressions)
}
// build an "isSet" expression that checks the nullness of the result of an assignment operation
private fun Condition.buildResultIsSetExpression() =
Condition
.Builder()
.fn(isSet(Reference(result.get(), SourceLocation.NONE)))
.build()
private fun isSet(expression: Expression) =
IsSet
.getDefinition()
.createFunction(FunctionNode.ofExpressions(IsSet.ID, ToExpression { expression }))
private fun Expression.isBooleanFunction(): Boolean {
if (this !is LibraryFunction) {
return true
}
return name !in setOf(
"parseUrl",
"substring",
"uriEncode",
"aws.parseArn",
"aws.partition",
)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy