All Downloads are FREE. Search and download functionalities are using the official Maven repository.

bsp.codegen.ir.SmithyToIR.scala Maven / Gradle / Ivy

The newest version!
package bsp.codegen.ir

import bsp.codegen.docs._
import bsp.codegen.ir.Primitive._
import bsp.codegen.ir.Type._
import bsp.traits.EnumKindTrait.EnumKind.{CLOSED, OPEN}
import bsp.traits._
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes._
import software.amazon.smithy.model.traits._

import java.util.Optional
import scala.collection.mutable.{Map => MMap}
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._

class SmithyToIR(model: Model) {

  val allDataKindAnnotated: Map[ShapeId, List[PolymorphicDataKind]] = {
    val allExtendableTypes = model.getShapesWithTrait(classOf[DataTrait]).asScala.toList
    val allExtendableTypeIds = allExtendableTypes.map(_.getId).toSet

    val dataKindInhabitants =
      model
        .getShapesWithTrait(classOf[DataKindTrait])
        .asScala
        .toList
        .map { shape =>
          val tr = shape.expectTrait(classOf[DataKindTrait])
          shape -> tr
        }

    // Validate that all data kinds extend a known extendable type.
    dataKindInhabitants.foreach(inhabitant => {
      val correct = inhabitant._2.getPolymorphicData.asScala.forall(allExtendableTypeIds.contains)
      if (!correct) {
        throw new RuntimeException(
          s"DataKindTrait on ${inhabitant._1.getId.getName} must extend a known extendable type."
        )
      }
    })

    val groupedInhabitants = dataKindInhabitants
      .flatMap { case (shape, dataKindTrait) =>
        dataKindTrait.getPolymorphicData.asScala.map((shape, dataKindTrait.getKind, _))
      }
      .groupBy { case (_, _, referencedShapeId) => referencedShapeId }
      .map { case (dataType, shapeAndTraits) =>
        dataType -> shapeAndTraits.map { case (shape, dataKind, _) =>
          PolymorphicDataKind(dataKind, shape.getId)
        }
      }

    allExtendableTypeIds.map { id =>
      val inhabitants = groupedInhabitants.getOrElse(id, List.empty).sortBy(_.kind)
      id -> inhabitants
    }.toMap
  }

  def docTree(namespace: String): DocTree = {
    val shapes =
      model.shapes().iterator().asScala.toList.filter(_.getId.getNamespace == namespace)
    val commonTag = "basic"
    val commonShapes =
      shapes
        .filter(_.getTrait(classOf[TagsTrait]).toScala.exists(_.getTags.contains(commonTag)))
        .sortBy(_.getTrait(classOf[DocsPriorityTrait]).toScala.map(_.getPriority).getOrElse(0))
        .reverse

    val docNodes = {
      val map = MMap[ShapeId, DocNode]()
      val visitor = new DocShapeVisitor(map)

      shapes.foreach(_.accept(visitor))
      map.toMap
    }

    val commonShapeIds = commonShapes.map(_.getId)
    val services = docNodes.collect({ case (id, node: ServiceDocNode) => (id, node) }).values.toList
    val operations = docNodes.collect({ case (id, node: OperationDocNode) => (id, node) })
    val structures = docNodes.collect({ case (id, node: StructureDocNode) => (id, node) })
    val dataKindInhabitants = allDataKindAnnotated
      .collect { case (extendableType, inhabitants) =>
        (extendableType, inhabitants.filter(_.shapeId.getNamespace == namespace))
      }
      .filter(_._2.nonEmpty)

    DocTree(commonShapeIds, services, operations, structures, dataKindInhabitants)
  }

  def definitions(namespace: String): List[Def] =
    model
      .shapes()
      .iterator()
      .asScala
      .filter(_.getId.getNamespace == namespace)
      .flatMap { shape =>
        if (shape.hasTrait("smithy.api#trait")) None
        else shape.accept(ToIRVisitor)
      }
      .toList

  object ToIRVisitor extends ShapeVisitor.Default[List[Def]] {
    // A lot of smithy shapes are not transformed to bsp4j data types, but are rather used
    // as type aliases and holders in Smithy. That is the case, in particular, for all
    // primitive shapes, that get handled by this default method.
    protected def getDefault(shape: Shape): List[Def] = List.empty

    def buildOperation(op: OperationShape): Option[Operation] = {
      val maybeMethod = if (op.hasTrait(classOf[JsonRequestTrait])) {
        val methodName = op.expectTrait(classOf[JsonRequestTrait]).getValue
        val methodType = JsonRPCMethodType.Request
        Some(methodName -> methodType)
      } else if (op.hasTrait(classOf[JsonNotificationTrait])) {
        val methodName = op.expectTrait(classOf[JsonNotificationTrait]).getValue
        val methodType = JsonRPCMethodType.Notification
        Some(methodName -> methodType)
      } else None
      maybeMethod.map { case (methodName, methodType) =>
        val inputType = getType(op.getInput).getOrElse(TUnit)
        val outputType = getType(op.getOutput).getOrElse(TUnit)
        val hints = getHints(op)
        Operation(op.getId, inputType, outputType, methodType, methodName, hints)
      }
    }

    override def serviceShape(shape: ServiceShape): List[Def] = {
      val operations = shape.getOperations.asScala.toList
        .map(model.expectShape(_, classOf[OperationShape]))
        .flatMap(buildOperation)
      List(Def.Service(shape.getId, operations, getHints(shape)))
    }

    def toField(member: MemberShape): Option[Field] = {
      val required = member.hasTrait(classOf[RequiredTrait])
      val jsonRenamed = member.getTrait(classOf[JsonNameTrait]).toScala.map(_.getValue)
      val name = member.getMemberName
      getType(member.getTarget).map(Field(name, _, required, jsonRenamed, getHints(member)))
    }

    def getType(shapeId: ShapeId): Option[Type] = model.expectShape(shapeId).accept(ToTypeVisitor)

    def getType(maybeShapeId: Optional[ShapeId]): Option[Type] =
      maybeShapeId.toScala.flatMap(getType)

    def dataKindShapeId(dataTypeShapeId: ShapeId): ShapeId =
      ShapeId.fromParts(dataTypeShapeId.getNamespace, dataTypeShapeId.getName + "Kind")

    override def structureShape(shape: StructureShape): List[Def] = {
      // Skip shapes that are used as mixins.
      if (shape.hasTrait(classOf[MixinTrait])) {
        return List.empty
      }

      val fields = shape.members().asScala.flatMap(toField).toList

      def fieldIsData(field: Field): Boolean =
        field.name == "data" && field.tpe.isPrimitive(Primitive.PDocument)

      def makeDiscriminatorField(dataField: Field): Field = {
        val doc =
          "Kind of data to expect in the `data` field. If this field is not set, the kind of data is not specified."
        val hints = List(Hint.Documentation(doc))
        val dataTypeShape = dataField.tpe match {
          case TPrimitive(Primitive.PDocument, dataTypeShape) => dataTypeShape
          case _ => throw new RuntimeException("Expected document type")
        }
        Field(
          "dataKind",
          TPrimitive(PString, dataKindShapeId(dataTypeShape)),
          required = false,
          None,
          hints
        )
      }

      def insertDiscriminator(fields: List[Field]): List[Field] = fields match {
        case field :: next if fieldIsData(field) =>
          val newField = makeDiscriminatorField(field)
          newField :: field :: next

        case field :: next => field :: insertDiscriminator(next)
        case Nil           => Nil
      }

      val associatedDataKinds =
        fields.filter(fieldIsData).flatMap(_.tpe.members()) match {
          case Nil => List.empty
          case dataShapeId :: Nil =>
            if (model.expectShape(dataShapeId).hasTrait(classOf[DataTrait])) {
              allDataKindAnnotated(dataShapeId)
            } else {
              throw new RuntimeException("Expected data field's shape to have data trait")
            }
          case _ => throw new RuntimeException("Expected exactly one data field")
        }

      val updatedFields = insertDiscriminator(fields)

      List(Def.Structure(shape.getId, updatedFields, getHints(shape), associatedDataKinds))
    }

    override def intEnumShape(shape: IntEnumShape): List[Def] = {
      val enumValues = shape.getEnumValues.asScala.map { case (name, value) =>
        val valueHints = getHints(shape.getAllMembers.get(name))
        EnumValue(name, value.toInt, valueHints)
      }.toList
      val hints = getHints(shape)
      shape.expectTrait(classOf[EnumKindTrait]).getEnumKind match {
        case OPEN =>
          List(Def.OpenEnum(shape.getId, EnumType.IntEnum, enumValues.sortBy(_.value), hints))
        case CLOSED => List(Def.ClosedEnum(shape.getId, EnumType.IntEnum, enumValues, hints))
      }
    }

    override def enumShape(shape: EnumShape): List[Def] = {
      val enumValues = shape.getEnumValues.asScala.map { case (name, value) =>
        val valueHints = getHints(shape.getAllMembers.get(name))
        EnumValue(name, value, valueHints)
      }.toList
      val hints = getHints(shape)
      shape.expectTrait(classOf[EnumKindTrait]).getEnumKind match {
        case OPEN =>
          List(Def.OpenEnum(shape.getId, EnumType.StringEnum, enumValues.sortBy(_.name), hints))
        case CLOSED => List(Def.ClosedEnum(shape.getId, EnumType.StringEnum, enumValues, hints))
      }
    }

    override def documentShape(shape: DocumentShape): List[Def] = {
      val hints = getHints(shape)

      // A document shape with the data trait in fact defines two structures: one for the data (a type alias for any)
      // and one for the data kind (an enum).
      if (shape.hasTrait(classOf[DataTrait])) {
        val id = shape.getId

        val allKnownInhabitants = allDataKindAnnotated(id)
        val openEnumId = dataKindShapeId(id)
        val values = allKnownInhabitants.map { case PolymorphicDataKind(disc, member) =>
          val snakeCased = disc.replace('-', '_').toUpperCase()
          val memberDoc = s"`data` field must contain a ${member.getName} object."
          EnumValue(snakeCased, disc, List(Hint.Documentation(memberDoc)))
        }

        val dataKindDef = Def.OpenEnum(openEnumId, EnumType.StringEnum, values, hints)
        val dataDef = Def.Alias(id, Type.TDocument, hints)
        List(dataKindDef, dataDef)
      } else {
        aliasShape(shape)
      }
    }

    def aliasShape(shape: Shape): List[Def] = {
      val hints = getHints(shape)
      shape
        .accept(ToTypeVisitor)
        .map(it => List(Def.Alias(shape.getId, it, hints)))
        .getOrElse(List.empty)
    }

    override def unionShape(shape: UnionShape): List[Def] = aliasShape(shape)

    override def booleanShape(shape: BooleanShape): List[Def] = aliasShape(shape)

    override def integerShape(shape: IntegerShape): List[Def] = aliasShape(shape)

    override def longShape(shape: LongShape): List[Def] = aliasShape(shape)

    override def floatShape(shape: FloatShape): List[Def] = aliasShape(shape)

    override def doubleShape(shape: DoubleShape): List[Def] = aliasShape(shape)

    override def stringShape(shape: StringShape): List[Def] = aliasShape(shape)

    override def timestampShape(shape: TimestampShape): List[Def] = aliasShape(shape)

    override def listShape(shape: ListShape): List[Def] = aliasShape(shape)

    override def mapShape(shape: MapShape): List[Def] = aliasShape(shape)
  }

  object ToTypeVisitor extends ShapeVisitor[Option[Type]] {

    def prim(primitive: Primitive, shape: Shape): Option[Type] = Some(
      TPrimitive(primitive, shape.getId)
    )

    def booleanShape(shape: BooleanShape): Option[Type] = prim(PBool, shape)

    def integerShape(shape: IntegerShape): Option[Type] = prim(PInt, shape)

    def longShape(shape: LongShape): Option[Type] = prim(PLong, shape)

    def floatShape(shape: FloatShape): Option[Type] = prim(PFloat, shape)

    def documentShape(shape: DocumentShape): Option[Type] =
      prim(PDocument, shape)

    def doubleShape(shape: DoubleShape): Option[Type] = prim(PDouble, shape)

    def stringShape(shape: StringShape): Option[Type] = prim(PString, shape)

    def structureShape(shape: StructureShape): Option[Type] = Some(TRef(shape.getId))

    def listShape(shape: ListShape): Option[Type] =
      model
        .expectShape(shape.getMember.getTarget)
        .accept(this)
        .map(memberType =>
          if (shape.hasTrait(classOf[SetTrait])) TSet(memberType)
          else TCollection(memberType)
        )

    def mapShape(shape: MapShape): Option[Type] = for {
      k <- shape.getKey.accept(this)
      v <- shape.getValue.accept(this)
    } yield TMap(k, v)

    def unionShape(shape: UnionShape): Option[Type] = Some {
      if (shape.hasTrait(classOf[UntaggedUnionTrait])) {
        val memberTypes = shape.getAllMembers.asScala.values.map(_.accept(this)).toList.flatten
        TUntaggedUnion(memberTypes)
      } else TRef(shape.getId)
    }

    override def enumShape(shape: EnumShape): Option[Type] = {
      val enumKind = shape.expectTrait(classOf[EnumKindTrait]).getEnumKind
      enumKind match {
        case OPEN   => prim(PString, shape)
        case CLOSED => Some(TRef(shape.getId))
      }
    }

    override def intEnumShape(shape: IntEnumShape): Option[Type] = {
      val enumKind = shape.expectTrait(classOf[EnumKindTrait]).getEnumKind
      enumKind match {
        case OPEN   => prim(PInt, shape)
        case CLOSED => Some(TRef(shape.getId))
      }
    }

    def memberShape(shape: MemberShape): Option[Type] =
      model.expectShape(shape.getTarget).accept(this)

    def timestampShape(shape: TimestampShape): Option[Type] = Some(
      TPrimitive(PTimestamp, shape.getId)
    )

    def shortShape(shape: ShortShape): Option[Type] = None

    def blobShape(shape: BlobShape): Option[Type] = None

    def byteShape(shape: ByteShape): Option[Type] = None

    def operationShape(shape: OperationShape): Option[Type] = None

    def resourceShape(shape: ResourceShape): Option[Type] = None

    def serviceShape(shape: ServiceShape): Option[Type] = None

    def bigIntegerShape(shape: BigIntegerShape): Option[Type] = None

    def bigDecimalShape(shape: BigDecimalShape): Option[Type] = None
  }

  def getHints(shape: Shape): List[Hint] = {
    val documentation = shape
      .getTrait(classOf[DocumentationTrait])
      .toScala
      .map(_.getValue)
      .map(Hint.Documentation)

    val deprecated = shape
      .getTrait(classOf[DeprecatedTrait])
      .toScala
      .map(_.getMessage.orElse(""))
      .map(Hint.Deprecated)

    val unstable = if (shape.hasTrait(classOf[UnstableTrait])) Some(Hint.Unstable) else None

    List(documentation, deprecated, unstable).flatten
  }

  class DocShapeVisitor(map: MMap[ShapeId, DocNode]) extends ShapeVisitor.Default[Unit] {
    protected def getDefault(shape: Shape) = {
      val id = shape.getId
      if (!map.contains(id)) {
        val docNodes = shape.accept(ToIRVisitor).map { definition =>
          StructureDocNode(definition)
        }

        docNodes.foreach(node => map.put(node.shapeId, node))
      }
    }

    override def operationShape(shape: OperationShape) = {
      val input = shape.getInput.toScala
      input.foreach(model.expectShape(_).accept(this))
      val output = shape.getOutput.toScala
      output.foreach(model.expectShape(_).accept(this))
      val operation = ToIRVisitor.buildOperation(shape).map { op: Operation =>
        OperationDocNode(op, input, output)
      }

      operation.foreach(map.put(shape.getId, _))
    }

    override def serviceShape(shape: ServiceShape) = {
      val ops = shape.getOperations.asScala.toList
      ops.foreach(model.expectShape(_).accept(this))
      map.put(shape.getId, ServiceDocNode(shape.getId, ops))
    }

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy