scalaxb.compiler.Module.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scalaxb_2.13 Show documentation
Show all versions of scalaxb_2.13 Show documentation
scalaxb is an XML data-binding tool for Scala that supports W3C XML Schema (xsd) and wsdl.
The newest version!
/*
* Copyright (c) 2010 e.e d3si9n
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package scalaxb.compiler
import java.net.{URI}
import scala.xml.{Node, Elem, UnprefixedAttribute, NamespaceBinding}
import scala.xml.factory.{XMLLoader}
import javax.xml.parsers.SAXParser
import java.io.{File, PrintWriter, Reader, BufferedReader}
import scala.collection.mutable
import scala.collection.mutable.{ListBuffer, ListMap}
import ConfigEntry._
object Snippet {
def apply(snippets: Snippet*): Snippet =
Snippet(snippets flatMap { s => s.companion ++ s.definition},
Nil,
snippets flatMap {_.defaultFormats},
snippets flatMap {_.implicitValue},
snippets flatMap {_.elemToTypeClauses})
}
case class Snippet(definition: Seq[Node] = Nil,
companion: Seq[Node] = Nil,
defaultFormats: Seq[Node] = Nil,
implicitValue: Seq[Node] = Nil,
elemToTypeClauses: Seq[Node] = Nil)
trait CanBeWriter[A] {
def toWriter(value: A): PrintWriter
def newInstance(packageName: Option[String], fileName: String): A
}
object CanBeWriter {
implicit val stringWriter: CanBeWriter[java.io.StringWriter] = new CanBeWriter[java.io.StringWriter] {
override def toWriter(value: java.io.StringWriter) = new PrintWriter(value)
override def newInstance(packageName: Option[String], fileName: String) = new java.io.StringWriter
}
}
trait CanBeRawSchema[A, B] {
def toRawSchema(value: A): B
def toURI(value: A): URI
}
object Module {
import scala.util.matching.Regex
val NL = System.getProperty("line.separator")
val FileExtension = """.*([.]\w+)$""".r
def moduleByFileName(file: File): Module = {
file.toString match {
case FileExtension(".wsdl") =>
new scalaxb.compiler.wsdl11.Driver
case _ =>
new scalaxb.compiler.xsd.Driver
}
}
def splitTypeName(value: String, scope: scala.xml.NamespaceBinding) = masked.scalaxb.Helper.splitQName(value, scope)
def indent(n: Int) = " " * n
def camelCase(name: String): String = {
val (cap, rest) = name span {_.isUpper}
cap.size match {
case x if (x == 0) || (x == 1) || (x == name.size) => cap.toLowerCase + rest
case x => (cap take (x - 1)).toLowerCase + (cap drop (x - 1)) + rest
}
}
}
trait Module {
type RawSchema
type Schema
type Context
case class CompileSource[From](context: Context,
schemas: ListMap[Importable, Schema],
importables: collection.Seq[(Importable, From)],
additionalImportables: ListMap[Importable, File],
firstNamespace: Option[String]) {
}
private val logger = Log.forName("module")
def verbose: Boolean = false
val encoding = "UTF-8"
val newline = System.getProperty("line.separator")
trait Importable {
def targetNamespace: Option[String]
def importNamespaces: Seq[String]
def importLocations: Seq[String]
def includeLocations: Seq[String]
def raw: RawSchema
def location: URI
def toSchema(context: Context): Schema
def swapTargetNamespace(outerNamespace: Option[String], n: Int): Importable
}
implicit val fileReader: CanBeRawSchema[File, RawSchema] = new CanBeRawSchema[File, RawSchema] {
override def toRawSchema(value: File) = readerToRawSchema(UnicodeFileReader.reader(value))
override def toURI(value: File) = value.toURI
}
val stringReader = new CanBeRawSchema[String, RawSchema] {
override def toRawSchema(value: String) = readerToRawSchema(new java.io.StringReader(value))
override def toURI(value: String) = new URI("file://C:/temp.txt")
}
val nodeReader = new CanBeRawSchema[Node, RawSchema] {
override def toRawSchema(value: Node) = nodeToRawSchema(value)
override def toURI(value: Node) = new URI("file://C:/temp.txt")
}
def process(file: File, packageName: String, outdir: File): List[File] =
process(file, Config.default.
update(PackageNames(Map(None -> Some(packageName)))).
update(Outdir(outdir)))
def process(file: File, config: Config): List[File] =
processFiles(List(file), config)
def processFiles(files: Seq[File], config: Config): List[File] = {
val (source, outfiles) = infoFiles(files, config)
outfiles map { x =>
println("generated " + x + ".")
logger.info("generated " + x + ".") }
outfiles
}
def infoFiles(files: Seq[File], config: Config): (CompileSource[File], List[File]) = {
implicit val fileWriter = new CanBeWriter[File] {
override def toWriter(value: File) = new PrintWriter(new java.io.OutputStreamWriter(
new java.io.FileOutputStream(value), encoding))
override def newInstance(packageName: Option[String], fileName: String) = {
val dir = if (config.packageDir) packageDir(packageName, config.outdir)
else config.outdir
dir.mkdirs()
new File(dir, fileName)
}
}
files.foreach(file => if (!file.exists)
sys.error("file not found: " + file.toString))
processReaders(files, config)
}
def packageDir(packageName: Option[String], dir: File) = packageName map { x =>
x.split('.').foldLeft(dir) { new File(_, _) }
} getOrElse {dir}
def processString(input: String, packageName: String): List[String] =
processString(input, Config.default.update(PackageNames(Map(None -> Some(packageName)))))
def processString(input: String, config: Config): List[String] =
infoString(input, config)._2
def infoString(input: String, config: Config): (CompileSource[String], List[String]) = {
implicit val ev = stringReader
val (source, result) = processReaders(Seq(input), config)
(source, result map {_.toString})
}
def processNode(input: Node, packageName: String): List[String] =
processNode(input, Config.default.update(PackageNames(Map(None -> Some(packageName)))))
def processNode(input: Node, config: Config): List[String] =
infoNode(input, config)._2
def infoNode(input: Node, config: Config): (CompileSource[Node], List[String]) = {
implicit val ev = nodeReader
val (source, result) = processReaders(Seq(input), config)
(source, result map {_.toString})
}
def headerSnippet(pkg: Option[String]): Snippet =
Snippet()
def processReaders[From, To](files: Seq[From], config: Config)
(implicit ev: CanBeRawSchema[From, RawSchema], evTo: CanBeWriter[To]): (CompileSource[From], List[To]) = {
val source = buildCompileSource(files)
(source, processCompileSource(source, config))
}
def buildCompileSource[From, To](files: Seq[From])
(implicit ev: CanBeRawSchema[From, RawSchema]): CompileSource[From] = {
logger.debug("%s", files.toString())
val context = buildContext
val importables0 = ListMap[From, Importable](files map { f =>
f -> toImportable(ev.toURI(f), ev.toRawSchema(f))}: _*)
val importables = ListBuffer[(Importable, From)](files map { f => importables0(f) -> f }: _*)
val schemas = ListMap[Importable, Schema](importables map { case (importable, file) =>
val s = parse(importable, context)
(importable, s) } toSeq: _*)
val additionalImportables = ListMap.empty[Importable, File]
// recursively add missing files
def addMissingFiles(): Unit = {
val current = (importables map {_._1}) ++ additionalImportables.keysIterator.toList
// check for all dependencies before proceeding.
val missings = (current flatMap { importable =>
missingDependencies(importable, current) }).distinct
val additional = missings flatMap { x =>
val uri = new URI(x)
val file = new File(new File(uri.getPath).getName)
if (file.exists) Some(file)
else None
}
var added = false
additionalImportables ++= (additional map { x =>
logger.warn("added " + x + " to compilation.")
added = true
val importable = toImportable(implicitly[CanBeRawSchema[File, RawSchema]].toURI(x),
implicitly[CanBeRawSchema[File, RawSchema]].toRawSchema(x))
val s = parse(importable, context)
schemas(importable) = s
(importable, x) })
if (added) addMissingFiles()
}
def processUnnamedIncludes(): Unit = {
logger.debug("processUnnamedIncludes")
val all = (importables.toList map {_._1}) ++ (additionalImportables.toList map {_._1})
val parents: ListBuffer[Importable] = ListBuffer(all filter { !_.includeLocations.isEmpty}: _*)
def children(importable: Importable): List[Importable] = {
val uris = importable.includeLocations map { includedLoc => shortenUri(new URI(includedLoc)) }
all filter { x => uris contains shortenUri(x.location) }
}
val mapping: ListMap[Importable, Option[String]] = ListMap()
val used: ListBuffer[Importable] = ListBuffer()
var count: Int = 0
val len = parents.size * parents.size
for {
i <- 0 to len if parents.size > 0
} {
val parent = parents(i % parents.size)
val xs = children(parent)
if (xs forall { x => !(parents contains x) }) {
val tns = mapping.get(parent) getOrElse {parent.targetNamespace}
tns foreach { tnsstr => xs foreach { x =>
x.targetNamespace match {
case Some(ns) =>
case None =>
logger.debug("processUnnamedIncludes - setting %s's outer namespace to %s", x.location, tnsstr)
count += 1
val swap = x.swapTargetNamespace(tns, count)
schemas(swap) = parse(swap, context)
additionalImportables(swap) = new File(swap.location.getPath)
used += x
}
mapping(x) = tns
}}
parents.remove(i % parents.size)
}
}
used.toList.distinct foreach { x =>
schemas -= x
logger.debug("processUnnamedIncludes - removing %s", x.location)
val idx = importables.indexWhere { case (i, _) => i == x }
if (idx >= 0) {
importables remove idx
}
if (additionalImportables contains x) {
additionalImportables -= x
}
}
}
addMissingFiles()
processUnnamedIncludes()
CompileSource(context, schemas, importables, additionalImportables,
importables0(files.head).targetNamespace)
}
// ev: CanBeRawSchema[From, RawSchema]
def processCompileSource[From, To](cs: CompileSource[From], config: Config)
(implicit ev: CanBeRawSchema[From, RawSchema], evTo: CanBeWriter[To]): List[To] = {
val snippets = ListBuffer.empty[Snippet]
def toFileNamePart[From](file: From)(implicit ev: CanBeRawSchema[From, RawSchema]): String =
"""([.]\w+)$""".r.replaceFirstIn(new File(ev.toURI(file).getPath).getName, "")
def processImportables[A](xs: List[(Importable, A)])(implicit ev: CanBeRawSchema[A, RawSchema]) = xs flatMap {
case (importable, file) =>
generate(cs.schemas(importable), toFileNamePart(file), cs.context, config) map { case (pkg, snippet, part) =>
snippets += snippet
val output = evTo.newInstance(pkg, part + ".scala")
val out = evTo.toWriter(output)
try {
printNodes(snippet.definition, out, config)
} finally {
out.flush()
out.close()
}
output
}
}
def processProtocol = {
val pkg = config.protocolPackageName match {
case Some(_) => config.protocolPackageName
case _ => packageName(cs.firstNamespace, cs.context)
}
val output = implicitly[CanBeWriter[To]].newInstance(pkg, config.protocolFileName)
val out = implicitly[CanBeWriter[To]].toWriter(output)
val config2 = config.update(ProtocolPackageName(pkg)).
update(DefaultNamespace(config.defaultNamespace match {
case Some(_) => config.defaultNamespace
case _ => cs.firstNamespace
}))
val protocolNodes = generateProtocol(Snippet(snippets.toSeq: _*), cs.context, config2)
try {
printNodes(protocolNodes, out, config2)
} finally {
out.flush()
out.close()
}
output
}
processContext(cs.context, cs.schemas.valuesIterator.toSeq, config)
cs.schemas.valuesIterator.toSeq foreach { schema =>
processSchema(schema, cs.context, config)
}
processImportables(cs.importables.toList) :::
processImportables(cs.additionalImportables.toList) :::
List(processProtocol) :::
(if (config.generateRuntime) generateRuntimeFiles[To](cs.context, config)
else Nil)
}
def substituteMany(subs: (String, String)*): String => String =
(s: String) => subs.toSeq.foldLeft(s)((acc, x) => acc.replaceAll(x._1, x._2))
private val scala3VarArgSub = """\:\s_\*""" -> "*"
private val scala3With = """scalaxb\.(\S+)\swith\sscalaxb\.(\S+)""" -> """scalaxb.$1 & scalaxb.$2"""
def printNodes(nodes: Seq[Node], out: PrintWriter, config: Config): Unit = {
val subs = ListBuffer.empty[(String, String)]
if (config.isScala3_4Plus) {
subs += scala3VarArgSub
subs += scala3With
printNodes(nodes, out, substituteMany(subs.toList: _*))
} else {
printNodes(nodes, out)
}
}
def generateFromResource[To](
packageName: Option[String],
fileName: String,
resourcePath: String,
config: Config,
subs0: (String, String)*
)(implicit evTo: CanBeWriter[To]): To = {
val output = implicitly[CanBeWriter[To]].newInstance(packageName, fileName)
val out = implicitly[CanBeWriter[To]].toWriter(output)
try {
val subs = ListBuffer.empty[(String, String)]
subs ++= subs0
if (config.isScala3_4Plus) {
subs += scala3VarArgSub
}
val transform = substituteMany(subs.toList: _*)
printFromResource(resourcePath, out, transform)
} finally {
out.flush()
out.close()
}
output
}
def generateBaseRuntimeFiles[To](cntxt: Context, config: Config)(implicit evTo: CanBeWriter[To]): List[To] = {
val subs = ListBuffer.empty[(String, String)]
subs += "%%JAXB_PACKAGE%%" -> config.jaxbPackage.packageName
if (config.isScala3Plus) {
subs += """CanWriteXML\[_\]""" -> "CanWriteXML[?]"
}
List(
generateFromResource[To](
Some("scalaxb"),
"scalaxb.scala",
"/scalaxb.scala.template",
config,
subs.toList: _*,
),
)
}
def generateRuntimeFiles[To](context: Context, config: Config)(implicit evTo: CanBeWriter[To]): List[To]
// returns a seq of package name, snippet, and file name part tuple
def generate(schema: Schema, part: String, context: Context, config: Config): Seq[(Option[String], Snippet, String)]
def generateProtocol(snippet: Snippet,
context: Context, config: Config): Seq[Node]
def toImportable(location: URI, rawschema: RawSchema): Importable
def shortenUri(uri: URI): String = {
val path = Option[String](uri.getPath) getOrElse {""}
(new File(path)).getName
}
def missingDependencies(importable: Importable, files: collection.Seq[Importable]): List[String] = {
val nsBased = importable.importNamespaces.toList flatMap { (ns: String) =>
files filter { x =>
val targetNamespace: Option[String] = x.targetNamespace
targetNamespace == Option(ns)
}
}
val XML_LOCATION = "http://www.w3.org/2001/xml.xsd"
val locationBased = importable.importLocations.toList flatMap { loc =>
val deps = files filter { f => shortenUri(f.location) == shortenUri(new URI(loc)) }
if (deps.isEmpty && loc != XML_LOCATION) {
logger.warn((new File(importable.location.getPath).getName) + " imports " + loc +
" but no schema with that name was compiled together.")
List(loc)
}
else Nil
}
val includes = importable.includeLocations.toList flatMap { loc =>
val deps = files filter { f => shortenUri(f.location) == shortenUri(new URI(loc)) }
if (deps.isEmpty && loc != XML_LOCATION) {
logger.warn("Warning: " + (new File(importable.location.getPath).getName) + " includes " + loc +
" but no schema with that name was compiled together.")
List(loc)
}
else Nil
}
(locationBased ::: includes).distinct
// (nsBased ::: locationBased ::: includes).distinct
}
def buildContext: Context
def processSchema(schema: Schema, context: Context, config: Config): Unit
def processContext(context: Context, schemas: Seq[Schema], config: Config): Unit
def packageName(namespace: Option[String], context: Context): Option[String]
def readerToRawSchema(reader: Reader): RawSchema
def nodeToRawSchema(node: Node): RawSchema
def parse(importable: Importable, context: Context): Schema
= importable.toSchema(context)
def parse(location: URI, in: Reader): Schema
= parse(toImportable(location, readerToRawSchema(in)), buildContext)
def printNodes(nodes: Seq[Node], out: PrintWriter, transform: String => String = identity): Unit = {
import scala.xml.{ transform => _, _ }
def printNode(n: Node): Unit = n match {
case Text(s) => out.print(transform(s))
case EntityRef("lt") => out.print('<')
case EntityRef("gt") => out.print('>')
case EntityRef("amp") => out.print('&')
case atom: Atom[_] => out.print(transform(atom.text))
case elem: Elem =>
printNodes(elem.child, out, transform)
if (elem.text != "") {
if (elem.text.contains(newline)) out.println("")
out.println("")
}
case _ => logger.error("error in Module: encountered "
+ n.getClass() + " " + n.toString)
}
for (node <- nodes) { printNode(node) }
}
def printFromResource(source: String, out: PrintWriter, transform: String => String = identity): Unit = {
val in = getClass.getResourceAsStream(source)
try {
val reader = new java.io.BufferedReader(new java.io.InputStreamReader(in))
var line: Option[String] = None
line = Option[String](reader.readLine)
while (line != None) {
line match {
case Some(l) => out.println(transform(l))
case _ => // do nothing
}
line = Option[String](reader.readLine)
}
} finally {
in.close()
}
out.flush
}
def copyFileFromResource(source: String, dest: File, transform: String => String = identity) =
printFromResource(source, new java.io.PrintWriter(new java.io.FileWriter(dest)), transform)
def mergeSnippets(snippets: Seq[Snippet]) =
Snippet(snippets flatMap {_.definition},
snippets flatMap {_.companion},
snippets flatMap {_.defaultFormats},
snippets flatMap {_.implicitValue})
def appendPostFix(location: URI, n: Int): URI = new URI(shortenUri(location).replaceFirst("\\.xsd", "") + n.toString + ".xsd")
// replace the targetNamespace
def replaceNamespace(raw: Node, old: Option[String], outerNamespace: Option[String]): Node = {
def fixScope(scope: NamespaceBinding): NamespaceBinding =
NamespaceBinding(null, outerNamespace getOrElse null, scope)
def fixSeq(ns: Seq[Node]): Seq[Node] =
for { node <- ns } yield node match {
case elem: Elem =>
elem.copy(scope = fixScope(elem.scope),
child = fixSeq(elem.child))
case other => other
}
val xml = CustomXML.load(new java.io.StringReader((raw match {
case elem: Elem if !old.isDefined =>
val x = elem.copy(attributes = new UnprefixedAttribute("targetNamespace", outerNamespace getOrElse null, elem.attributes),
scope = fixScope(elem.scope), child = fixSeq(elem.child))
x
case node => node
}).toString))
xml
}
}
object CustomXML extends XMLLoader[Elem] {
override def parser: SAXParser = {
val factory = javax.xml.parsers.SAXParserFactory.newInstance()
factory.setFeature("http://xml.org/sax/features/validation", false)
factory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", false)
factory.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false)
factory.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false)
factory.newSAXParser()
}
}
object UnicodeFileReader {
def reader(value: File): Reader = {
val BOM_SIZE = 4
val EF = 0xEF.toByte
val BB = 0xBB.toByte
val BF = 0xBF.toByte
val FE = 0xFE.toByte
val FF = 0xFF.toByte
val bom = Array.ofDim[Byte](BOM_SIZE)
val in = new java.io.PushbackInputStream(new java.io.FileInputStream(value), BOM_SIZE)
val readSize = in.read(bom, 0, bom.length)
val (bomSize, encoding) = bom.toList match {
case EF :: BB :: BF :: xs => (3, "UTF-8")
case FE :: FF :: xs => (2, "UTF-16BE")
case FF :: FE :: xs => (2, "UTF-16LE")
case _ => (0, "UTF-8")
}
in.unread(bom, bomSize, readSize - bomSize)
new BufferedReader(new java.io.InputStreamReader(in, encoding))
}
}
class ReferenceNotFound(kind: String, namespace: Option[String], name: String) extends RuntimeException(
"Error: Referenced " + kind + " " +
(namespace map { "{" + _ + "}" } getOrElse {"(unqualified) "}) + name + " was not found.")
@deprecated(message = "will be removed")
class CaseClassTooLong(fqn: String, xmlname: String) extends RuntimeException(
s"""Error: A case class with > 22 parameters cannot be created for ${fqn}. Consider using --wrap-contents "${xmlname}" option."""
)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy