All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.codegen.optimization.CapturedVarsOptimizationMethodTransformer.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2017 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.codegen.optimization

import org.jetbrains.kotlin.builtins.PrimitiveType
import org.jetbrains.kotlin.codegen.AsmUtil
import org.jetbrains.kotlin.codegen.optimization.common.*
import org.jetbrains.kotlin.codegen.optimization.fixStack.peek
import org.jetbrains.kotlin.codegen.optimization.fixStack.top
import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
import org.jetbrains.kotlin.utils.addToStdlib.safeAs
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue
import org.jetbrains.org.objectweb.asm.tree.analysis.Frame

class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
    override fun transform(internalClassName: String, methodNode: MethodNode) {
        Transformer(internalClassName, methodNode).run()
    }

    // Tracks proper usages of objects corresponding to captured variables.
    //
    // The 'kotlin.jvm.internal.Ref.*' instance can be replaced with a local variable,
    // if all of the following conditions are satisfied:
    //  * It is created inside a current method.
    //  * The only permitted operations on it are:
    //      - store to a local variable
    //      - ALOAD, ASTORE
    //      - DUP, POP
    //      - GETFIELD .element, PUTFIELD .element
    //  * There's a corresponding local variable definition,
    //      and all ALOAD/ASTORE instructions operate on that particular local variable.
    //  * Its 'element' field is initialized at start of local variable visibility range.
    //
    // Note that for code that doesn't create Ref objects explicitly these conditions are true,
    // unless the Ref object escapes to a local class constructor (including local classes for lambdas).
    //
    private class CapturedVarDescriptor(val newInsn: TypeInsnNode, val refType: Type, val valueType: Type) : ReferenceValueDescriptor {
        var hazard = false

        var initCallInsn: MethodInsnNode? = null
        var localVar: LocalVariableNode? = null
        var localVarIndex = -1
        val astoreInsns: MutableCollection = LinkedHashSet()
        val aloadInsns: MutableCollection = LinkedHashSet()
        val stackInsns: MutableCollection = LinkedHashSet()
        val getFieldInsns: MutableCollection = LinkedHashSet()
        val putFieldInsns: MutableCollection = LinkedHashSet()
        var cleanVarInstruction: VarInsnNode? = null

        fun canRewrite(): Boolean =
            !hazard &&
                    initCallInsn != null &&
                    localVar != null &&
                    localVarIndex >= 0

        override fun onUseAsTainted() {
            hazard = true
        }
    }

    private class Transformer(private val internalClassName: String, private val methodNode: MethodNode) {
        private val refValues = ArrayList()
        private val refValuesByNewInsn = LinkedHashMap()
        private val insns = methodNode.instructions.toArray()
        private lateinit var frames: Array?>

        val hasRewritableRefValues: Boolean
            get() = refValues.isNotEmpty()

        fun run() {
            createRefValues()
            if (!hasRewritableRefValues) return

            analyze()
            if (!hasRewritableRefValues) return

            rewrite()
        }

        private fun AbstractInsnNode.getIndex() = methodNode.instructions.indexOf(this)

        private fun createRefValues() {
            for (insn in insns) {
                if (insn.opcode == Opcodes.NEW && insn is TypeInsnNode) {
                    val type = Type.getObjectType(insn.desc)
                    if (AsmTypes.isSharedVarType(type)) {
                        val valueType = REF_TYPE_TO_ELEMENT_TYPE[type.internalName] ?: continue
                        val refValue = CapturedVarDescriptor(insn, type, valueType)
                        refValues.add(refValue)
                        refValuesByNewInsn[insn] = refValue
                    }
                }
            }
        }

        private inner class Interpreter : ReferenceTrackingInterpreter() {
            override fun newOperation(insn: AbstractInsnNode): BasicValue =
                refValuesByNewInsn[insn]?.let { descriptor ->
                    ProperTrackedReferenceValue(descriptor.refType, descriptor)
                }
                        ?: super.newOperation(insn)

            override fun processRefValueUsage(value: TrackedReferenceValue, insn: AbstractInsnNode, position: Int) {
                for (descriptor in value.descriptors) {
                    if (descriptor !is CapturedVarDescriptor) throw AssertionError("Unexpected descriptor: $descriptor")
                    when {
                        insn.opcode == Opcodes.ALOAD ->
                            descriptor.aloadInsns.add(insn as VarInsnNode)
                        insn.opcode == Opcodes.ASTORE ->
                            descriptor.astoreInsns.add(insn as VarInsnNode)
                        insn.opcode == Opcodes.GETFIELD && insn is FieldInsnNode && insn.name == REF_ELEMENT_FIELD && position == 0 ->
                            descriptor.getFieldInsns.add(insn)
                        insn.opcode == Opcodes.PUTFIELD && insn is FieldInsnNode && insn.name == REF_ELEMENT_FIELD && position == 0 ->
                            descriptor.putFieldInsns.add(insn)
                        insn.opcode == Opcodes.INVOKESPECIAL && insn is MethodInsnNode && insn.name == INIT_METHOD_NAME && position == 0 ->
                            if (descriptor.initCallInsn != null && descriptor.initCallInsn != insn)
                                descriptor.hazard = true
                            else
                                descriptor.initCallInsn = insn
                        insn.opcode == Opcodes.DUP ->
                            descriptor.stackInsns.add(insn)
                        else ->
                            descriptor.hazard = true
                    }
                }
            }

        }

        private fun analyze() {
            frames = MethodTransformer.analyze(internalClassName, methodNode, Interpreter())
            trackPops()
            assignLocalVars()

            refValues.removeAll { !it.canRewrite() }
        }

        private fun trackPops() {
            for (i in insns.indices) {
                val frame = frames[i] ?: continue
                val insn = insns[i]

                when (insn.opcode) {
                    Opcodes.POP -> {
                        frame.top()?.getCapturedVarOrNull()?.run { stackInsns.add(insn) }
                    }
                    Opcodes.POP2 -> {
                        val top = frame.top()
                        if (top?.size == 1) {
                            top.getCapturedVarOrNull()?.hazard = true
                            frame.peek(1)?.getCapturedVarOrNull()?.hazard = true
                        }
                    }
                }
            }
        }

        private fun BasicValue.getCapturedVarOrNull() =
            safeAs()?.descriptor?.safeAs()

        private fun assignLocalVars() {
            for (localVar in methodNode.localVariables) {
                val type = Type.getType(localVar.desc)
                if (!AsmTypes.isSharedVarType(type)) continue

                val startFrame = frames[localVar.start.getIndex()] ?: continue

                val refValue = startFrame.getLocal(localVar.index) as? ProperTrackedReferenceValue ?: continue
                val descriptor = refValue.descriptor as? CapturedVarDescriptor ?: continue

                if (descriptor.hazard) continue

                if (descriptor.localVar == null) {
                    descriptor.localVar = localVar
                } else {
                    descriptor.hazard = true
                }
            }

            for (refValue in refValues) {
                if (refValue.hazard) continue
                val localVar = refValue.localVar ?: continue
                val oldVarIndex = localVar.index

                if (refValue.valueType.size != 1) {
                    refValue.localVarIndex = methodNode.maxLocals
                    methodNode.maxLocals += 2
                    localVar.index = refValue.localVarIndex
                } else {
                    refValue.localVarIndex = localVar.index
                }

                val cleanInstructions = findCleanInstructions(refValue, oldVarIndex, methodNode.instructions)
                if (cleanInstructions.size > 1) {
                    refValue.hazard = true
                    continue
                }
                refValue.cleanVarInstruction = cleanInstructions.firstOrNull()
            }
        }

        private fun findCleanInstructions(refValue: CapturedVarDescriptor, oldVarIndex: Int, instructions: InsnList): List {
            return InsnSequence(instructions).filterIsInstance().filter {
                it.opcode == Opcodes.ASTORE && it.`var` == oldVarIndex
            }.filter {
                it.previous?.opcode == Opcodes.ACONST_NULL
            }.filter {
                val operationIndex = instructions.indexOf(it)
                val localVariableNode = refValue.localVar!!
                instructions.indexOf(localVariableNode.start) < operationIndex && operationIndex < instructions.indexOf(
                    localVariableNode.end
                )
            }.toList()
        }

        private fun rewrite() {
            for (refValue in refValues) {
                if (!refValue.canRewrite()) continue

                rewriteRefValue(refValue)
            }

            methodNode.removeEmptyCatchBlocks()
            methodNode.removeUnusedLocalVariables()
        }

        // Be careful to not remove instructions that are the only instruction for a line number. That will
        // break debugging. If the previous instruction is a line number and the following instruction is
        // a label followed by a line number, insert a nop instead of deleting the instruction.
        private fun InsnList.removeOrReplaceByNop(insn: AbstractInsnNode) {
            if (insn.previous is LineNumberNode && insn.next is LabelNode && insn.next.next is LineNumberNode) {
                set(insn, InsnNode(Opcodes.NOP))
            } else {
                remove(insn)
            }
        }

        private fun rewriteRefValue(capturedVar: CapturedVarDescriptor) {
            methodNode.instructions.run {
                val localVar = capturedVar.localVar!!
                localVar.signature = null
                localVar.desc = capturedVar.valueType.descriptor

                val loadOpcode = capturedVar.valueType.getOpcode(Opcodes.ILOAD)
                val storeOpcode = capturedVar.valueType.getOpcode(Opcodes.ISTORE)

                if (capturedVar.putFieldInsns.none { it.getIndex() < localVar.start.getIndex() }) {
                    // variable needs to be initialized before its live range can begin
                    insertBefore(capturedVar.newInsn, InsnNode(AsmUtil.defaultValueOpcode(capturedVar.valueType)))
                    insertBefore(capturedVar.newInsn, VarInsnNode(storeOpcode, capturedVar.localVarIndex))
                }

                remove(capturedVar.newInsn)
                remove(capturedVar.initCallInsn!!)

                capturedVar.stackInsns.forEach { removeOrReplaceByNop(it) }
                capturedVar.aloadInsns.forEach { removeOrReplaceByNop(it) }
                capturedVar.astoreInsns.forEach { removeOrReplaceByNop(it) }
                capturedVar.getFieldInsns.forEach { set(it, VarInsnNode(loadOpcode, capturedVar.localVarIndex)) }
                capturedVar.putFieldInsns.forEach { set(it, VarInsnNode(storeOpcode, capturedVar.localVarIndex)) }

                //after visiting block codegen tries to delete all allocated references:
                // see ExpressionCodegen.addLeaveTaskToRemoveLocalVariableFromFrameMap
                capturedVar.cleanVarInstruction?.let {
                    remove(it.previous)
                    remove(it)
                }
            }
        }

    }
}

internal const val REF_ELEMENT_FIELD = "element"
internal const val INIT_METHOD_NAME = ""

internal val REF_TYPE_TO_ELEMENT_TYPE = HashMap().apply {
    put(AsmTypes.OBJECT_REF_TYPE.internalName, AsmTypes.OBJECT_TYPE)
    PrimitiveType.values().forEach {
        put(AsmTypes.sharedTypeForPrimitive(it).internalName, AsmTypes.valueTypeForPrimitive(it))
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy