All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sttp.tapir.docs.apispec.schema.TSchemaToASchema.scala Maven / Gradle / Ivy

The newest version!
package sttp.tapir.docs.apispec.schema

import sttp.apispec.{Schema => ASchema, _}
import sttp.tapir.Schema.{SName, Title, UniqueItems}
import sttp.tapir.Validator.EncodeToRaw
import sttp.tapir.docs.apispec.DocsExtensionAttribute.RichSchema
import sttp.tapir.docs.apispec.schema.TSchemaToASchema.{tDefaultToADefault, tExampleToAExample}
import sttp.tapir.docs.apispec.{DocsExtensions, exampleValue}
import sttp.tapir.internal._
import sttp.tapir.{Codec, Validator, Schema => TSchema, SchemaType => TSchemaType}

/** Converts a tapir schema to an OpenAPI/AsyncAPI schema, using `toSchemaReference` to resolve references. */
private[docs] class TSchemaToASchema(
    fallbackSchemaTitle: SName => String,
    toSchemaReference: ToSchemaReference,
    markOptionsAsNullable: Boolean
) {

  def apply[T](codec: Codec[T, _, _]): ASchema = apply(codec.schema, allowReference = true)

  /** @param allowReference
    *   Can a reference schema be generated, if this is a named schema - should be `false` for top-level component definitions (otherwise
    *   the definitions are infinitely recursive)
    */
  def apply[T](schema: TSchema[T], allowReference: Boolean, isOptionElement: Boolean = false): ASchema = {
    val nullable = markOptionsAsNullable && isOptionElement

    val result = schema.name match {
      case Some(name) if allowReference => toSchemaReference.map(schema, name)
      case _ =>
        schema.schemaType match {
          case TSchemaType.SInteger() => ASchema(SchemaType.Integer)
          case TSchemaType.SNumber()  => ASchema(SchemaType.Number)
          case TSchemaType.SBoolean() => ASchema(SchemaType.Boolean)
          case TSchemaType.SString()  => ASchema(SchemaType.String)
          case TSchemaType.SProduct(fields) if schema.attribute(TSchema.Tuple.Attribute).map(_.isTuple).getOrElse(false) =>
            ASchema(SchemaType.Array).copy(
              prefixItems = Some(fields.map(f => apply(f.schema, allowReference = true)))
            )
          case p @ TSchemaType.SProduct(fields) =>
            ASchema(SchemaType.Object).copy(
              required = p.required.map(_.encodedName),
              properties = extractProperties(fields)
            )
          case TSchemaType.SArray(el) => ASchema(SchemaType.Array).copy(items = Some(apply(el, allowReference = true)))
          case opt @ TSchemaType.SOption(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
            // #3288: in case there are multiple different customisations of the nested schema, we need to propagate the
            // metadata to properly customise the reference. These are also propagated in ToKeyedSchemas when computing
            // the initial list of schemas.
            val propagated = propagateMetadataForOption(schema, opt).element
            val ref = toSchemaReference.map(propagated, name)
            if (!markOptionsAsNullable) ref else ref.nullable
          case TSchemaType.SOption(el)    => apply(el, allowReference = true, isOptionElement = true)
          case TSchemaType.SBinary()      => ASchema(SchemaType.String).copy(format = Some(SchemaFormat.Binary))
          case TSchemaType.SDate()        => ASchema(SchemaType.String).copy(format = Some(SchemaFormat.Date))
          case TSchemaType.SDateTime()    => ASchema(SchemaType.String).copy(format = Some(SchemaFormat.DateTime))
          case TSchemaType.SRef(fullName) => toSchemaReference.mapDirect(fullName)
          case TSchemaType.SCoproduct(schemas, d) =>
            ASchema.oneOf(
              schemas
                .filterNot(_.hidden)
                .map(apply(_, allowReference = true))
                .sortBy {
                  case schema if schema.$ref.isDefined => schema.$ref.get
                  case schema =>
                    schema.`type`
                      .collect {
                        case List(t)                  => t.value
                        case List(t, SchemaType.Null) => t.value
                      }
                      .getOrElse("") + schema.toString
                },
              d.map(tDiscriminatorToADiscriminator)
            )
          case p @ TSchemaType.SOpenProduct(fields, valueSchema) =>
            ASchema(SchemaType.Object).copy(
              required = p.required.map(_.encodedName),
              properties = extractProperties(fields),
              additionalProperties = Some(apply(valueSchema, allowReference = true)).filterNot(_ => valueSchema.hidden)
            )
        }
    }

    if (result.$ref.isEmpty) {
      // only customising non-reference schemas; references might get enriched with some meta-data if there
      // are multiple different customisations of the referenced schema in ToSchemaReference (#1203)

      val primitiveValidators = schema.validator.asPrimitiveValidators
      val schemaIsWholeNumber = schema.schemaType match {
        case TSchemaType.SInteger() => true
        case _                      => false
      }

      var s = result
      s = if (nullable) s.nullable else s
      s = addMetadata(s, schema)
      s = addAttributes(s, schema)
      s = addConstraints(s, primitiveValidators, schemaIsWholeNumber)
      s
    } else result
  }

  private def extractProperties[T](fields: List[TSchemaType.SProductField[T]]) = {
    fields
      .filterNot(_.schema.hidden)
      .map(f => f.name.encodedName -> apply(f.schema, allowReference = true))
      .toListMap
  }

  private def addAttributes(oschema: ASchema, tschema: TSchema[_]): ASchema = {
    val titleFromAttr = tschema.attributes.get(Title.Attribute).map(_.value)
    // The primary motivation for using schema name as fallback title is to improve Swagger UX with
    // `oneOf` schemas in OpenAPI 3.1. See https://github.com/softwaremill/tapir/issues/3447 for details.
    def fallbackTitle = tschema.name.map(fallbackSchemaTitle)
    
    val const = tschema.attribute(TSchema.EncodedDiscriminatorValue.Attribute).map(_.v).map(v => ExampleSingleValue(v))
    
    oschema
      .copy(title = titleFromAttr.orElse(fallbackTitle))
      .copy(uniqueItems = tschema.attribute(UniqueItems.Attribute).map(_.uniqueItems))
      .copy(const = const)
  }

  private def addMetadata(oschema: ASchema, tschema: TSchema[_]): ASchema = {
    oschema.copy(
      description = tschema.description.orElse(oschema.description),
      default = tDefaultToADefault(tschema).orElse(oschema.default),
      examples = tExampleToAExample(tschema).map(List(_)).orElse(oschema.examples),
      format = tschema.format.orElse(oschema.format),
      deprecated = (if (tschema.deprecated) Some(true) else None).orElse(oschema.deprecated),
      extensions = DocsExtensions.fromIterable(tschema.docsExtensions)
    )
  }

  private def addConstraints(
      oschema: ASchema,
      vs: Seq[Validator.Primitive[_]],
      schemaIsWholeNumber: Boolean
  ): ASchema = vs.foldLeft(oschema)(addConstraints(_, _, schemaIsWholeNumber))

  private def addConstraints(aschema: ASchema, v: Validator.Primitive[_], wholeNumbers: Boolean): ASchema = {
    v match {
      case m @ Validator.Min(v, false) =>
        aschema.copy(minimum = Some(toBigDecimal(v, m.valueIsNumeric, wholeNumbers)))
      case m @ Validator.Min(v, true) =>
        aschema.copy(exclusiveMinimum = Some(toBigDecimal(v, m.valueIsNumeric, wholeNumbers)))
      case m @ Validator.Max(v, false) =>
        aschema.copy(maximum = Some(toBigDecimal(v, m.valueIsNumeric, wholeNumbers)))
      case m @ Validator.Max(v, true) =>
        aschema.copy(exclusiveMaximum = Some(toBigDecimal(v, m.valueIsNumeric, wholeNumbers)))
      case Validator.Pattern(value)                  => aschema.copy(pattern = Some(Pattern(value)))
      case Validator.MinLength(value, _)             => aschema.copy(minLength = Some(value))
      case Validator.MaxLength(value, _)             => aschema.copy(maxLength = Some(value))
      case Validator.MinSize(value)                  => aschema.copy(minItems = Some(value))
      case Validator.MaxSize(value)                  => aschema.copy(maxItems = Some(value))
      case Validator.Custom(_, _)                    => aschema
      case Validator.Enumeration(_, None, _)         => aschema
      case Validator.Enumeration(v, Some(encode), _) => addEnumeration(aschema, v, encode)
    }
  }

  private def addEnumeration[T](aschema: ASchema, v: List[T], encode: EncodeToRaw[T]): ASchema = {
    val values = v.flatMap(x => encode(x).map(ExampleSingleValue.apply))
    aschema.copy(`enum` = if (values.nonEmpty) Some(values) else None)
  }

  private def toBigDecimal[N](v: N, vIsNumeric: Numeric[N], schemaIsWholeNumber: Boolean): BigDecimal = {
    v match {
      case x: Int                   => BigDecimal(x)
      case x: Long                  => BigDecimal(x)
      case x: Float                 => BigDecimal(x.toDouble)
      case x: Double                => BigDecimal(x)
      case x: BigInt                => BigDecimal(x)
      case x: java.math.BigInteger  => BigDecimal(x)
      case x: BigDecimal            => x
      case x: java.math.BigDecimal  => BigDecimal(x)
      case _ if schemaIsWholeNumber => BigDecimal(vIsNumeric.toLong(v))
      case _                        => BigDecimal(vIsNumeric.toDouble(v))
    }
  }

  private def tDiscriminatorToADiscriminator(discriminator: TSchemaType.SDiscriminator): Discriminator = {
    val schemas = Some(
      discriminator.mapping
        .flatMap { case (k, TSchemaType.SRef(fullName)) =>
          toSchemaReference.mapDiscriminator(fullName).$ref.map(k -> _)
        }
        .toList
        .sortBy(_._1)
        .toListMap
    )
    Discriminator(discriminator.name.encodedName, schemas)
  }
}

object TSchemaToASchema {
  def tDefaultToADefault(schema: TSchema[_]): Option[ExampleValue] = schema.default.flatMap { case (_, raw) =>
    raw.flatMap(r => exampleValue(schema, r))
  }
  def tExampleToAExample(schema: TSchema[_]): Option[ExampleValue] = schema.encodedExample.flatMap(exampleValue(schema, _))
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy