All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.scalatra.swagger.Swagger.scala Maven / Gradle / Ivy

There is a newer version: 3.1.0
Show newest version
package org.scalatra
package swagger

import java.lang.reflect.Field
import java.util.concurrent.ConcurrentHashMap
import java.util.{ Date => JDate }

import org.json4s.JsonAST.JValue
import org.scalatra.swagger.reflect._
import org.scalatra.swagger.runtime.annotations.{ ApiModel, ApiModelProperty }
import org.slf4j.LoggerFactory

import scala.jdk.CollectionConverters._

trait SwaggerEngine {
  def swaggerVersion: String
  def apiVersion: String
  def host: String
  def apiInfo: ApiInfo
  def extraSwaggerDefinition: Option[JValue]

  private[swagger] val _docs = new ConcurrentHashMap[String, Api]().asScala

  private[this] var _authorizations = List.empty[AuthorizationType]
  def authorizations = _authorizations
  def addAuthorization(auth: AuthorizationType): Unit = { _authorizations ::= auth }

  def docs = _docs.values

  /**
   * Configurations used by UrlGenerator when creating baseUrl.
   */
  def baseUrlIncludeContextPath = true
  def baseUrlIncludeServletPath = false

  /**
   * Returns the documentation for the given path.
   */
  def doc(path: String): Option[Api] = _docs.get(path)

  /**
   * Registers the documentation for an API with the given path.
   */
  def register(listingPath: String, resourcePath: String, description: Option[String], s: SwaggerSupportSyntax with SwaggerSupportBase, consumes: List[String], produces: List[String], protocols: List[String], authorizations: List[String]): Unit

}

object Swagger {

  val excludes: Set[java.lang.reflect.Type] = Set(
    classOf[java.util.TimeZone],
    classOf[java.util.Date],
    classOf[java.time.OffsetDateTime],
    classOf[java.time.ZonedDateTime],
    classOf[java.time.LocalDateTime],
    classOf[java.time.LocalDate],
    classOf[java.time.LocalTime],
    classOf[java.time.Instant],
    classOf[java.time.chrono.Chronology],
    classOf[java.time.ZoneOffset])
  val SpecVersion = "2.0"

  def collectModels[T: Manifest](alreadyKnown: Set[Model]): Set[Model] = collectModels(Reflector.scalaTypeOf[T], alreadyKnown)
  private[swagger] def collectModels(tpe: ScalaType, alreadyKnown: Set[Model], known: Set[ScalaType] = Set.empty): Set[Model] = {

    if (excludes contains tpe.erasure) {
      // Date classes only output 'date-time' in a uniform format, so there is no need to analyze the structure of the object.
      // Therefore, the analysis of classes specified as excludes will be skipped.
      Set.empty
    } else if (tpe.isMap) {
      collectModels(tpe.typeArgs.head, alreadyKnown, tpe.typeArgs.toSet) ++ collectModels(tpe.typeArgs.last, alreadyKnown, tpe.typeArgs.toSet)

    } else if ((tpe.isCollection && tpe.typeArgs.nonEmpty) || (tpe.isOption && tpe.typeArgs.nonEmpty)) {
      val ntpe = tpe.typeArgs.head
      if (!known.contains(ntpe)) {
        collectModels(ntpe, alreadyKnown, known + ntpe)
      } else {
        Set.empty
      }

    } else {
      if (alreadyKnown.map(_.id).contains(tpe.simpleName)) {
        Set.empty
      } else {
        val descriptor = Reflector.describe(tpe)
        descriptor match {
          case descriptor: ClassDescriptor =>
            val ctorModels = descriptor.mostComprehensive.filterNot(_.isPrimitive).toVector
            val propModels = descriptor.properties.filterNot(p => p.isPrimitive || ctorModels.exists(_.name == p.name))
            val subModels = (ctorModels.map(_.argType) ++ propModels.map(_.returnType)).toSet -- known
            val topLevel = for {
              tl <- subModels + descriptor.erasure
              if !(tl.isCollection || tl.isMap || tl.isOption)
              m <- modelToSwagger(tl)
            } yield m

            val nested = subModels.foldLeft((topLevel, known + descriptor.erasure)) { (acc, b) =>
              val m = collectModels(b, alreadyKnown, acc._2)
              (acc._1 ++ m, acc._2 + b)
            }
            nested._1
          case _ => Set.empty
        }
      }
    }
  }

  import org.scalatra.util.RicherString._
  def modelToSwagger[T](implicit mf: Manifest[T]): Option[Model] = modelToSwagger(Reflector.scalaTypeOf[T])

  private[this] def toModelProperty(descriptor: ClassDescriptor, position: Option[Int] = None, required: Boolean = true, description: Option[String] = None, allowableValues: String = "", example: Option[String] = None, minimumValue: Option[Double] = None, maximumValue: Option[Double] = None, default: Option[String] = None, hidden: Boolean = false)(prop: PropertyDescriptor) = {
    val ctorParam = descriptor.mostComprehensive.find(_.name == prop.name)
    val mp = ModelProperty(
      `type` = DataType.fromScalaType(if (prop.returnType.isOption) ctorParam.map(_.argType.typeArgs.head).getOrElse(prop.returnType.typeArgs.head) else ctorParam.map(_.argType).getOrElse(prop.returnType)),
      position = if (position.isDefined && position.forall(_ >= 0)) position.get else ctorParam.map(_.argIndex).getOrElse(position.getOrElse(0)),
      required = required && !prop.returnType.isOption,
      description = description.flatMap(_.blankOption),
      allowableValues = convertToAllowableValues(allowableValues),
      example = example.flatMap(_.blankOption),
      default = default.flatMap(_.blankOption),
      minimumValue = minimumValue,
      maximumValue = maximumValue,
      hidden = hidden)
    prop.name -> mp
  }
  def modelToSwagger(klass: ScalaType): Option[Model] = {
    if (Reflector.isPrimitive(klass.erasure) || Reflector.isExcluded(klass.erasure, excludes.toSeq)) None
    else {
      val name = klass.simpleName

      val descriptor = Reflector.describe(klass).asInstanceOf[ClassDescriptor]
      val apiModel = Option(klass.erasure.getAnnotation(classOf[ApiModel]))

      val fields = klass.erasure.getDeclaredFields.toList collect {
        case f: Field if f.getAnnotation(classOf[ApiModelProperty]) != null =>
          val annotation = f.getAnnotation(classOf[ApiModelProperty])
          val position = if (annotation.position() == Integer.MAX_VALUE) None else Some(annotation.position())
          val minimumValue = if (annotation.minimumValue().isNaN) None else Option(annotation.minimumValue())
          val maximumValue = if (annotation.maximumValue().isNaN) None else Option(annotation.maximumValue())
          val asModelProperty = toModelProperty(descriptor, position, annotation.required(), annotation.description().blankOption, annotation.allowableValues(),
            annotation.example().blankOption, minimumValue, maximumValue, annotation.defaultValue().blankOption, annotation.hidden())_

          descriptor.properties.find(_.mangledName == f.getName) map asModelProperty

        case f: Field =>
          val asModelProperty = toModelProperty(descriptor)_
          descriptor.properties.find(_.mangledName == f.getName) map asModelProperty

      }

      val result = apiModel map { am =>
        Model(
          id = name,
          name = name,
          qualifiedName = klass.fullName.blankOption,
          description = am.description().blankOption,
          properties = fields.flatten,
          baseModel = am.parent.getName.blankOption,
          discriminator = am.discriminator.blankOption)
      } orElse Some(Model(name, name, klass.fullName.blankOption, properties = fields.flatten))
      result
    }
  }

  private def convertToAllowableValues(csvString: String, paramType: String = null): AllowableValues = {
    if (csvString.toLowerCase.startsWith("range[")) {
      val ranges = csvString.substring(6, csvString.length() - 1).split(",")
      buildAllowableRangeValues(ranges, csvString, inclusive = true)

    } else if (csvString.toLowerCase.startsWith("rangeexclusive[")) {
      val ranges = csvString.substring(15, csvString.length() - 1).split(",")
      buildAllowableRangeValues(ranges, csvString, inclusive = false)

    } else {
      if (csvString.isEmpty) {
        AllowableValues.AnyValue
      } else {
        val params = csvString.split(",").toList
        paramType match {
          case null => AllowableValues.AllowableValuesList(params)
          case "string" => AllowableValues.AllowableValuesList(params)
        }
      }
    }
  }

  private def buildAllowableRangeValues(ranges: Array[String], inputStr: String, inclusive: Boolean): AllowableValues.AllowableRangeValues = {
    var min: java.lang.Float = 0f
    var max: java.lang.Float = 0f

    if (ranges.size < 2) {
      throw new RuntimeException("Allowable values format " + inputStr + "is incorrect")
    }

    if (ranges(0).equalsIgnoreCase("Infinity")) {
      min = Float.PositiveInfinity
    } else if (ranges(0).equalsIgnoreCase("-Infinity")) {
      min = Float.NegativeInfinity
    } else {
      min = ranges(0).toFloat
    }

    if (ranges(1).equalsIgnoreCase("Infinity")) {
      max = Float.PositiveInfinity
    } else if (ranges(1).equalsIgnoreCase("-Infinity")) {
      max = Float.NegativeInfinity
    } else {
      max = ranges(1).toFloat
    }

    AllowableValues.AllowableRangeValues(if (inclusive) Range.inclusive(min.toInt, max.toInt) else Range(min.toInt, max.toInt))
  }

}

/**
 * An instance of this class is used to hold the API documentation.
 */
class Swagger(val swaggerVersion: String, val apiVersion: String, val apiInfo: ApiInfo, val host: String = "", val extraSwaggerDefinition: Option[JValue] = None) extends SwaggerEngine {
  private[this] val logger = LoggerFactory.getLogger(getClass)

  /**
   * Registers the documentation for an API with the given path.
   */
  def register(listingPath: String, resourcePath: String, description: Option[String], s: SwaggerSupportSyntax with SwaggerSupportBase, consumes: List[String], produces: List[String], protocols: List[String], authorizations: List[String]) = {
    logger.debug(s"registering swagger api with: { listingPath: $listingPath, resourcePath: $resourcePath, description: $resourcePath, servlet: ${s.getClass} }")
    val endpoints: List[Endpoint] = s.endpoints(resourcePath) collect { case m: Endpoint => m }
    _docs += listingPath -> Api(
      apiVersion,
      swaggerVersion,
      resourcePath,
      description,
      (produces ::: endpoints.flatMap(_.operations.flatMap(_.produces))).distinct,
      (consumes ::: endpoints.flatMap(_.operations.flatMap(_.consumes))).distinct,
      (protocols ::: endpoints.flatMap(_.operations.flatMap(_.schemes))).distinct,
      endpoints,
      s.models.toMap,
      (authorizations ::: endpoints.flatMap(_.operations.flatMap(_.authorizations))).distinct,
      0)
  }
}

case class Api(
  apiVersion: String,
  swaggerVersion: String,
  resourcePath: String,
  description: Option[String] = None,
  produces: List[String] = Nil,
  consumes: List[String] = Nil,
  protocols: List[String] = Nil,
  apis: List[Endpoint] = Nil,
  models: Map[String, Model] = Map.empty,
  authorizations: List[String] = Nil,
  position: Int = 0) {
  def model(name: String) = models.get(name)
}

object ParamType extends Enumeration {
  type ParamType = Value

  /** A parameter carried in a POST body. **/
  val Body = Value("body")

  /**
   * A parameter carried on the query string.
   *
   * E.g. http://example.com/foo?param=2
   */
  val Query = Value("query")

  /**
   * A path parameter mapped to a Scalatra route.
   *
   * E.g. http://example.com/foo/2 where there's a route like
   * get("/foo/:id").
   */
  val Path = Value("path")

  /** A parameter carried in an HTTP header. **/
  val Header = Value("header")

  val File = Value("file")

  val Form = Value("form")
}

sealed trait DataType {
  def name: String
}

object DataType {

  case class ValueDataType(name: String, format: Option[String] = None, qualifiedName: Option[String] = None) extends DataType
  case class ContainerDataType(name: String, typeArg: Option[DataType] = None, uniqueItems: Boolean = false) extends DataType

  val Void = DataType("void")
  val String = DataType("string")
  val Byte = DataType("string", Some("byte"))
  val Int = DataType("integer", Some("int32"))
  val Long = DataType("integer", Some("int64"))
  val Float = DataType("number", Some("float"))
  val Double = DataType("number", Some("double"))
  val Boolean = DataType("boolean")
  val Date = DataType("string", Some("date"))
  val DateTime = DataType("string", Some("date-time"))

  object GenList {
    def apply(): DataType = ContainerDataType("List")
    def apply(v: DataType): DataType = new ContainerDataType("List", Some(v))
  }

  object GenSet {
    def apply(): DataType = ContainerDataType("Set", uniqueItems = true)
    def apply(v: DataType): DataType = new ContainerDataType("Set", Some(v), uniqueItems = true)
  }

  object GenArray {
    def apply(): DataType = ContainerDataType("Array")
    def apply(v: DataType): DataType = new ContainerDataType("Array", Some(v))
  }

  object GenMap {
    def apply(): DataType = ContainerDataType("Map")
    def apply(v: DataType): DataType = new ContainerDataType("Map", Some(v))
  }

  def apply(name: String, format: Option[String] = None, qualifiedName: Option[String] = None) =
    new ValueDataType(name, format, qualifiedName)
  def apply[T](implicit mf: Manifest[T]): DataType = fromManifest[T](mf)

  private[this] val StringTypes = Set[Class[_]](classOf[String], classOf[java.lang.String])
  private[this] def isString(klass: Class[_]) = StringTypes contains klass
  private[this] val BoolTypes = Set[Class[_]](classOf[Boolean], classOf[java.lang.Boolean])
  private[this] def isBool(klass: Class[_]) = BoolTypes contains klass

  private[swagger] def fromManifest[T](implicit mf: Manifest[T]): DataType = {
    fromScalaType(Reflector.scalaTypeOf[T])
  }
  private[swagger] def fromClass(klass: Class[_]): DataType = fromScalaType(Reflector.scalaTypeOf(klass))
  private[swagger] def fromScalaType(st: ScalaType): DataType = {
    val klass = if (st.isOption && st.typeArgs.nonEmpty) st.typeArgs.head.erasure else st.erasure
    if (classOf[Unit].isAssignableFrom(klass) || classOf[Void].isAssignableFrom(klass)) this.Void
    else if (isString(klass)) this.String
    else if (classOf[Byte].isAssignableFrom(klass) || classOf[java.lang.Byte].isAssignableFrom(klass)) this.Byte
    else if (classOf[Long].isAssignableFrom(klass) || classOf[java.lang.Long].isAssignableFrom(klass)) this.Long
    else if (isInt(klass)) this.Int
    else if (classOf[Float].isAssignableFrom(klass) || classOf[java.lang.Float].isAssignableFrom(klass)) this.Float
    else if (isDecimal(klass)) this.Double
    else if (isDate(klass)) this.Date
    else if (isDateTime(klass)) this.DateTime
    else if (isBool(klass)) this.Boolean
    else if (classOf[scala.collection.Set[_]].isAssignableFrom(klass) || classOf[java.util.Set[_]].isAssignableFrom(klass)) {
      if (st.typeArgs.nonEmpty) GenSet(fromScalaType(st.typeArgs.head))
      else GenSet()
    } else if (classOf[collection.Seq[_]].isAssignableFrom(klass) || classOf[java.util.List[_]].isAssignableFrom(klass)) {
      if (st.typeArgs.nonEmpty) GenList(fromScalaType(st.typeArgs.head))
      else GenList()
    } else if (st.isMap) {
      if (st.typeArgs.nonEmpty) GenMap(fromScalaType(st.typeArgs.last))
      else GenMap()
    } else if (st.isArray || isCollection(klass)) {
      if (st.typeArgs.nonEmpty) GenArray(fromScalaType(st.typeArgs.head))
      else GenArray()
    } else {
      val stt = if (st.isOption) st.typeArgs.head else st
      new ValueDataType(stt.simpleName, qualifiedName = Option(stt.fullName))
    }
  }

  private[this] val IntTypes =
    Set[Class[_]](classOf[Int], classOf[java.lang.Integer], classOf[Short], classOf[java.lang.Short], classOf[BigInt], classOf[java.math.BigInteger])
  private[this] def isInt(klass: Class[_]) = IntTypes.contains(klass)

  private[this] val DecimalTypes =
    Set[Class[_]](classOf[Double], classOf[java.lang.Double], classOf[BigDecimal], classOf[java.math.BigDecimal])
  private[this] def isDecimal(klass: Class[_]) = DecimalTypes contains klass

  private[this] val DateTypes =
    Set[Class[_]](classOf[java.time.LocalDate])
  private[this] def isDate(klass: Class[_]) = DateTypes.exists(_.isAssignableFrom(klass))
  private[this] val DateTimeTypes =
    Set[Class[_]](classOf[JDate], classOf[java.time.LocalDateTime], classOf[java.time.ZonedDateTime], classOf[java.time.OffsetDateTime],
      classOf[java.time.Instant])
  private[this] def isDateTime(klass: Class[_]) = DateTimeTypes.exists(_.isAssignableFrom(klass))

  private[this] def isCollection(klass: Class[_]) =
    classOf[collection.Traversable[_]].isAssignableFrom(klass) ||
      classOf[java.util.Collection[_]].isAssignableFrom(klass)

}

case class ApiInfo(
  title: String,
  description: String,
  termsOfServiceUrl: String,
  contact: ContactInfo,
  license: LicenseInfo)

case class ContactInfo(
  name: String,
  url: String,
  email: String)

case class LicenseInfo(
  name: String,
  url: String)

trait AllowableValues

object AllowableValues {
  case object AnyValue extends AllowableValues
  case class AllowableValuesList[T](values: List[T]) extends AllowableValues
  case class AllowableRangeValues(values: Range) extends AllowableValues

  def apply(): AllowableValues = empty
  def apply[T](values: T*): AllowableValues = apply(values.toList)
  def apply[T](values: List[T]): AllowableValues = AllowableValuesList(values)
  def apply(values: Range): AllowableValues = AllowableRangeValues(values)
  def empty = AnyValue
}

case class Parameter(
  name: String,
  `type`: DataType,
  description: Option[String] = None,
  paramType: ParamType.ParamType = ParamType.Query,
  defaultValue: Option[String] = None,
  allowableValues: AllowableValues = AllowableValues.AnyValue,
  required: Boolean = true,
  // TODO Add collectionFormat: Option[String] for Swagger 2.0
  position: Int = 0,
  example: Option[String] = None,
  minimumValue: Option[Double] = None,
  maximumValue: Option[Double] = None,
  hidden: Boolean = false)

case class ModelProperty(
  `type`: DataType,
  position: Int = 0,
  required: Boolean = false,
  description: Option[String] = None,
  allowableValues: AllowableValues = AllowableValues.AnyValue,
  example: Option[String] = None,
  default: Option[String] = None,
  minimumValue: Option[Double] = None,
  maximumValue: Option[Double] = None,
  hidden: Boolean = false)

case class Model(
  id: String,
  name: String,
  qualifiedName: Option[String] = None,
  description: Option[String] = None,
  properties: List[(String, ModelProperty)] = Nil,
  baseModel: Option[String] = None,
  discriminator: Option[String] = None) {

  def setRequired(property: String, required: Boolean): Model = {
    val prop = properties.find(_._1 == property).get
    copy(properties = (property -> prop._2.copy(required = required)) :: properties)
  }

  def getVisibleProperties: Seq[(String, ModelProperty)] = properties.filter(!_._2.hidden)
}

case class LoginEndpoint(url: String)
case class TokenRequestEndpoint(url: String, clientIdName: String, clientSecretName: String)
case class TokenEndpoint(url: String, tokenName: String)

trait AuthorizationType {
  def `type`: String
  def keyName: String
  def description: String
}
case class OAuth(
  scopes: List[String],
  grantTypes: List[GrantType],
  keyName: String = "oauth2",
  description: String = "") extends AuthorizationType {
  override val `type` = "oauth2"
}
case class ApiKey(keyName: String, passAs: String = "header", description: String = "") extends AuthorizationType {
  override val `type` = "apiKey"
}

case class BasicAuth(keyName: String, description: String = "") extends AuthorizationType {
  override val `type` = "basic"
}

trait GrantType {
  def `type`: String
}
case class ImplicitGrant(
  loginEndpoint: LoginEndpoint,
  tokenName: String) extends GrantType {
  def `type` = "implicit"
}
case class AuthorizationCodeGrant(
  tokenRequestEndpoint: TokenRequestEndpoint,
  tokenEndpoint: TokenEndpoint) extends GrantType {
  def `type` = "authorization_code"
}
case class ApplicationGrant(
  tokenEndpoint: TokenEndpoint) extends GrantType {
  def `type` = "application"
}

case class Operation(
  method: HttpMethod,
  operationId: String,
  responseClass: DataType,
  summary: String,
  position: Int,
  description: Option[String] = None,
  deprecated: Boolean = false,
  parameters: List[Parameter] = Nil,
  responseMessages: List[ResponseMessage] = Nil,
  consumes: List[String] = Nil,
  produces: List[String] = Nil,
  schemes: List[String] = Nil,
  authorizations: List[String] = Nil,
  tags: List[String] = Nil) {

  def getVisibleParameters: List[Parameter] = parameters.filter(!_.hidden)
}

case class Endpoint(
  path: String,
  description: Option[String] = None,
  operations: List[Operation] = Nil)

case class ResponseMessage(code: Int, message: String, responseModel: Option[String] = None)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy