All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalapb.ScalaPBC.scala Maven / Gradle / Ivy

The newest version!
package scalapb

import java.io.File

import protocbridge.{ProtocBridge, ProtocCodeGenerator}
import coursier.parse.DependencyParser
import coursier.core.Configuration
import com.github.ghik.silencer.silent
import coursier.core.Dependency
import java.net.URLClassLoader
import java.util.jar.JarInputStream
import java.io.FileInputStream
import protocbridge.SandboxedJvmGenerator
import scala.util.{Try, Success, Failure}
import protocbridge.ProtocRunner

case class Config(
    version: String = scalapb.compiler.Version.protobufVersion,
    throwException: Boolean = false,
    args: Seq[String] = Seq.empty,
    customProtocLocation: Option[String] = None,
    namedGenerators: Seq[(String, ProtocCodeGenerator)] = Seq("scala" -> ScalaPbCodeGenerator),
    executableArtifacts: Seq[String] = Seq.empty,
    jvmPlugins: Seq[(String, String)] = Seq.empty
)

class ScalaPbcException(msg: String) extends RuntimeException(msg)

object ScalaPBC {
  private val CustomPathArgument     = "--protoc="
  private val CustomGenArgument      = "--custom-gen="
  private val PluginArtifactArgument = "--plugin-artifact="
  private val JvmPluginArgument      = "--jvm-plugin="

  def processArgs(args: Array[String]): Config = {
    case class State(cfg: Config, passThrough: Boolean)

    args
      .foldLeft(State(Config(), false)) { case (state, item) =>
        (state.passThrough, item) match {
          case (false, "--")      => state.copy(passThrough = true)
          case (false, "--throw") => state.copy(cfg = state.cfg.copy(throwException = true))
          case (false, p) if p.startsWith(CustomGenArgument) =>
            val Array(genName, klassName) = p.substring(CustomGenArgument.length).split('=')
            val klass                     = Class.forName(klassName + "$")
            val gen = klass.getField("MODULE$").get(klass).asInstanceOf[ProtocCodeGenerator]
            state.copy(
              cfg = state.cfg.copy(namedGenerators = state.cfg.namedGenerators :+ (genName -> gen))
            )
          case (false, p) if p.startsWith(JvmPluginArgument) =>
            val Array(genName, artifactName) = p.substring(JvmPluginArgument.length).split('=')
            state.copy(
              cfg = state.cfg.copy(jvmPlugins = state.cfg.jvmPlugins :+ (genName -> artifactName))
            )
          case (false, p) if p.startsWith(CustomPathArgument) =>
            state.copy(
              cfg = state.cfg
                .copy(customProtocLocation = Some(p.substring(CustomPathArgument.length)))
            )
          case (false, p) if p.startsWith(PluginArtifactArgument) =>
            state.copy(cfg =
              state.cfg
                .copy(executableArtifacts =
                  state.cfg.executableArtifacts :+ p.substring(PluginArtifactArgument.length())
                )
            )
          case (false, v) if v.startsWith("-v") =>
            state.copy(cfg = state.cfg.copy(version = v.substring(2).trim))
          case (_, other) =>
            state.copy(passThrough = true, cfg = state.cfg.copy(args = state.cfg.args :+ other))
        }
      }
      .cfg
  }

  @silent("method right in class Either is deprecated")
  def fetchArtifact(artifact: String): Either[String, (Dependency, Seq[File])] = {
    import coursier._
    for {
      dep <- DependencyParser
        .dependency(
          artifact,
          scala.util.Properties.versionNumberString,
          Configuration.empty
        )
        .right
      runResult = Fetch().addDependencies(dep).run()
      outcome <-
        if (runResult.isEmpty) Left(s"Could not find artifact for $artifact")
        else Right(runResult)
    } yield (dep, outcome)
  }

  def fetchArtifacts(
      artifacts: Seq[(String, String)]
  ): Either[String, Seq[(String, (Dependency, Seq[File]))]] =
    artifacts.foldLeft[Either[String, Seq[(String, (Dependency, Seq[File]))]]](Right(Seq())) {
      case (Left(error), _) => Left(error)
      case (Right(result), (name, artifact)) =>
        fetchArtifact(artifact) match {
          case Right((dep, files)) => Right(result :+ ((name, (dep, files))))
          case Left(error)         => Left(error)
        }
    }

  def findMainClass(f: File): Either[String, String] = {
    val jin = new JarInputStream(new FileInputStream(f))
    try {
      val manifest = jin.getManifest()
      Option(manifest.getMainAttributes().getValue("Main-Class"))
        .toRight("Could not find main class for plugin")
        .map(_ + "$")
    } finally {
      jin.close()
    }
  }

  private def getProtoc(version: String): Either[String, String] = {
    Try(protocbridge.CoursierProtocCache.getProtoc(version)) match {
      case Success(f) => Right(f.getAbsolutePath())
      case Failure(e) => Left(e.getMessage)
    }
  }

  @silent("method right in class Either is deprecated")
  private[scalapb] def runProtoc(config: Config): Int = {
    if (
      config.namedGenerators
        .map(_._1)
        .toSet
        .intersect(config.jvmPlugins.map(_._1).toSet)
        .nonEmpty
    ) {
      throw new RuntimeException(
        s"Same plugin name provided by $PluginArtifactArgument and $JvmPluginArgument"
      )
    }

    def fatalError(err: String): Nothing = {
      if (config.throwException) {
        throw new ScalaPbcException(s"Error: $err")
      } else {
        System.err.println(err)
        sys.exit(1)
      }
    }

    val jvmGenerators = fetchArtifacts(
      config.jvmPlugins
    ) match {
      case Left(error) => fatalError(error)
      case Right(arts) =>
        arts.map { case (name, (_, files)) =>
          val urls   = files.map(_.toURI().toURL()).toArray
          val loader = new URLClassLoader(urls, null)
          val mainClass = findMainClass(files.head) match {
            case Right(v)  => v
            case Left(err) => fatalError(err)
          }
          name -> SandboxedJvmGenerator.load(mainClass, loader)
        }
    }

    val pluginArgs = fetchArtifacts(
      config.executableArtifacts.map(a => ("", a))
    ) match {
      case Left(error) => fatalError(error)
      case Right(arts) =>
        arts.map {
          case (_, (dep, file :: Nil)) =>
            file.setExecutable(true)
            s"--plugin=${dep.module.name.value}=${file.getAbsolutePath()}"
          case (_, (dep, files)) =>
            fatalError(s"Got ${files.length} files for dependency $dep. Only one expected.")
        }
    }

    val protoc =
      config.customProtocLocation
        .getOrElse(getProtoc(config.version).fold(fatalError(_), identity(_)))

    ProtocBridge.runWithGenerators(
      ProtocRunner(protoc),
      namedGenerators = config.namedGenerators ++ jvmGenerators,
      params = config.args ++ pluginArgs
    )
  }

  def main(args: Array[String]): Unit = {
    val config = processArgs(args)
    val code   = runProtoc(config)

    if (!config.throwException) {
      sys.exit(code)
    } else {
      if (code != 0) {
        throw new ScalaPbcException(s"Exit with code $code")
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy