All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tscfg.generators.scala.ScalaGen.scala Maven / Gradle / Ivy

The newest version!
package tscfg.generators.scala

import tscfg.{ModelBuilder, Namespace, model, util}
import tscfg.generators._
import tscfg.model._
import tscfg.util.escapeString
import tscfg.codeDefs.scalaDef


class ScalaGen(genOpts: GenOpts) extends Generator(genOpts) {
  val accessors = new Accessors

  import defs._

  implicit val methodNames = MethodNames()
  val getter = Getter(genOpts, hasPath, accessors, methodNames)

  import methodNames._

  val scalaUtil: ScalaUtil = new ScalaUtil(useBackticks = genOpts.useBackticks)

  import scalaUtil.{scalaIdentifier, getClassName}

  def padScalaIdLength(implicit symbols:List[String]): Int =
    if (symbols.isEmpty) 0 else
      symbols.map(scalaIdentifier).maxBy(_.length).length

  def padId(id: String)(implicit symbols:List[String]): String = id + (" " * (padScalaIdLength - id.length))

  def generate(objectType: ObjectType): GenResult = {
    genResults = GenResult()

    checkUserSymbol(className)
    val res = generateForObj(objectType, className = className, isRoot = true)

    val packageStr = s"package ${genOpts.packageName}\n\n"

    val definition = (packageStr + res.definition).trim
    genResults.copy(code = {
      if (util.doFormatting)
        formatter.format(genOpts.packageName, definition)
      else
        definition
    })
  }

  private def generate(typ: Type,
                       classNamesPrefix: List[String],
                       className: String,
                       parentClassName: Option[String] = None,
                       parentClassMembers: Option[Map[String, model.AnnType]] = None,
                      ): Res = typ match {

    case et: EnumObjectType => generateForEnum(et, classNamesPrefix, className)

    case ot: ObjectType =>
      generateForObj(ot, classNamesPrefix, className, parentClassName = parentClassName, parentClassMembers = parentClassMembers)

    case aot: AbstractObjectType =>
      generateForAbstractObj(aot, classNamesPrefix, className, parentClassName, parentClassMembers);

    case ort: ObjectRefType => generateForObjRef(ort, classNamesPrefix)

    case lt: ListType => generateForList(lt, classNamesPrefix, className)

    case bt: BasicType => generateForBasic(bt)
  }

  def buildClassMembersString(classData: List[(String, Res, AnnType, Boolean)],
                              padId: String => String,
                              isAbstractClass: Boolean = false): String = {
    classData.flatMap { case (symbol, res, a, isChildClass) =>
      if (a.isDefine) None
      else Some {
        val memberType = res.scalaType
        val typ = if (a.optional && a.default.isEmpty) s"scala.Option[$memberType]" else memberType
        val scalaId = scalaIdentifier(symbol)
        val modifiers = (isChildClass, isAbstractClass) match {
          case (true, _) =>
            /* The field will be overridden, therefore it needs override modifier and "val" keyword */
            "override val "
          case (false, true) =>
            /* The field does not override anything, but is part of abstract class => needs "val" keyword */
            "val "
          case _ => ""
        }
        val type_ = a.t match {
          case ot: ObjectRefType =>
            getClassNameForObjectRefType(ot)
          case _ => typ
        }
        genResults = genResults.copy(fields = genResults.fields + (scalaId -> type_.toString))
        modifiers + padId(scalaId) + " : " + type_ + dbg("E")
      }
    }.mkString(",\n  ")
  }

  private def generateForObj(ot: ObjectType,
                             classNamesPrefix: List[String] = List.empty,
                             className: String,
                             isRoot: Boolean = false,
                             parentClassName: Option[String] = None,
                             parentClassMembers: Option[Map[String, model.AnnType]] = None
                            ): Res = {

    genResults = genResults.copy(classNames = genResults.classNames + className)

    implicit val symbols = ot.members.keys.toList.sorted
    symbols.foreach(checkUserSymbol)

    val results = symbols.map { symbol =>
      val a = ot.members(symbol)
      val res = generate(a.t,
        classNamesPrefix = className + "." :: classNamesPrefix,
        className = getClassName(symbol),
        a.abstractClass,
        a.parentClassMembers,
      )
      (symbol, res, a, false)
    }

    val parentClassMemberResults = parentClassMembers.map(_.keys.toList.sorted).map(parentSymbols => {
      parentSymbols.foreach(checkUserSymbol)
      parentSymbols.map { symbol =>
        val a = parentClassMembers.get(symbol)
        val res = generate(a.t,
          classNamesPrefix = className + "." :: classNamesPrefix,
          className = getClassName(symbol),
          a.abstractClass,
          a.parentClassMembers,
        )
        (symbol, res, a, true)
      }
    }).getOrElse(List.empty)

    val classMembersStr = buildClassMembersString(parentClassMemberResults ++ results, padId)

    val parentClassString = parentClassName.map(" extends " + _ + "(" +
      parentClassMemberResults.map(_._1).mkString(",") + ")").getOrElse("")

    val classStr =
      s"""final case class $className(
         |  $classMembersStr
         |)$parentClassString
         |""".stripMargin


    implicit val listAccessors = collection.mutable.LinkedHashMap[String, String]()

    val objectMembersStr = (results ++ parentClassMemberResults).flatMap { case (symbol, res, a, parentOv) =>
      if (a.isDefine) None
      else Some {
        val scalaId = scalaIdentifier(symbol)
        padId(scalaId) + " = " + getter.instance(a, res, path = escapeString(symbol))
      }
    }.mkString(",\n      ")

    val innerClassesStr = {
      val defs = results.map(_._2.definition).filter(_.nonEmpty)
      if (defs.isEmpty) "" else {
        "\n  " + defs.mkString("\n").replaceAll("\n", "\n  ")
      }
    }

    val elemAccessorsStr = {
      val objOnes = if (listAccessors.isEmpty) "" else {
        "\n" + listAccessors.keys.toList.sorted.map { methodName =>
          listAccessors(methodName)
        }.mkString("\n")
      }
      val rootOnes = if (!isRoot) "" else {
        if (accessors.rootListAccessors.isEmpty) "" else {
          "\n\n" + accessors.rootListAccessors.keys.toList.sorted.map { methodName =>
            accessors.rootListAccessors(methodName)
          }.mkString("\n")
        }
      }
      objOnes + rootOnes
    }

    val rootAuxClasses = if (isRoot) {
      scalaDef("$TsCfgValidator")
    }
    else ""

    val (ctorParams, errHandlingDecl, errHandlingDispatch) = if (isRoot) {
      ("c: com.typesafe.config.Config",
        """val $tsCfgValidator: $TsCfgValidator = new $TsCfgValidator()
          |    val parentPath: java.lang.String = ""
          |    val $result = """.stripMargin,

        s"""
           |    $$tsCfgValidator.validate()
           |    $$result""".stripMargin
      )
    }
    else (
      "c: com.typesafe.config.Config, parentPath: java.lang.String, $tsCfgValidator: $TsCfgValidator",
      "",
      ""
    )

    val fullClassName = classNamesPrefix.reverse.mkString + className
    val objectString = {
      s"""object $className {$innerClassesStr
         |  def apply($ctorParams): $fullClassName = {
         |    $errHandlingDecl$fullClassName(
         |      $objectMembersStr
         |    )$errHandlingDispatch
         |  }$elemAccessorsStr
         |$rootAuxClasses}
      """.stripMargin
    }

    val baseType = classNamesPrefix.reverse.mkString + className
    Res(ot,
      scalaType = BaseScalaType(baseType),
      definition = classStr + objectString
    )
  }

  private def generateForAbstractObj(aot: AbstractObjectType,
                                     classNamesPrefix: List[String],
                                     className: String,
                                     parentClassName: Option[String] = None,
                                     parentClassMembers: Option[Map[String, model.AnnType]] = None): Res = {

    genResults = genResults.copy(classNames = genResults.classNames + className)

    implicit val symbols = aot.members.keys.toList.sorted
    symbols.foreach(checkUserSymbol)

    val results = symbols.map { symbol =>
      val a = aot.members(symbol)
      val res = generate(a.t,
        classNamesPrefix = className + "." :: classNamesPrefix,
        className = getClassName(symbol),
        a.abstractClass,
        a.parentClassMembers,
      )
      (symbol, res, a, false)
    }

    /* Consider parent abstract classes */
    val parentClassMemberResults = parentClassMembers.map(_.keys.toList.sorted).map(parentSymbols => {
      parentSymbols.foreach(checkUserSymbol)
      parentSymbols.map { symbol =>
        val a = parentClassMembers.get(symbol)
        val res = generate(a.t,
          classNamesPrefix = className + "." :: classNamesPrefix,
          className = getClassName(symbol),
          a.abstractClass,
          a.parentClassMembers,
        )
        (symbol, res, a, true)
      }
    }).getOrElse(List.empty)

    val abstractClassMembersStr = buildClassMembersString(parentClassMemberResults ++ results, padId, isAbstractClass = true)

    val parentClassString = parentClassName.map(" extends " + _ + "(" +
      parentClassMemberResults.map(_._1).mkString(",") + ")").getOrElse("")

    val abstractClassStr =
      s"""sealed abstract class $className (
         | $abstractClassMembersStr
         |)$parentClassString
         |""".stripMargin

    val baseType = classNamesPrefix.reverse.mkString + className

    Res(aot,
      scalaType = BaseScalaType(baseType),
      definition = abstractClassStr
    )

  }

  private def generateForObjRef(ort: ObjectRefType,
                                classNamesPrefix: List[String]): Res = {

    val className = getClassName(ort.simpleName)
    genResults = genResults.copy(classNames = genResults.classNames + className)

    val fullScalaName = getClassNameForObjectRefType(ort)

    Res(ort,
      scalaType = BaseScalaType(fullScalaName + dbg(""))
    )
  }

  private def getClassNameForObjectRefType(ot: ObjectRefType): String = {
    val className = getClassName(ot.simpleName)
    val namespace = Namespace.resolve(ot.namespace)
    val fullScalaName = if (namespace.isRoot)
      s"${genOpts.className}.$className"
    else
      (namespace.getPath.map(getClassName) ++ Seq(className)).mkString(".")

    scribe.debug(s"getClassNameForObjectRefType:" +
      s" simpleName=${ot.simpleName}" +
      s" className=$className fullScalaName=$fullScalaName")

    fullScalaName
  }

  private def generateForList(lt: ListType,
                              classNamesPrefix: List[String],
                              className: String
                             ): Res = {
    val className2 = className + (if (className.endsWith("$Elm")) "" else "$Elm")
    val elem = generate(lt.t, classNamesPrefix, className2)
    Res(lt,
      scalaType = ListScalaType(elem.scalaType),
      definition = elem.definition
    )
  }

  private def generateForBasic(b: BasicType): Res = {
    Res(b, scalaType = BaseScalaType(name = b match {
      case STRING => "java.lang.String"
      case INTEGER => "scala.Int"
      case LONG => "scala.Long"
      case DOUBLE => "scala.Double"
      case BOOLEAN => "scala.Boolean"
      case SIZE => "scala.Long"
      case DURATION(_) => if (genOpts.useDurations) "java.time.Duration" else "scala.Long"
    }))
  }

  private def generateForEnum(et: EnumObjectType,
                              classNamesPrefix: List[String] = List.empty,
                              className: String,
                             ): Res = {
    scribe.debug(s"generateForEnum: className=$className classNamesPrefix=$classNamesPrefix")
    genResults = genResults.copy(classNames = genResults.classNames + className)

    /// Example:
    //  sealed trait FruitType
    //  object FruitType {
    //    object apple extends FruitType
    //    object banana extends FruitType
    //    object pineapple extends FruitType
    //    def $resEnum(name: java.lang.String, path: java.lang.String, $tsCfgValidator: $TsCfgValidator): FruitType = name match {
    //      case "apple" => FruitType.apple
    //      case "banana" => FruitType.banana
    //      case "pineapple" => FruitType.pineapple
    //      case v => $tsCfgValidator.addInvalidEnumValue(path, v, "FruitType")
    //                null
    //    }
    //  }

    val resolve =
      s"""def $$resEnum(name: java.lang.String, path: java.lang.String, $$tsCfgValidator: $$TsCfgValidator): $className = name match {
         |  ${et.members.map(m => s"""case "$m" => $className.$m""").mkString("\n  ")}
         |  case v => $$tsCfgValidator.addInvalidEnumValue(path, v, "$className")
         |            null
         |}""".stripMargin

    val str =
      s"""|sealed trait $className
         |object $className {
         |  ${et.members.map(m => s"object $m extends $className").mkString("\n  ")}
         |  ${resolve.replaceAll("\n", "\n  ")}
         |}""".stripMargin

    val baseType = classNamesPrefix.reverse.mkString + className
    Res(et,
      scalaType = BaseScalaType(baseType + dbg("")),
      definition = str
    )
  }
}

object ScalaGen {

  import _root_.java.io.{File, PrintWriter, FileWriter}

  import tscfg.util

  // $COVERAGE-OFF$
  def generate(filename: String,
               j7: Boolean = false,
               assumeAllRequired: Boolean = false,
               showOut: Boolean = false,
               s12: Boolean = false,
               useBackticks: Boolean = false
              ): GenResult = {
    val file = new File("src/main/tscfg/" + filename)
    val source = io.Source.fromFile(file)
    val sourceStr = source.mkString.trim
    source.close()

    if (showOut)
      println("source:\n  |" + sourceStr.replaceAll("\n", "\n  |"))

    val className = "Scala" + {
      val noPath = filename.substring(filename.lastIndexOf('/') + 1)
      val noDef = noPath.replaceAll("""^def\.""", "")
      val symbol = noDef.substring(0, noDef.indexOf('.'))
      util.upperFirst(symbol) + "Cfg"
    }

    val buildResult = ModelBuilder(sourceStr, assumeAllRequired = assumeAllRequired)
    val objectType = buildResult.objectType
    if (showOut) {
      println("\nobjectType:\n  |" + model.util.format(objectType).replaceAll("\n", "\n  |"))
      if (buildResult.warnings.nonEmpty) {
        println("warnings:")
        buildResult.warnings.foreach(w => println(s"   line ${w.line}: ${w.source}: ${w.message}"))
      }
    }

    val genOpts = GenOpts("tscfg.example", className, j7 = j7,
      useBackticks = useBackticks, s12 = s12)

    val generator = new ScalaGen(genOpts)

    val results = generator.generate(objectType)

    val destFilename = s"src/test/scala/tscfg/example/$className.scala"
    val destFile = new File(destFilename)
    val out = new PrintWriter(new FileWriter(destFile), true)
    out.println(results.code)
    results
  }

  def main(args: Array[String]): Unit = {
    val filename = args(0)
    val results = generate(filename, showOut = true)
    println(
      s"""classNames: ${results.classNames}
         |fields    : ${results.fields}
      """.stripMargin)
  }

  // $COVERAGE-ON$
}

private[scala] object defs {

  abstract sealed class ScalaType

  case class BaseScalaType(name: String) extends ScalaType {
    override def toString: String = name
  }

  case class ListScalaType(st: ScalaType) extends ScalaType {
    override def toString: String = s"scala.List[$st]"
  }

  case class Res(typ: Type,
                 scalaType: ScalaType,
                 definition: String = "")

}

private[scala] case class MethodNames() {
  val strA = "$_str"
  val intA = "$_int"
  val lngA = "$_lng"
  val dblA = "$_dbl"
  val blnA = "$_bln"
  val durA = "$_dur"
  val sizA = "$_siz"
  val expE = "$_expE"
  val listPrefix = "$_L"
  val requireName = "$_require"

  def checkUserSymbol(symbol: String): Unit = {
    if (symbol.startsWith("$_"))
      println(
        s"""
           |WARNING: Symbol $symbol may cause conflict with generated code.
           |         Avoid the $$_ prefix in your spec's identifiers.
         """.stripMargin
      )
  }

  // definition of methods used to access list's elements of basic type
  val basicElemAccessDefinition: Map[String, String] = {
    List(strA, intA, lngA, dblA, blnA, sizA)
      .map(k => k -> scalaDef(k))
      .toMap
  }

  val expEDef: String = scalaDef(expE)

  val requireDef: String = scalaDef(requireName)
}

private[scala] case class Getter(genOpts: GenOpts, hasPath: String, accessors: Accessors, implicit val methodNames: MethodNames) {

  import defs._

  def instance(a: AnnType, res: Res, path: String)
              (implicit listAccessors: collection.mutable.LinkedHashMap[String, String]): String = {

    val objRefResolution: Option[String] = a.t match {
      case ort: ObjectRefType => objectRefInstance(ort, res, path)
      case _ => None
    }

    objRefResolution.getOrElse {
      a.t match {
        case bt: BasicType => basicInstance(a, bt, path)
        case _: ObjectAbsType => objectInstance(a, res, path)
        case lt: ListType => listInstance(a, lt, res, path)
      }
    }
  }

  private def objectRefInstance(ort: ObjectRefType, res: Res, path: String): Option[String] = {
    val namespace = Namespace.resolve(ort.namespace)
    namespace.getDefine(ort.simpleName) flatMap { t =>
      t match {
        case _: EnumObjectType => Some(enumInstance(res, path))
        case _ => None
      }
    }
  }

  private def enumInstance(res: Res, path: String): String = {
    val className = res.scalaType.toString

    //// Example:
    // fruit = FruitType.$resEnum(c.getString("fruit"), parentPath + "fruit", $tsCfgValidator)

    s"""$className.$$resEnum(c.getString("$path"), parentPath + "$path", $$tsCfgValidator)"""
  }

  private def objectInstance(a: AnnType, res: Res, path: String)
                            (implicit listAccessors: collection.mutable.Map[String, String]): String = {
    val className = res.scalaType.toString

    val ppArg = s""", parentPath + "$path.", $$tsCfgValidator"""

    def reqConfigCall = {
      val methodName = "$_reqConfig"
      listAccessors += methodName -> scalaDef(methodName)
      s"""$methodName(parentPath, c, "$path", $$tsCfgValidator)"""
    }

    if (genOpts.assumeAllRequired)
      s"""$className($reqConfigCall$ppArg)"""

    else if (a.optional) {
      s"""if(c.$hasPath("$path")) scala.Some($className(c.getConfig("$path")$ppArg)) else None"""
    }
    else {
      // TODO revisit #33 handling of object as always optional
      s"""$className(if(c.$hasPath("$path")) c.getConfig("$path") else com.typesafe.config.ConfigFactory.parseString("$path{}")$ppArg)"""
    }
  }

  private def listInstance(a: AnnType, lt: ListType, res: Res, path: String)
                          (implicit listAccessors: collection.mutable.Map[String, String]
                          ): String = {
    val scalaType: ListScalaType = res.scalaType.asInstanceOf[ListScalaType]
    val base = accessors.listMethodName(scalaType, lt, path, genOpts.s12)
    if (a.optional) {
      s"""if(c.$hasPath("$path")) scala.Some($base) else None"""
    }
    else base
  }

  private def basicInstance(a: AnnType, bt: BasicType, path: String)
                           (implicit listAccessors: collection.mutable.Map[String, String]): String = {
    val getter = tsConfigUtil.basicGetter(bt, path, genOpts.useDurations)

    a.default match {
      case Some(v) =>
        val value = tsConfigUtil.basicValue(a.t, v, useDurations = genOpts.useDurations)
        (bt, value) match {
          case (BOOLEAN, "true") => s"""!c.$hasPath("$path") || c.$getter"""
          case (BOOLEAN, "false") => s"""c.$hasPath("$path") && c.$getter"""
          case (DURATION(qs), duration) if genOpts.useDurations => s"""if(c.$hasPath("$path")) c.$getter else java.time.Duration.parse("$duration")"""
          case _ => s"""if(c.$hasPath("$path")) c.$getter else $value"""
        }

      case None if a.optional =>
        s"""if(c.$hasPath("$path")) Some(c.$getter) else None"""

      case _ =>
        bt match {
          case DURATION(_) => s"""c.$getter"""
          case _ =>
            val (methodName, methodCall) = tsConfigUtil.basicRequiredGetter(bt, path, genOpts.useDurations)
            listAccessors += methodName -> scalaDef(methodName)
            methodCall
        }
    }
  }
}

private[scala] class Accessors {

  import defs._

  val rootListAccessors = collection.mutable.LinkedHashMap[String, String]()

  def listMethodName(scalaType: ListScalaType,
                     lt: ListType,
                     path: String,
                     s12: Boolean
                    )
                    (implicit listAccessors: collection.mutable.Map[String, String],
                     methodNames: MethodNames
                    ): String = {

    val (_, methodName) = rec(scalaType, lt, "", s12)
    methodName + s"""(c.getList("$path"), parentPath, $$tsCfgValidator)"""
  }

  private def rec(lst: ListScalaType, lt: ListType, prefix: String, s12: Boolean
                 )(implicit listAccessors: collection.mutable.Map[String, String],
                   methodNames: MethodNames
                 ): (Boolean, String) = {

    val (isBasic, elemMethodName) = lst.st match {
      case bst: BaseScalaType =>
        val basic = lt.t.isInstanceOf[BasicType]
        val methodName = baseName(lt.t, bst.toString)
        if (basic) {
          rootListAccessors += methodName -> methodNames.basicElemAccessDefinition(methodName)
          rootListAccessors += methodNames.expE -> methodNames.expEDef
        }
        (basic, methodName)

      case lst: ListScalaType =>
        rec(lst, lt.t.asInstanceOf[ListType], prefix + methodNames.listPrefix, s12)
    }

    val (methodName, methodBody) = listMethodDefinition(elemMethodName, lst.st, s12, lt)

    if (isBasic)
      rootListAccessors += methodName -> methodBody
    else
      listAccessors += methodName -> methodBody

    (isBasic, methodName)
  }

  private def baseName(t: Type, name: String)
                      (implicit methodNames: MethodNames): String = t match {
    case STRING => methodNames.strA
    case INTEGER => methodNames.intA
    case LONG => methodNames.lngA
    case DOUBLE => methodNames.dblA
    case BOOLEAN => methodNames.blnA
    case SIZE => methodNames.sizA
    case DURATION(_) => methodNames.durA

    case _: ObjectAbsType => name.replace('.', '_')

    case _: ListType => throw new AssertionError()
  }

  def listMethodDefinition(elemMethodName: String, scalaType: ScalaType, s12: Boolean, lt: ListType)
                          (implicit methodNames: MethodNames): (String, String) = {

    val elem = if (elemMethodName.startsWith(methodNames.listPrefix)) {
      s"$elemMethodName(cv.asInstanceOf[com.typesafe.config.ConfigList], parentPath, $$tsCfgValidator)"
    }
    else if (elemMethodName.startsWith("$")) {
      s"$elemMethodName(cv)"
    }
    else {
      val adjusted = elemMethodName.replace("_", ".")
      val objRefResolution = lt.t match {
        case ort:ObjectRefType =>
          val namespace = Namespace.resolve(ort.namespace)
          namespace.getDefine(ort.simpleName) flatMap { t =>
            t match {
              case _: EnumObjectType =>
              // TODO some more useful path (for now just "?" below)
                Some(s"""$adjusted.$$resEnum(cv.unwrapped().toString, "?", $$tsCfgValidator)""")
              case _ => None
            }
          }

        case _ => None
      }
      objRefResolution.getOrElse {
        s"$adjusted(cv.asInstanceOf[com.typesafe.config.ConfigObject].toConfig, parentPath, $$tsCfgValidator)"
      }
    }

    val methodName = methodNames.listPrefix + elemMethodName
    val scalaCollectionConverter = if (s12) "scala.collection.JavaConverters._" else "scala.jdk.CollectionConverters._"
    val methodDef =
      s"""  private def $methodName(cl:com.typesafe.config.ConfigList, parentPath: java.lang.String, $$tsCfgValidator: $$TsCfgValidator): scala.List[$scalaType] = {
         |    import $scalaCollectionConverter
         |    cl.asScala.map(cv => $elem).toList
         |  }""".stripMargin
    (methodName, methodDef)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy