org.apache.flink.api.scala.ClosureCleaner.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala
import java.io._
import org.apache.flink.annotation.Internal
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.util.InstantiationUtil
import org.slf4j.LoggerFactory
import scala.collection.mutable.Map
import scala.collection.mutable.Set
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.objectweb.asm.Opcodes._
/* This code is originally from the Apache Spark project. */
@Internal
object ClosureCleaner {
val LOG = LoggerFactory.getLogger(this.getClass)
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
val resourceStream = cls.getResourceAsStream(className)
// todo: Fixme - continuing with earlier behavior ...
if (resourceStream == null) return new ClassReader(resourceStream)
val baos = new ByteArrayOutputStream(128)
copyStream(resourceStream, baos, true)
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
// Check whether a class represents a Scala closure
private def isClosure(cls: Class[_]): Boolean = {
cls.getName.contains("$anonfun$")
}
// Get a list of the classes of the outer objects of a given closure object, obj;
// the outer objects are defined as any closures that obj is nested within, plus
// possibly the class that the outermost closure is in, if any. We stop searching
// for outer objects beyond that because cloning the user's object is probably
// not a good idea (whereas we can clone closure objects just fine since we
// understand how all their fields are used).
private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
if (isClosure(f.getType)) {
return f.getType :: getOuterClasses(f.get(obj))
} else {
return f.getType :: Nil // Stop at the first $outer that is not a closure
}
}
Nil
}
// Get a list of the outer objects for a given closure object.
private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
if (isClosure(f.getType)) {
return f.get(obj) :: getOuterObjects(f.get(obj))
} else {
return f.get(obj) :: Nil // Stop at the first $outer that is not a closure
}
}
Nil
}
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
var stack = List[Class[_]](obj.getClass)
while (stack.nonEmpty) {
val cr = getClassReader(stack.head)
stack = stack.tail
val set = Set[Class[_]]()
cr.accept(new InnerClosureFinder(set), 0)
for (cls <- set -- seen) {
seen += cls
stack = cls :: stack
}
}
(seen - obj.getClass).toList
}
private def createNullValue(cls: Class[_]): AnyRef = {
if (cls.isPrimitive) {
new java.lang.Byte(0: Byte) // Should be convertible to any primitive type
} else {
null
}
}
def clean(func: AnyRef, checkSerializable: Boolean = true) {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
val outerObjects = getOuterObjects(func)
val accessedFields = Map[Class[_], Set[String]]()
getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
for (cls <- outerClasses)
accessedFields(cls) = Set[String]()
for (cls <- func.getClass :: innerClasses)
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
if (LOG.isDebugEnabled) {
LOG.debug("accessedFields: " + accessedFields)
}
var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
var outer: AnyRef = null
if (outerPairs.nonEmpty && !isClosure(outerPairs.head._1)) {
// The closure is ultimately nested inside a class; keep the object of that
// class without cloning it since we don't want to clone the user's objects.
outer = outerPairs.head._2
outerPairs = outerPairs.tail
}
// Clone the closure objects themselves, nulling out any fields that are not
// used in the closure we're working on or any of its inner closures.
for ((cls, obj) <- outerPairs) {
outer = instantiateClass(cls, outer)
for (fieldName <- accessedFields(cls)) {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
if (LOG.isDebugEnabled) {
LOG.debug("1: Setting " + fieldName + " on " + cls + " to " + value)
}
field.set(outer, value)
}
}
if (outer != null) {
if (LOG.isDebugEnabled) {
LOG.debug("2: Setting $outer on " + func.getClass + " to " + outer)
}
val field = func.getClass.getDeclaredField("$outer")
field.setAccessible(true)
field.set(func, outer)
}
if (checkSerializable) {
ensureSerializable(func)
}
}
def ensureSerializable(func: AnyRef) {
try {
InstantiationUtil.serializeObject(func)
} catch {
case ex: Exception => throw new InvalidProgramException("Task not serializable", ex)
}
}
private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = {
if (LOG.isDebugEnabled) {
LOG.debug("Creating a " + cls + " with outer = " + outer)
}
// This is a bona fide closure class, whose constructor has no effects
// other than to set its fields, so use its constructor
val cons = cls.getConstructors()(0)
val params = cons.getParameterTypes.map(createNullValue)
if (outer != null) {
params(0) = outer // First param is always outer object
}
cons.newInstance(params: _*).asInstanceOf[AnyRef]
}
/** Copy all data from an InputStream to an OutputStream */
def copyStream(in: InputStream,
out: OutputStream,
closeStreams: Boolean = false): Long =
{
var count = 0L
try {
if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) {
// When both streams are File stream, use transferTo to improve copy performance.
val inChannel = in.asInstanceOf[FileInputStream].getChannel()
val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
val size = inChannel.size()
// In case transferTo method transferred less data than we have required.
while (count < size) {
count += inChannel.transferTo(count, size - count, outChannel)
}
} else {
val buf = new Array[Byte](8192)
var n = 0
while (n != -1) {
n = in.read(buf)
if (n != -1) {
out.write(buf, 0, n)
count += n
}
}
}
count
} finally {
if (closeStreams) {
try {
in.close()
} finally {
out.close()
}
}
}
}
}
@Internal
private[flink]
class ReturnStatementFinder extends ClassVisitor(ASM5) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
if (name.contains("apply")) {
new MethodVisitor(ASM5) {
override def visitTypeInsn(op: Int, tp: String) {
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
throw new InvalidProgramException("Return statements aren't allowed in Flink closures")
}
}
}
} else {
new MethodVisitor(ASM5) {}
}
}
}
@Internal
private[flink]
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM5) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
new MethodVisitor(ASM5) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
output(cl) += name
}
}
}
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
output(cl) += name
}
}
}
}
}
}
@Internal
private[flink] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
superName: String, interfaces: Array[String]) {
myName = name
}
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
new MethodVisitor(ASM5) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
if (op == INVOKESPECIAL && name == "" && argTypes.nonEmpty
&& argTypes(0).toString.startsWith("L") // is it an object?
&& argTypes(0).getInternalName == myName) {
output += Class.forName(
owner.replace('/', '.'),
false,
Thread.currentThread.getContextClassLoader)
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy