scala.cli.graal.BytecodeProcessor.scala Maven / Gradle / Ivy
package scala.cli.graal
import org.objectweb.asm._
import java.io.{File, InputStream}
import java.nio.file.{Files, StandardOpenOption}
import java.util.jar.{Attributes, JarEntry, JarFile, JarOutputStream, Manifest}
import scala.jdk.CollectionConverters._
object BytecodeProcessor {
def toClean(classpath: Seq[ClassPathEntry]): Seq[os.Path] = classpath.flatMap {
case Processed(path, _, TempCache) => Seq(path)
case PathingJar(path, entries) => toClean(path +: entries)
case _ => Nil
}
def processPathingJar(pathingJar: os.Path, cache: JarCache): Seq[ClassPathEntry] = {
val jarFile = new JarFile(pathingJar.toIO)
try {
val cp = jarFile.getManifest().getMainAttributes().getValue(Attributes.Name.CLASS_PATH)
if (cp != null && cp.nonEmpty) {
// paths in pathing jars are separated by spaces
val entries = cp.split(" +").toSeq.map { rawEntry =>
// In manifest JARs, class path entries are supposed to be encoded as URL paths.
// This especially matters on Windows, where we end up with paths like "/C:/…".
// Theoretically, we should decode those paths with
// os.Path(java.nio.file.Paths.get(new java.net.URI("file://" + rawEntry)), os.pwd)
// but native-image doesn't follow this, and decodes them with just Paths.get(…).
// As the JARs we are handed are supposed to be passed to native-image, we follow
// the native-image convention here.
os.Path(rawEntry, os.pwd)
}
val processedCp = processClassPathEntries(entries, cache)
val dest = os.temp(suffix = ".jar")
val outStream = Files.newOutputStream(dest.toNIO, StandardOpenOption.CREATE)
try {
val stringCp = processedCp.map(_.path.toNIO).mkString(" ")
val manifest = new Manifest(jarFile.getManifest())
manifest.getMainAttributes().put(Attributes.Name.CLASS_PATH, stringCp)
val outjar = new JarOutputStream(outStream, manifest)
outjar.close()
dest.toNIO.toString()
Seq(PathingJar(Processed(dest, pathingJar, TempCache), processedCp))
}
finally outStream.close()
}
else processClassPathEntries(Seq(pathingJar), cache)
}
finally jarFile.close()
}
def processClassPath(classPath: String, cache: JarCache = TempCache): Seq[ClassPathEntry] =
classPath.split(File.pathSeparator) match {
case Array(maybePathingJar) if maybePathingJar.endsWith(".jar") =>
processPathingJar(os.Path(maybePathingJar, os.pwd), cache)
case cp =>
val cp0 = cp.toSeq.map(os.Path(_, os.pwd))
processClassPathEntries(cp0, cache)
}
def processClassPathEntries(entries: Seq[os.Path], cache: JarCache): Seq[ClassPathEntry] = {
val cp = entries.map { path =>
cache.cache(path) { dest =>
if (path.ext == "jar" && os.isFile(path)) processJar(path, dest, cache)
else if (os.isDir(path)) processDir(path, dest, cache)
else Unmodified(dest)
}
}
if (cp.exists(_.modified)) {
// jar with runtime deps is added as a resource
// scala3RuntimeFixes.jar is also used within
// resource-config.json and BytecodeProcessor.scala
val jarName = "scala3RuntimeFixes.jar"
val runtimeJarIs = getClass().getClassLoader.getResourceAsStream(jarName)
if (runtimeJarIs == null) throw new NoSuchElementException(
"Unable to find scala3RuntimeFixes.jar on classpath, did you add scala3-graal jar on classpath?"
)
val created = cache.put(os.RelPath(jarName), runtimeJarIs.readAllBytes())
created +: cp
}
else cp // No need to add processed jar
}
def processDir(dir: os.Path, dest: os.Path, cache: JarCache): ClassPathEntry = {
val paths = os.walk(dir).filter(os.isFile)
val (skipped, processed) = paths.partitionMap {
case p if p.ext != "class" =>
Left(p)
case clazzFile =>
val original = os.read.bytes(clazzFile)
processClassFile(original) match {
case Some(content) =>
val relPath = clazzFile.relativeTo(dir)
val destPath = dest / relPath
os.makeDir.all(destPath / os.up)
assert(content != original)
os.write(destPath, content)
Right(clazzFile)
case _ =>
Left(clazzFile)
}
}
if (processed.nonEmpty) {
skipped.foreach(file =>
os.copy.over(file, dest / (file.relativeTo(dir)), createFolders = true)
)
Processed(dest, dir, cache)
}
else Unmodified(dir)
}
def processJar(path: os.Path, dest: os.Path, cache: JarCache): ClassPathEntry = {
val jarFile = new JarFile(path.toIO)
try {
var processedBytecode: Option[Array[Byte]] = None
val endMarker = "///" // not a valid path
var processed: String = endMarker
def processEntry(entry: JarEntry) = {
val newBytecode = processClassFile(jarFile.getInputStream(entry))
processed = entry.getName()
processedBytecode = newBytecode
newBytecode.fold(entry.getName())(_ => endMarker) // empty string is an end marker
}
val classFilesIterator =
jarFile.entries().asIterator().asScala.filter(_.getName().endsWith(".class"))
val cachedEntries = classFilesIterator.map(processEntry).takeWhile(_ != endMarker).toSet
if (processedBytecode.isEmpty) Unmodified(path)
else {
os.makeDir.all(dest / os.up)
val outStream = Files.newOutputStream(dest.toNIO, StandardOpenOption.CREATE)
val outjar = new JarOutputStream(outStream)
jarFile.entries().asIterator().asScala.foreach { entry =>
val content: Array[Byte] = jarFile.getInputStream(entry).readAllBytes()
val name = entry.getName()
val destBytes =
if (cachedEntries.contains(name) || !name.endsWith(".class")) content
else if (name == processed) processedBytecode.get
else processClassFile(content).getOrElse(content)
val newEntry = new JarEntry(entry.getName())
outjar.putNextEntry(newEntry)
outjar.write(destBytes)
outjar.closeEntry()
}
outjar.close()
Processed(dest, path, cache)
}
}
finally jarFile.close()
}
def processClassReader(reader: ClassReader): Option[Array[Byte]] = {
val writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS)
val visitor = new LazyValVisitor(writer)
val res = util.Try(reader.accept(visitor, 0))
if (visitor.changed && res.isSuccess) Some(writer.toByteArray) else None
}
def processClassFile(content: => InputStream): Option[Array[Byte]] = {
val is = content
try processClassReader(new ClassReader(is))
finally is.close()
}
def processClassFile(content: Array[Byte]): Option[Array[Byte]] =
processClassReader(new ClassReader(content))
class LazyValVisitor(writer: ClassWriter) extends ClassVisitor(Opcodes.ASM9, writer) {
var changed: Boolean = false
class StaticInitVistor(parent: MethodVisitor)
extends MethodVisitor(Opcodes.ASM9, parent) {
override def visitMethodInsn(
opcode: Int,
owner: String,
name: String,
descr: String,
isInterface: Boolean
): Unit =
if (owner == "scala/runtime/LazyVals$" && name == "getOffset") {
changed = true
super.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/lang/Class",
"getDeclaredField",
"(Ljava/lang/String;)Ljava/lang/reflect/Field;",
false
)
super.visitMethodInsn(
Opcodes.INVOKESTATIC,
"scala/cli/runtime/SafeLazyVals",
"getOffset",
"(Ljava/lang/Object;Ljava/lang/reflect/Field;)J",
false
)
}
else
super.visitMethodInsn(opcode, owner, name, descr, isInterface)
}
override def visitMethod(
access: Int,
name: String,
desc: String,
sig: String,
exceptions: Array[String]
): MethodVisitor =
if (name == "")
new StaticInitVistor(super.visitMethod(access, name, desc, sig, exceptions))
else
super.visitMethod(access, name, desc, sig, exceptions)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy