All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.codegen.optimization.boxing.RedundantBoxingMethodTransformer.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2015 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.codegen.optimization.boxing

import com.intellij.openapi.util.Pair
import org.jetbrains.kotlin.codegen.inline.insnOpcodeText
import org.jetbrains.kotlin.codegen.inline.insnText
import org.jetbrains.kotlin.codegen.intrinsics.IntrinsicMethods
import org.jetbrains.kotlin.codegen.optimization.common.StrictBasicValue
import org.jetbrains.kotlin.codegen.optimization.common.remapLocalVariables
import org.jetbrains.kotlin.codegen.optimization.fixStack.peek
import org.jetbrains.kotlin.codegen.optimization.fixStack.top
import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer
import org.jetbrains.kotlin.codegen.state.GenerationState
import org.jetbrains.org.objectweb.asm.Label
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue
import org.jetbrains.org.objectweb.asm.tree.analysis.Frame
import java.util.*

class RedundantBoxingMethodTransformer(private val generationState: GenerationState) : MethodTransformer() {

    override fun transform(internalClassName: String, node: MethodNode) {
        val interpreter = RedundantBoxingInterpreter(node.instructions, generationState)
        val frames = MethodTransformer.analyze(internalClassName, node, interpreter)

        interpretPopInstructionsForBoxedValues(interpreter, node, frames)

        val valuesToOptimize = interpreter.candidatesBoxedValues

        if (!valuesToOptimize.isEmpty) {
            // has side effect on valuesToOptimize and frames, containing BoxedBasicValues that are unsafe to remove
            removeValuesClashingWithVariables(valuesToOptimize, node, frames)

            adaptLocalVariableTableForBoxedValues(node, frames)

            node.remapLocalVariables(buildVariablesRemapping(valuesToOptimize, node))

            adaptInstructionsForBoxedValues(node, valuesToOptimize)
        }
    }

    private fun interpretPopInstructionsForBoxedValues(
        interpreter: RedundantBoxingInterpreter,
        node: MethodNode,
        frames: Array?>
    ) {
        for (i in frames.indices) {
            val insn = node.instructions[i]
            if (insn.opcode != Opcodes.POP && insn.opcode != Opcodes.POP2) {
                continue
            }

            val frame = frames[i] ?: continue

            val top = frame.top()!!
            interpreter.processPopInstruction(insn, top)

            if (top.size == 1 && insn.opcode == Opcodes.POP2) {
                interpreter.processPopInstruction(insn, frame.peek(1)!!)
            }
        }
    }

    private fun removeValuesClashingWithVariables(
        values: RedundantBoxedValuesCollection,
        node: MethodNode,
        frames: Array>
    ) {
        while (removeValuesClashingWithVariablesPass(values, node, frames)) {
            // do nothing
        }
    }

    private fun removeValuesClashingWithVariablesPass(
        values: RedundantBoxedValuesCollection,
        node: MethodNode,
        frames: Array?>
    ): Boolean {
        var needToRepeat = false

        for (localVariableNode in node.localVariables) {
            if (Type.getType(localVariableNode.desc).sort != Type.OBJECT) {
                continue
            }

            val variableValues = getValuesStoredOrLoadedToVariable(localVariableNode, node, frames)

            val boxed = variableValues.filterIsInstance()

            if (boxed.isEmpty()) continue

            val firstBoxed = boxed.first().descriptor
            if (isUnsafeToRemoveBoxingForConnectedValues(variableValues, firstBoxed.unboxedType)) {
                for (value in boxed) {
                    val descriptor = value.descriptor
                    if (descriptor.isSafeToRemove) {
                        values.remove(descriptor)
                        needToRepeat = true
                    }
                }
            }
        }

        return needToRepeat
    }

    private fun isUnsafeToRemoveBoxingForConnectedValues(usedValues: List, unboxedType: Type): Boolean =
        usedValues.any { input ->
            if (input === StrictBasicValue.UNINITIALIZED_VALUE) return@any false
            if (input !is BoxedBasicValue) return@any true

            val descriptor = input.descriptor
            !descriptor.isSafeToRemove || descriptor.unboxedType != unboxedType
        }

    private fun adaptLocalVariableTableForBoxedValues(node: MethodNode, frames: Array>) {
        for (localVariableNode in node.localVariables) {
            if (Type.getType(localVariableNode.desc).sort != Type.OBJECT) {
                continue
            }

            for (value in getValuesStoredOrLoadedToVariable(localVariableNode, node, frames)) {
                if (value !is BoxedBasicValue) continue

                val descriptor = value.descriptor
                if (!descriptor.isSafeToRemove) continue
                localVariableNode.desc = descriptor.unboxedType.descriptor
            }
        }
    }

    private fun getValuesStoredOrLoadedToVariable(
        localVariableNode: LocalVariableNode,
        node: MethodNode,
        frames: Array?>
    ): List {
        val values = ArrayList()
        val insnList = node.instructions
        val localVariableStart = insnList.indexOf(localVariableNode.start)
        val localVariableEnd = insnList.indexOf(localVariableNode.end)

        frames[localVariableStart]?.let { frameForStartInsn ->
            frameForStartInsn.getLocal(localVariableNode.index)?.let { localVarValue ->
                values.add(localVarValue)
            }
        }

        for (i in localVariableStart until localVariableEnd) {
            if (i < 0 || i >= insnList.size()) continue
            val frame = frames[i] ?: continue
            val insn = insnList[i]
            if ((insn.opcode == Opcodes.ASTORE || insn.opcode == Opcodes.ALOAD) &&
                (insn as VarInsnNode).`var` == localVariableNode.index) {
                if (insn.getOpcode() == Opcodes.ASTORE) {
                    values.add(frame.top()!!)
                } else {
                    values.add(frame.getLocal(insn.`var`))
                }
            }
        }

        return values
    }

    private fun buildVariablesRemapping(values: RedundantBoxedValuesCollection, node: MethodNode): IntArray {
        val doubleSizedVars = HashSet()
        for (valueDescriptor in values) {
            if (valueDescriptor.isDoubleSize()) {
                doubleSizedVars.addAll(valueDescriptor.getVariablesIndexes())
            }
        }

        node.maxLocals += doubleSizedVars.size
        val remapping = IntArray(node.maxLocals)
        for (i in remapping.indices) {
            remapping[i] = i
        }

        for (varIndex in doubleSizedVars) {
            for (i in varIndex + 1..remapping.lastIndex) {
                remapping[i]++
            }
        }

        return remapping
    }

    private fun adaptInstructionsForBoxedValues(
        node: MethodNode,
        values: RedundantBoxedValuesCollection
    ) {
        for (value in values) {
            adaptInstructionsForBoxedValue(node, value)
        }
    }

    private fun adaptInstructionsForBoxedValue(node: MethodNode, value: BoxedValueDescriptor) {
        adaptBoxingInstruction(node, value)

        for (cast in value.getUnboxingWithCastInsns()) {
            adaptCastInstruction(node, value, cast)
        }

        for (insn in value.getAssociatedInsns()) {
            adaptInstruction(node, insn, value)
        }
    }

    private fun adaptBoxingInstruction(node: MethodNode, value: BoxedValueDescriptor) {
        if (!value.isFromProgressionIterator()) {
            node.instructions.remove(value.boxingInsn)
        } else {
            val iterator = value.progressionIterator ?: error("iterator should not be null because isFromProgressionIterator returns true")

            //add checkcast to kotlin/Iterator before next() call
            node.instructions.insertBefore(value.boxingInsn, TypeInsnNode(Opcodes.CHECKCAST, iterator.type.internalName))

            //invoke concrete method (kotlin/iterator.next())
            node.instructions.set(
                value.boxingInsn,
                MethodInsnNode(
                    Opcodes.INVOKEVIRTUAL,
                    iterator.type.internalName, iterator.nextMethodName, iterator.nextMethodDesc,
                    false
                )
            )
        }
    }

    private fun adaptCastInstruction(
        node: MethodNode,
        value: BoxedValueDescriptor,
        castWithType: Pair
    ) {
        val castInsn = castWithType.getFirst()
        val castInsnsListener = MethodNode(Opcodes.API_VERSION)
        InstructionAdapter(castInsnsListener).cast(value.unboxedType, castWithType.getSecond())

        for (insn in castInsnsListener.instructions.toArray()) {
            node.instructions.insertBefore(castInsn, insn)
        }

        node.instructions.remove(castInsn)
    }

    private fun adaptInstruction(
        node: MethodNode, insn: AbstractInsnNode, value: BoxedValueDescriptor
    ) {
        val isDoubleSize = value.isDoubleSize()

        when (insn.opcode) {
            Opcodes.POP ->
                if (isDoubleSize) {
                    node.instructions.set(insn, InsnNode(Opcodes.POP2))
                }

            Opcodes.DUP ->
                if (isDoubleSize) {
                    node.instructions.set(insn, InsnNode(Opcodes.DUP2))
                }

            Opcodes.ASTORE, Opcodes.ALOAD -> {
                val storeOpcode = value.unboxedType.getOpcode(if (insn.opcode == Opcodes.ASTORE) Opcodes.ISTORE else Opcodes.ILOAD)
                node.instructions.set(insn, VarInsnNode(storeOpcode, (insn as VarInsnNode).`var`))
            }

            Opcodes.INSTANCEOF -> {
                node.instructions.insertBefore(
                    insn,
                    InsnNode(if (isDoubleSize) Opcodes.POP2 else Opcodes.POP)
                )
                node.instructions.set(insn, InsnNode(Opcodes.ICONST_1))
            }

            Opcodes.INVOKESTATIC -> {
                when {
                    insn.isAreEqualIntrinsic() ->
                        adaptAreEqualIntrinsic(node, insn, value)
                    insn.isJavaLangComparableCompareTo() ->
                        adaptJavaLangComparableCompareTo(node, insn, value)
                    insn.isJavaLangClassBoxing() ||
                            insn.isJavaLangClassUnboxing() ->
                        node.instructions.remove(insn)
                    else ->
                        throwCannotAdaptInstruction(insn)
                }
            }

            Opcodes.INVOKEINTERFACE -> {
                if (insn.isJavaLangComparableCompareTo()) {
                    adaptJavaLangComparableCompareTo(node, insn, value)
                } else {
                    throwCannotAdaptInstruction(insn)
                }
            }

            Opcodes.CHECKCAST,
            Opcodes.INVOKEVIRTUAL ->
                node.instructions.remove(insn)

            else ->
                throwCannotAdaptInstruction(insn)
        }
    }

    private fun throwCannotAdaptInstruction(insn: AbstractInsnNode): Nothing =
        throw AssertionError("Cannot adapt instruction: ${insn.insnText}")

    private fun adaptAreEqualIntrinsic(
        node: MethodNode,
        insn: AbstractInsnNode,
        value: BoxedValueDescriptor
    ) {
        val unboxedType = value.unboxedType

        when (unboxedType.sort) {
            Type.BOOLEAN, Type.BYTE, Type.SHORT, Type.INT, Type.CHAR ->
                adaptAreEqualIntrinsicForInt(node, insn)
            Type.LONG ->
                adaptAreEqualIntrinsicForLong(node, insn)
            Type.OBJECT -> {
            }
            else ->
                throw AssertionError("Unexpected unboxed type kind: $unboxedType")
        }
    }

    private fun adaptAreEqualIntrinsicForInt(node: MethodNode, insn: AbstractInsnNode) {
        node.instructions.run {
            val next = insn.next
            if (next != null && (next.opcode == Opcodes.IFEQ || next.opcode == Opcodes.IFNE)) {
                fuseAreEqualWithBranch(node, insn, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPEQ)
                remove(insn)
                remove(next)
            } else {
                ifEqual1Else0(node, insn, Opcodes.IF_ICMPNE)
                remove(insn)
            }
        }
    }

    private fun adaptAreEqualIntrinsicForLong(node: MethodNode, insn: AbstractInsnNode) {
        node.instructions.run {
            insertBefore(insn, InsnNode(Opcodes.LCMP))
            val next = insn.next
            if (next != null && (next.opcode == Opcodes.IFEQ || next.opcode == Opcodes.IFNE)) {
                fuseAreEqualWithBranch(node, insn, Opcodes.IFNE, Opcodes.IFEQ)
                remove(insn)
                remove(next)
            } else {
                ifEqual1Else0(node, insn, Opcodes.IFNE)
                remove(insn)
            }
        }
    }

    private fun fuseAreEqualWithBranch(
        node: MethodNode,
        insn: AbstractInsnNode,
        ifEqualOpcode: Int,
        ifNotEqualOpcode: Int
    ) {
        node.instructions.run {
            val next = insn.next
            assert(next is JumpInsnNode) { "JumpInsnNode expected: $next" }
            val nextLabel = (next as JumpInsnNode).label
            when {
                next.getOpcode() == Opcodes.IFEQ ->
                    insertBefore(insn, JumpInsnNode(ifEqualOpcode, nextLabel))
                next.getOpcode() == Opcodes.IFNE ->
                    insertBefore(insn, JumpInsnNode(ifNotEqualOpcode, nextLabel))
                else ->
                    throw AssertionError("IFEQ or IFNE expected: ${next.insnOpcodeText}")
            }
        }
    }

    private fun ifEqual1Else0(node: MethodNode, insn: AbstractInsnNode, ifneOpcode: Int) {
        node.instructions.run {
            val lNotEqual = LabelNode(Label())
            val lDone = LabelNode(Label())
            insertBefore(insn, JumpInsnNode(ifneOpcode, lNotEqual))
            insertBefore(insn, InsnNode(Opcodes.ICONST_1))
            insertBefore(insn, JumpInsnNode(Opcodes.GOTO, lDone))
            insertBefore(insn, lNotEqual)
            insertBefore(insn, InsnNode(Opcodes.ICONST_0))
            insertBefore(insn, lDone)
        }
    }

    private fun adaptJavaLangComparableCompareTo(
        node: MethodNode,
        insn: AbstractInsnNode,
        value: BoxedValueDescriptor
    ) {
        val unboxedType = value.unboxedType

        when (unboxedType.sort) {
            Type.BOOLEAN, Type.BYTE, Type.SHORT, Type.INT, Type.CHAR ->
                adaptJavaLangComparableCompareToForInt(node, insn)
            Type.LONG ->
                adaptJavaLangComparableCompareToForLong(node, insn)
            Type.FLOAT ->
                adaptJavaLangComparableCompareToForFloat(node, insn)
            Type.DOUBLE ->
                adaptJavaLangComparableCompareToForDouble(node, insn)
            else ->
                throw AssertionError("Unexpected unboxed type kind: $unboxedType")
        }
    }

    private fun adaptJavaLangComparableCompareToForInt(node: MethodNode, insn: AbstractInsnNode) {
        node.instructions.run {
            val next = insn.next
            val next2 = next?.next
            when {
                next != null && next2 != null &&
                        next.opcode == Opcodes.ICONST_0 &&
                        next2.opcode >= Opcodes.IF_ICMPEQ && next2.opcode <= Opcodes.IF_ICMPLE -> {
                    // Fuse: compareTo + ICONST_0 + IF_ICMPxx -> IF_ICMPxx
                    remove(insn)
                    remove(next)
                }

                next != null &&
                        next.opcode >= Opcodes.IFEQ && next.opcode <= Opcodes.IFLE -> {
                    // Fuse: compareTo + IFxx -> IF_ICMPxx
                    val nextLabel = (next as JumpInsnNode).label
                    val ifCmpOpcode = next.opcode - Opcodes.IFEQ + Opcodes.IF_ICMPEQ
                    insertBefore(insn, JumpInsnNode(ifCmpOpcode, nextLabel))
                    remove(insn)
                    remove(next)
                }

                else -> {
                    // Can't fuse with branching instruction. Use Intrinsics#compare(int, int).
                    set(insn, MethodInsnNode(Opcodes.INVOKESTATIC, IntrinsicMethods.INTRINSICS_CLASS_NAME, "compare", "(II)I", false))
                }
            }
        }
    }

    private fun adaptJavaLangComparableCompareToForLong(node: MethodNode, insn: AbstractInsnNode) {
        node.instructions.set(insn, InsnNode(Opcodes.LCMP))
    }

    private fun adaptJavaLangComparableCompareToForFloat(node: MethodNode, insn: AbstractInsnNode) {
        node.instructions.set(insn, MethodInsnNode(Opcodes.INVOKESTATIC, "java/lang/Float", "compare", "(FF)I", false))
    }

    private fun adaptJavaLangComparableCompareToForDouble(node: MethodNode, insn: AbstractInsnNode) {
        node.instructions.set(insn, MethodInsnNode(Opcodes.INVOKESTATIC, "java/lang/Double", "compare", "(DD)I", false))
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy