All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.codegen.optimization.nullCheck.RedundantNullCheckMethodTransformer.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2017 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.codegen.optimization.nullCheck

import org.jetbrains.kotlin.codegen.coroutines.withInstructionAdapter
import org.jetbrains.kotlin.codegen.inline.ReifiedTypeInliner
import org.jetbrains.kotlin.codegen.intrinsics.IntrinsicMethods
import org.jetbrains.kotlin.codegen.optimization.DeadCodeEliminationMethodTransformer
import org.jetbrains.kotlin.codegen.optimization.common.OptimizationBasicInterpreter
import org.jetbrains.kotlin.codegen.optimization.common.StrictBasicValue
import org.jetbrains.kotlin.codegen.optimization.common.debugText
import org.jetbrains.kotlin.codegen.optimization.common.isInsn
import org.jetbrains.kotlin.codegen.optimization.fixStack.peek
import org.jetbrains.kotlin.codegen.optimization.fixStack.top
import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
import org.jetbrains.kotlin.utils.SmartList
import org.jetbrains.org.objectweb.asm.Label
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter
import org.jetbrains.org.objectweb.asm.tree.*

class RedundantNullCheckMethodTransformer : MethodTransformer() {
    override fun transform(internalClassName: String, methodNode: MethodNode) {
        while (TransformerPass(internalClassName, methodNode).run()) {}
    }

    private class TransformerPass(val internalClassName: String, val methodNode: MethodNode) {
        private var changes = false

        fun run(): Boolean {
            val stackOnThrowExceptions = hashMapOf()
            val checkedReferenceTypes = analyzeTypesAndRemoveDeadCode(stackOnThrowExceptions)
            eliminateRedundantChecks(checkedReferenceTypes, stackOnThrowExceptions)

            return changes
        }

        private fun analyzeTypesAndRemoveDeadCode(stackOnThrowExceptionsHolder: MutableMap): Map {
            val insns = methodNode.instructions.toArray()
            val frames = analyze(internalClassName, methodNode, OptimizationBasicInterpreter())

            val checkedReferenceTypes = HashMap()
            for (i in insns.indices) {
                val insn = insns[i]
                val frame = frames[i]
                if (insn.isInstanceOfOrNullCheck()) {
                    checkedReferenceTypes[insn] = frame?.top()?.type ?: continue
                }
                else if (insn.isCheckParameterIsNotNull() || insn.isCheckExpressionValueIsNotNull()) {
                    checkedReferenceTypes[insn] = frame?.peek(1)?.type ?: continue
                }
                else if (insn.isThrowNpeIntrinsic()) {
                    stackOnThrowExceptionsHolder[insn] = frame?.maxStackSize ?: continue
                }
            }

            val dceResult = DeadCodeEliminationMethodTransformer().removeDeadCodeByFrames(methodNode, frames)
            if (dceResult.hasRemovedAnything()) {
                changes = true
            }

            return checkedReferenceTypes
        }

        private fun eliminateRedundantChecks(
                checkedReferenceTypes: Map,
                stackOnThrowExceptions: MutableMap
        ) {
            val nullabilityAssumptions = injectNullabilityAssumptions(checkedReferenceTypes, stackOnThrowExceptions)

            val nullabilityMap = analyzeNullabilities()

            nullabilityAssumptions.revert()

            transformTrivialChecks(nullabilityMap)
        }

        private fun injectNullabilityAssumptions(
                checkedReferenceTypes: Map,
                stackOnThrowExceptions: MutableMap
        ) = NullabilityAssumptionsBuilder(checkedReferenceTypes, stackOnThrowExceptions).injectNullabilityAssumptions()

        private fun analyzeNullabilities(): Map {
            val frames = analyze(internalClassName, methodNode, NullabilityInterpreter())
            val insns = methodNode.instructions.toArray()
            val nullabilityMap = LinkedHashMap()
            for (i in insns.indices) {
                val frame = frames[i] ?: continue
                val insn = insns[i]

                val value = when {
                    insn.isInstanceOfOrNullCheck() -> frame.top()
                    insn.isCheckExpressionValueIsNotNull() -> frame.peek(1)
                    else -> null
                } as? StrictBasicValue ?: continue

                val nullability = value.getNullability()
                if (nullability == Nullability.NULLABLE) continue
                nullabilityMap[insn] = value
            }
            return nullabilityMap
        }

        private fun transformTrivialChecks(nullabilityMap: Map) {
            for ((insn, value) in nullabilityMap) {
                val nullability = value.getNullability()
                when (insn.opcode) {
                    Opcodes.IFNULL -> transformTrivialNullJump(insn as JumpInsnNode, nullability == Nullability.NULL)
                    Opcodes.IFNONNULL -> transformTrivialNullJump(insn as JumpInsnNode, nullability == Nullability.NOT_NULL)
                    Opcodes.INSTANCEOF -> transformInstanceOf(insn as TypeInsnNode, nullability, value)

                    Opcodes.INVOKESTATIC -> {
                        if (insn.isCheckExpressionValueIsNotNull()) {
                            transformTrivialCheckExpressionValueIsNotNull(insn, nullability)
                        }
                    }
                }
            }
        }

        private fun transformTrivialNullJump(insn: JumpInsnNode, alwaysTrue: Boolean) {
            changes = true

            methodNode.instructions.run {
                popReferenceValueBefore(insn)
                if (alwaysTrue) {
                    set(insn, JumpInsnNode(Opcodes.GOTO, insn.label))
                }
                else {
                    remove(insn)
                }
            }
        }

        private fun transformInstanceOf(insn: TypeInsnNode, nullability: Nullability, value: StrictBasicValue) {
            if (ReifiedTypeInliner.isOperationReifiedMarker(insn.previous)) return
            if (nullability == Nullability.NULL) {
                changes = true
                transformTrivialInstanceOf(insn, false)
            }
            else if (nullability == Nullability.NOT_NULL && value.type.internalName == insn.desc) {
                changes = true
                transformTrivialInstanceOf(insn, true)
            }
        }

        private fun transformTrivialInstanceOf(insn: AbstractInsnNode, constValue: Boolean) {
            methodNode.instructions.run {
                popReferenceValueBefore(insn)
                set(insn, if (constValue) InsnNode(Opcodes.ICONST_1) else InsnNode(Opcodes.ICONST_0))
            }
        }

        private fun transformTrivialCheckExpressionValueIsNotNull(insn: AbstractInsnNode, nullability: Nullability) {
            if (nullability != Nullability.NOT_NULL) return
            val ldcInsn = insn.previous?.takeIf { it.opcode == Opcodes.LDC } ?: return
            methodNode.instructions.run {
                popReferenceValueBefore(ldcInsn)
                remove(ldcInsn)
                remove(insn)
            }
        }

        private inner class NullabilityAssumptionsBuilder(
                val checkedReferenceTypes: Map,
                val stackOnThrowExceptions: MutableMap
        ) {

            private val checksDependingOnVariable = HashMap>()

            fun injectNullabilityAssumptions(): NullabilityAssumptions {
                collectVariableDependentChecks()
                return injectAssumptions()
            }

            private fun collectVariableDependentChecks() {
                insnLoop@ for (insn in methodNode.instructions) {
                    when {
                        insn.isInstanceOfOrNullCheck() -> {
                            val previous = insn.previous ?: continue@insnLoop
                            if (previous.opcode == Opcodes.ALOAD) {
                                addDependentCheck(insn, previous as VarInsnNode)
                            }
                            else if (previous.opcode == Opcodes.DUP) {
                                val previous2 = previous.previous ?: continue@insnLoop
                                if (previous2.opcode == Opcodes.ALOAD) {
                                    addDependentCheck(insn, previous2 as VarInsnNode)
                                }
                            }
                        }

                        insn.isCheckParameterIsNotNull() -> {
                            val ldcInsn = insn.previous ?: continue@insnLoop
                            if (ldcInsn.opcode != Opcodes.LDC) continue@insnLoop
                            val aLoadInsn = ldcInsn.previous ?: continue@insnLoop
                            if (aLoadInsn.opcode != Opcodes.ALOAD) continue@insnLoop
                            addDependentCheck(insn, aLoadInsn as VarInsnNode)
                        }

                        insn.isCheckExpressionValueIsNotNull() -> {
                            val ldcInsn = insn.previous ?: continue@insnLoop
                            if (ldcInsn.opcode != Opcodes.LDC) continue@insnLoop
                            var aLoadInsn: VarInsnNode? = null
                            val insn1 = ldcInsn.previous ?: continue@insnLoop
                            if (insn1.opcode == Opcodes.ALOAD) {
                                aLoadInsn = insn1 as VarInsnNode
                            }
                            else if (insn1.opcode == Opcodes.DUP) {
                                val insn2 = insn1.previous ?: continue@insnLoop
                                if (insn2.opcode == Opcodes.ALOAD) {
                                    aLoadInsn = insn2 as VarInsnNode
                                }
                            }
                            if (aLoadInsn == null) continue@insnLoop
                            addDependentCheck(insn, aLoadInsn)
                        }

                    }
                }
            }

            private fun addDependentCheck(insn: AbstractInsnNode, aLoadInsn: VarInsnNode) {
                checksDependingOnVariable.getOrPut(aLoadInsn.`var`) {
                    SmartList()
                }.add(insn)
            }

            private fun injectAssumptions(): NullabilityAssumptions {
                val nullabilityAssumptions = NullabilityAssumptions()
                for ((varIndex, dependentChecks) in checksDependingOnVariable) {
                    for (checkInsn in dependentChecks) {
                        val varType = checkedReferenceTypes[checkInsn]
                                      ?: AsmTypes.OBJECT_TYPE
                        nullabilityAssumptions.injectAssumptionsForCheck(varIndex, checkInsn, varType)
                    }
                }
                for (insn in methodNode.instructions) {
                    if (insn.isThrowNpeIntrinsic()) {
                        nullabilityAssumptions.injectCodeForThrowNpe(insn)
                    }
                }
                return nullabilityAssumptions
            }

            private fun NullabilityAssumptions.injectAssumptionsForCheck(varIndex: Int, insn: AbstractInsnNode, varType: Type) {
                when (insn.opcode) {
                    Opcodes.IFNULL,
                    Opcodes.IFNONNULL ->
                        injectAssumptionsForNullCheck(varIndex, insn as JumpInsnNode, varType)
                    Opcodes.INVOKESTATIC -> {
                        assert(insn.isCheckParameterIsNotNull() || insn.isCheckExpressionValueIsNotNull()) {
                            "Expected non-null assertion: ${insn.debugText}"
                        }
                        injectAssumptionsForNotNullAssertion(varIndex, insn, varType)
                    }
                    Opcodes.INSTANCEOF ->
                        injectAssumptionsForInstanceOfCheck(varIndex, insn, varType)
                }
            }

            private fun NullabilityAssumptions.injectAssumptionsForNullCheck(varIndex: Int, insn: JumpInsnNode, varType: Type) {
                //  ALOAD v
                //  IFNULL L
                //  <...>   -- v is not null here
                // L:
                //  <...>   -- v is null here

                val jumpsIfNull = insn.opcode == Opcodes.IFNULL
                val originalLabel = insn.label
                originalLabels[insn] = originalLabel
                insn.label = synthetic(LabelNode(Label()))

                val insertAfterNull = if (jumpsIfNull) insn.label else insn
                val insertAfterNonNull = if (jumpsIfNull) insn else insn.label

                methodNode.instructions.run {
                    add(insn.label)

                    insert(insertAfterNull, listOfSynthetics {
                        aconst(null)
                        store(varIndex, varType)
                        if (jumpsIfNull) {
                            goTo(originalLabel.label)
                        }
                    })

                    insert(insertAfterNonNull, listOfSynthetics {
                        anew(varType)
                        store(varIndex, varType)
                        if (!jumpsIfNull) {
                            goTo(originalLabel.label)
                        }
                    })
                }
            }

            private fun NullabilityAssumptions.injectAssumptionsForNotNullAssertion(varIndex: Int, insn: AbstractInsnNode, varType: Type) {
                //  ALOAD v
                //  LDC *
                //  INVOKESTATIC checkParameterIsNotNull
                //  <...>   -- v is not null here (otherwise an exception was thrown)

                //  ALOAD v
                //  DUP
                //  LDC *
                //  INVOKESTATIC checkExpressionValueIsNotNull
                //  <...>   -- v is not null here (otherwise an exception was thrown)

                methodNode.instructions.insert(insn, listOfSynthetics {
                    anew(varType)
                    store(varIndex, varType)
                })
            }

            private fun NullabilityAssumptions.injectAssumptionsForInstanceOfCheck(varIndex: Int, insn: AbstractInsnNode, varType: Type) {
                //  ALOAD v
                //  INSTANCEOF T
                //  IFEQ L
                //  <...>  -- v is not null here (because it is an instance of T)
                // L:
                //  <...>  -- v is something else here (maybe null)

                val next = insn.next ?: return
                if (next.opcode != Opcodes.IFEQ && next.opcode != Opcodes.IFNE) return
                if (next !is JumpInsnNode) return
                val jumpsIfInstance = next.opcode == Opcodes.IFNE

                val originalLabel: LabelNode?
                val insertAfterNotNull: AbstractInsnNode
                if (jumpsIfInstance) {
                    originalLabel = next.label
                    originalLabels[next] = next.label
                    val newLabel = synthetic(LabelNode(Label()))
                    methodNode.instructions.add(newLabel)
                    next.label = newLabel
                    insertAfterNotNull = newLabel
                }
                else {
                    originalLabel = null
                    insertAfterNotNull = next
                }

                methodNode.instructions.run {
                    insert(insertAfterNotNull, listOfSynthetics {
                        anew(varType)
                        store(varIndex, varType)
                        if (originalLabel != null) {
                            goTo(originalLabel.label)
                        }
                    })
                }
            }

            private fun NullabilityAssumptions.injectCodeForThrowNpe(insn: AbstractInsnNode) {
                methodNode.instructions.run {
                    insert(insn, listOfSynthetics {
                        aconst(null)
                        athrow()
                    })
                }

                methodNode.maxStack = Math.max(methodNode.maxStack, (stackOnThrowExceptions[insn] ?: -1) + 1)
            }

        }

        inner class NullabilityAssumptions {
            val originalLabels = HashMap()
            val syntheticInstructions = ArrayList()

            fun  synthetic(insn: T): T {
                syntheticInstructions.add(insn)
                return insn
            }

            inline fun listOfSynthetics(block: InstructionAdapter.() -> Unit): InsnList {
                val insnList = withInstructionAdapter(block)
                for (insn in insnList) {
                    synthetic(insn)
                }
                return insnList
            }

            fun revert() {
                methodNode.instructions.run {
                    syntheticInstructions.forEach { remove(it) }
                }
                for ((jumpInsn, originalLabel) in originalLabels) {
                    jumpInsn.label = originalLabel
                }
            }
        }
    }
}

internal fun AbstractInsnNode.isInstanceOfOrNullCheck() =
        opcode == Opcodes.INSTANCEOF ||
        opcode == Opcodes.IFNULL ||
        opcode == Opcodes.IFNONNULL

internal fun AbstractInsnNode.isCheckParameterIsNotNull() =
        isInsn(Opcodes.INVOKESTATIC) {
            owner == IntrinsicMethods.INTRINSICS_CLASS_NAME &&
            name == "checkParameterIsNotNull" &&
            desc == "(Ljava/lang/Object;Ljava/lang/String;)V"
        }

internal fun AbstractInsnNode.isCheckExpressionValueIsNotNull() =
        isInsn(Opcodes.INVOKESTATIC) {
            owner == IntrinsicMethods.INTRINSICS_CLASS_NAME &&
            name == "checkExpressionValueIsNotNull" &&
            desc == "(Ljava/lang/Object;Ljava/lang/String;)V"
        }

internal fun AbstractInsnNode.isThrowNpeIntrinsic() =
        isInsn(Opcodes.INVOKESTATIC) {
            owner == IntrinsicMethods.INTRINSICS_CLASS_NAME &&
            name == "throwNpe" &&
            desc == "()V"
        }

internal fun InsnList.popReferenceValueBefore(insn: AbstractInsnNode) {
    val prev = insn.previous
    when (prev?.opcode) {
        Opcodes.ACONST_NULL,
        Opcodes.DUP,
        Opcodes.ALOAD ->
            remove(prev)
        else ->
            insertBefore(insn, InsnNode(Opcodes.POP))
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy