All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.codegen.optimization.boxing.PopBackwardPropagationTransformer.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2016 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.codegen.optimization.boxing

import org.jetbrains.kotlin.codegen.optimization.OptimizationMethodVisitor
import org.jetbrains.kotlin.codegen.optimization.common.debugText
import org.jetbrains.kotlin.codegen.optimization.common.isLoadOperation
import org.jetbrains.kotlin.codegen.optimization.common.isMeaningful
import org.jetbrains.kotlin.codegen.optimization.fixStack.peekWords
import org.jetbrains.kotlin.codegen.optimization.fixStack.top
import org.jetbrains.kotlin.codegen.optimization.removeNodeGetNext
import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.Analyzer
import org.jetbrains.org.objectweb.asm.tree.analysis.Frame
import org.jetbrains.org.objectweb.asm.tree.analysis.SourceInterpreter
import org.jetbrains.org.objectweb.asm.tree.analysis.SourceValue
import java.util.*

class PopBackwardPropagationTransformer : MethodTransformer() {
    override fun transform(internalClassName: String, methodNode: MethodNode) {
        if (!OptimizationMethodVisitor.canBeOptimizedUsingSourceInterpreter(methodNode)) return
        Transformer(methodNode).transform()
    }

    private class Transformer(val methodNode: MethodNode) {
        private interface Transformation {
            fun apply(insn: AbstractInsnNode)
        }

        private inline fun Transformation(crossinline body: (AbstractInsnNode) -> Unit): Transformation =
            object : Transformation {
                override fun apply(insn: AbstractInsnNode) {
                    body(insn)
                }
            }

        private val REPLACE_WITH_NOP = Transformation { insnList.set(it, createRemovableNopInsn()) }
        private val REPLACE_WITH_POP1 = Transformation { insnList.set(it, InsnNode(Opcodes.POP)) }
        private val REPLACE_WITH_POP2 = Transformation { insnList.set(it, InsnNode(Opcodes.POP2)) }
        private val INSERT_POP1_AFTER = Transformation { insnList.insert(it, InsnNode(Opcodes.POP)) }
        private val INSERT_POP2_AFTER = Transformation { insnList.insert(it, InsnNode(Opcodes.POP2)) }

        private val insnList = methodNode.instructions

        private val insns = insnList.toArray()

        private val dontTouchInsnIndices = BitSet(insns.size)
        private val transformations = hashMapOf()
        private val removableNops = hashSetOf()

        private val frames by lazy { analyzeMethodBody() }

        fun transform() {
            if (insns.none { it.isPop() || it.isPurePush() }) return

            computeTransformations()
            for ((insn, transformation) in transformations.entries) {
                transformation.apply(insn)
            }
            postprocessNops()
        }

        private fun analyzeMethodBody(): Array?> {
            val frames = Analyzer(HazardsTrackingInterpreter()).analyze("fake", methodNode)

            postprocessStackHazards(frames)

            return frames
        }

        private fun postprocessStackHazards(frames: Array?>) {
            val insns = methodNode.instructions.toArray()
            for (i in frames.indices) {
                val frame = frames[i] ?: continue
                val insn = insns[i]

                when (insn.opcode) {
                    Opcodes.POP2 -> {
                        val top2 = frame.peekWords(2) ?: throwIncorrectBytecode(insn, frame)
                        top2.forEach { it.insns.markAsDontTouch() }
                    }
                    Opcodes.DUP_X1 -> {
                        val top2 = frame.peekWords(1, 1) ?: throwIncorrectBytecode(insn, frame)
                        top2.forEach { it.insns.markAsDontTouch() }
                    }
                    Opcodes.DUP2_X1 -> {
                        val top3 = frame.peekWords(2, 1) ?: throwIncorrectBytecode(insn, frame)
                        top3.forEach { it.insns.markAsDontTouch() }
                    }
                    Opcodes.DUP_X2 -> {
                        val top3 = frame.peekWords(1, 2) ?: throwIncorrectBytecode(insn, frame)
                        top3.forEach { it.insns.markAsDontTouch() }
                    }
                    Opcodes.DUP2_X2 -> {
                        val top4 = frame.peekWords(2, 2) ?: throwIncorrectBytecode(insn, frame)
                        top4.forEach { it.insns.markAsDontTouch() }
                    }
                }
            }
        }

        private fun throwIncorrectBytecode(insn: AbstractInsnNode?, frame: Frame): Nothing {
            throw AssertionError("Incorrect bytecode at ${methodNode.instructions.indexOf(insn)}: ${insn.debugText} $frame")
        }

        private inner class HazardsTrackingInterpreter : SourceInterpreter(Opcodes.API_VERSION) {
            override fun naryOperation(insn: AbstractInsnNode, values: MutableList): SourceValue {
                for (value in values) {
                    value.insns.markAsDontTouch()
                }
                return super.naryOperation(insn, values)
            }

            override fun copyOperation(insn: AbstractInsnNode, value: SourceValue): SourceValue {
                value.insns.markAsDontTouch()
                return super.copyOperation(insn, value)
            }

            override fun unaryOperation(insn: AbstractInsnNode, value: SourceValue): SourceValue {
                if (insn.opcode != Opcodes.CHECKCAST && !insn.isPrimitiveTypeConversion()) {
                    value.insns.markAsDontTouch()
                }
                return super.unaryOperation(insn, value)
            }

            override fun binaryOperation(insn: AbstractInsnNode, value1: SourceValue, value2: SourceValue): SourceValue {
                value1.insns.markAsDontTouch()
                value2.insns.markAsDontTouch()
                return super.binaryOperation(insn, value1, value2)
            }

            override fun ternaryOperation(
                insn: AbstractInsnNode,
                value1: SourceValue,
                value2: SourceValue,
                value3: SourceValue
            ): SourceValue {
                value1.insns.markAsDontTouch()
                value2.insns.markAsDontTouch()
                value3.insns.markAsDontTouch()
                return super.ternaryOperation(insn, value1, value2, value3)
            }
        }

        private fun Collection.markAsDontTouch() {
            forEach {
                dontTouchInsnIndices[insnList.indexOf(it)] = true
            }
        }


        private fun computeTransformations() {
            transformations.clear()

            for (i in insns.indices) {
                if (frames[i] == null) continue
                val insn = insns[i]

                if (insn.opcode == Opcodes.POP) {
                    propagatePopBackwards(insn, 0)
                }
            }
        }

        private fun propagatePopBackwards(insn: AbstractInsnNode, poppedValueSize: Int) {
            if (transformations.containsKey(insn)) return

            when {
                insn.opcode == Opcodes.POP -> {
                    val inputTop = getInputTop(insn)
                    val sources = inputTop.insns
                    if (sources.all { !isDontTouch(it) } && sources.any { isTransformablePopOperand(it) }) {
                        transformations[insn] = replaceWithNopTransformation()
                        sources.forEach { propagatePopBackwards(it, inputTop.size) }
                    }
                }

                insn.opcode == Opcodes.CHECKCAST -> {
                    val inputTop = getInputTop(insn)
                    val sources = inputTop.insns
                    val resultType = (insn as TypeInsnNode).desc
                    if (sources.all { !isDontTouch(it) } && sources.any { isTransformableCheckcastOperand(it, resultType) }) {
                        transformations[insn] = replaceWithNopTransformation()
                        sources.forEach { propagatePopBackwards(it, inputTop.size) }
                    } else {
                        transformations[insn] = insertPopAfterTransformation(poppedValueSize)
                    }
                }

                insn.isPrimitiveBoxing() -> {
                    val boxedValueSize = getInputTop(insn).size
                    transformations[insn] = replaceWithPopTransformation(boxedValueSize)
                }

                insn.isPurePush() -> {
                    transformations[insn] = replaceWithNopTransformation()
                }

                insn.isPrimitiveTypeConversion() -> {
                    val inputTop = getInputTop(insn)
                    val sources = inputTop.insns
                    if (sources.all { !isDontTouch(it) }) {
                        transformations[insn] = replaceWithNopTransformation()
                        sources.forEach { propagatePopBackwards(it, inputTop.size) }
                    } else {
                        transformations[insn] = replaceWithPopTransformation(poppedValueSize)
                    }
                }

                else -> {
                    transformations[insn] = insertPopAfterTransformation(poppedValueSize)
                }
            }
        }

        private fun postprocessNops() {
            var node: AbstractInsnNode? = insnList.first
            var hasRemovableNops = false
            while (node != null) {
                node = node.next
                val begin = node ?: break
                while (node != null && node !is LabelNode) {
                    if (node in removableNops) {
                        hasRemovableNops = true
                    }
                    node = node.next
                }
                val end = node
                if (hasRemovableNops) {
                    removeUnneededNopsInRange(begin, end)
                }
                hasRemovableNops = false
            }
        }

        private fun removeUnneededNopsInRange(begin: AbstractInsnNode, end: AbstractInsnNode?) {
            var node: AbstractInsnNode? = begin
            var keepNop = true
            while (node != null && node != end) {
                if (node in removableNops && !keepNop) {
                    node = insnList.removeNodeGetNext(node)
                } else {
                    if (node.isMeaningful) keepNop = false
                    node = node.next
                }
            }
        }

        private fun replaceWithPopTransformation(size: Int): Transformation =
            when (size) {
                1 -> REPLACE_WITH_POP1
                2 -> REPLACE_WITH_POP2
                else -> throw AssertionError("Unexpected pop value size: $size")
            }

        private fun insertPopAfterTransformation(size: Int): Transformation =
            when (size) {
                1 -> INSERT_POP1_AFTER
                2 -> INSERT_POP2_AFTER
                else -> throw AssertionError("Unexpected pop value size: $size")
            }

        private fun replaceWithNopTransformation(): Transformation =
            REPLACE_WITH_NOP

        private fun createRemovableNopInsn() =
            InsnNode(Opcodes.NOP).apply { removableNops.add(this) }

        private fun getInputTop(insn: AbstractInsnNode): SourceValue {
            val i = insnList.indexOf(insn)
            val frame = frames[i] ?: throw AssertionError("Unexpected dead instruction #$i")
            return frame.top() ?: throw AssertionError("Instruction #$i has empty stack on input")
        }

        private fun isTransformableCheckcastOperand(it: AbstractInsnNode, resultType: String) =
            it.isPrimitiveBoxing() && (it as MethodInsnNode).owner == resultType

        private fun isTransformablePopOperand(insn: AbstractInsnNode) =
            insn.opcode == Opcodes.CHECKCAST || insn.isPrimitiveBoxing() || insn.isPurePush()

        private fun isDontTouch(insn: AbstractInsnNode) =
            dontTouchInsnIndices[insnList.indexOf(insn)]
    }

}

fun AbstractInsnNode.isPurePush() =
    isLoadOperation() ||
            opcode in Opcodes.ACONST_NULL..Opcodes.LDC + 2 ||
            isUnitInstance()

fun AbstractInsnNode.isPop() =
    opcode == Opcodes.POP || opcode == Opcodes.POP2

fun AbstractInsnNode.isUnitInstance() =
    opcode == Opcodes.GETSTATIC &&
            this is FieldInsnNode && owner == "kotlin/Unit" && name == "INSTANCE"

fun AbstractInsnNode.isPrimitiveTypeConversion() =
    opcode in Opcodes.I2L..Opcodes.I2S




© 2015 - 2024 Weber Informatics LLC | Privacy Policy