All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.codegen.optimization.CapturedVarsOptimizationMethodTransformer.kt Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.codegen.optimization

import org.jetbrains.kotlin.builtins.PrimitiveType
import org.jetbrains.kotlin.codegen.AsmUtil
import org.jetbrains.kotlin.codegen.InsnSequence
import org.jetbrains.kotlin.codegen.asSequence
import org.jetbrains.kotlin.codegen.optimization.common.*
import org.jetbrains.kotlin.codegen.optimization.fixStack.peek
import org.jetbrains.kotlin.codegen.optimization.fixStack.top
import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue
import org.jetbrains.org.objectweb.asm.tree.analysis.Frame

class CapturedVarsOptimizationMethodTransformer : MethodTransformer() {
    override fun transform(internalClassName: String, methodNode: MethodNode) {
        Transformer(internalClassName, methodNode).run()
    }

    // Tracks proper usages of objects corresponding to captured variables.
    //
    // The 'kotlin.jvm.internal.Ref.*' instance can be replaced with a local variable, if
    //  * it is created inside a current method;
    //  * the only operations on it are ALOAD, ASTORE, DUP, POP, GETFIELD element, PUTFIELD element.
    //
    // Note that for code that doesn't create Ref objects explicitly these conditions are true,
    // unless the Ref object escapes to a local class constructor (including local classes for lambdas).
    //
    private class CapturedVarDescriptor(val newInsn: TypeInsnNode, val refType: Type, val valueType: Type) : ReferenceValueDescriptor {
        var hazard = false

        var initCallInsn: MethodInsnNode? = null
        var localVar: LocalVariableNode? = null
        var localVarIndex = -1
        val wrapperInsns: MutableCollection = LinkedHashSet()
        val getFieldInsns: MutableCollection = LinkedHashSet()
        val putFieldInsns: MutableCollection = LinkedHashSet()

        override fun onUseAsTainted() {
            hazard = true
        }

        fun canRewrite() = !hazard && initCallInsn != null
    }

    private class Transformer(private val internalClassName: String, private val methodNode: MethodNode) {
        private val refValues = ArrayList()
        private val refValuesByNewInsn = LinkedHashMap()

        fun run() {
            createRefValues()
            if (refValues.isEmpty()) return

            val frames = analyze(internalClassName, methodNode, Interpreter())
            trackPops(frames)
            assignLocalVars(frames)

            for (refValue in refValues) {
                if (refValue.canRewrite()) {
                    rewriteRefValue(refValue)
                }
            }

            methodNode.removeEmptyCatchBlocks()
            methodNode.removeUnusedLocalVariables()
        }

        private fun AbstractInsnNode.getIndex() = methodNode.instructions.indexOf(this)

        private fun createRefValues() {
            for (insn in methodNode.instructions.asSequence()) {
                if (insn.opcode == Opcodes.NEW && insn is TypeInsnNode) {
                    val type = Type.getObjectType(insn.desc)
                    if (AsmTypes.isSharedVarType(type)) {
                        val valueType = REF_TYPE_TO_ELEMENT_TYPE[type.internalName] ?: continue
                        val refValue = CapturedVarDescriptor(insn, type, valueType)
                        refValues.add(refValue)
                        refValuesByNewInsn[insn] = refValue
                    }
                }
            }
        }

        private inner class Interpreter : ReferenceTrackingInterpreter() {
            override fun newOperation(insn: AbstractInsnNode): BasicValue =
                refValuesByNewInsn[insn]?.let { ProperTrackedReferenceValue(it.refType, it) } ?: super.newOperation(insn)

            override fun processRefValueUsage(value: TrackedReferenceValue, insn: AbstractInsnNode, position: Int) {
                for (descriptor in value.descriptors) {
                    if (descriptor !is CapturedVarDescriptor) throw AssertionError("Unexpected descriptor: $descriptor")
                    when {
                        insn.opcode == Opcodes.DUP -> descriptor.wrapperInsns.add(insn)
                        insn.opcode == Opcodes.ALOAD -> descriptor.wrapperInsns.add(insn)
                        insn.opcode == Opcodes.ASTORE -> descriptor.wrapperInsns.add(insn)
                        insn.opcode == Opcodes.GETFIELD && insn is FieldInsnNode && insn.name == REF_ELEMENT_FIELD && position == 0 ->
                            descriptor.getFieldInsns.add(insn)
                        insn.opcode == Opcodes.PUTFIELD && insn is FieldInsnNode && insn.name == REF_ELEMENT_FIELD && position == 0 ->
                            descriptor.putFieldInsns.add(insn)
                        insn.opcode == Opcodes.INVOKESPECIAL && insn is MethodInsnNode && insn.name == INIT_METHOD_NAME && position == 0 ->
                            if (descriptor.initCallInsn != null && descriptor.initCallInsn != insn)
                                descriptor.hazard = true
                            else
                                descriptor.initCallInsn = insn
                        else -> descriptor.hazard = true
                    }
                }
            }
        }

        private fun trackPops(frames: Array?>) {
            for ((i, insn) in methodNode.instructions.asSequence().withIndex()) {
                val frame = frames[i] ?: continue
                when (insn.opcode) {
                    Opcodes.POP -> {
                        frame.top()?.getCapturedVarOrNull()?.run { wrapperInsns.add(insn) }
                    }
                    Opcodes.POP2 -> {
                        val top = frame.top()
                        if (top?.size == 1) {
                            top.getCapturedVarOrNull()?.hazard = true
                            frame.peek(1)?.getCapturedVarOrNull()?.hazard = true
                        }
                    }
                }
            }
        }

        private fun BasicValue.getCapturedVarOrNull(): CapturedVarDescriptor? =
            (this as? ProperTrackedReferenceValue)?.descriptor as? CapturedVarDescriptor

        private fun assignLocalVars(frames: Array?>) {
            for (localVar in methodNode.localVariables) {
                val type = Type.getType(localVar.desc)
                if (!AsmTypes.isSharedVarType(type)) continue

                val startFrame = frames[localVar.start.getIndex()] ?: continue

                val refValue = startFrame.getLocal(localVar.index) as? ProperTrackedReferenceValue ?: continue
                val descriptor = refValue.descriptor as? CapturedVarDescriptor ?: continue

                if (descriptor.hazard) continue

                if (descriptor.localVar == null) {
                    descriptor.localVar = localVar
                } else {
                    descriptor.hazard = true
                }
            }

            for (refValue in refValues) {
                if (refValue.hazard) continue
                if (refValue.localVar == null || refValue.valueType.size != 1) {
                    refValue.localVarIndex = methodNode.maxLocals
                    methodNode.maxLocals += refValue.valueType.size
                } else {
                    refValue.localVarIndex = refValue.localVar!!.index
                }
            }
        }

        private fun LocalVariableNode.findCleanInstructions() =
            InsnSequence(methodNode.instructions).dropWhile { it != start }.takeWhile { it != end }.filter {
                it is VarInsnNode && it.opcode == Opcodes.ASTORE && it.`var` == index && it.previous?.opcode == Opcodes.ACONST_NULL
            }

        // Be careful to not remove instructions that are the only instruction for a line number. That will
        // break debugging. If the previous instruction is a line number and the following instruction is
        // a label followed by a line number, insert a nop instead of deleting the instruction.
        private fun InsnList.removeOrReplaceByNop(insn: AbstractInsnNode) {
            if (insn.previous is LineNumberNode && insn.next is LabelNode && insn.next.next is LineNumberNode) {
                set(insn, InsnNode(Opcodes.NOP))
            } else {
                remove(insn)
            }
        }

        private fun rewriteRefValue(capturedVar: CapturedVarDescriptor) {
            methodNode.instructions.run {
                val loadOpcode = capturedVar.valueType.getOpcode(Opcodes.ILOAD)
                val storeOpcode = capturedVar.valueType.getOpcode(Opcodes.ISTORE)

                val localVar = capturedVar.localVar
                if (localVar != null) {
                    if (capturedVar.putFieldInsns.none { it.getIndex() < localVar.start.getIndex() }) {
                        // variable needs to be initialized before its live range can begin
                        insertBefore(capturedVar.newInsn, InsnNode(AsmUtil.defaultValueOpcode(capturedVar.valueType)))
                        insertBefore(capturedVar.newInsn, VarInsnNode(storeOpcode, capturedVar.localVarIndex))
                    }

                    for (insn in localVar.findCleanInstructions()) {
                        // after visiting block codegen tries to delete all allocated references:
                        // see ExpressionCodegen.addLeaveTaskToRemoveLocalVariableFromFrameMap
                        if (storeOpcode == Opcodes.ASTORE) {
                            set(insn.previous, InsnNode(AsmUtil.defaultValueOpcode(capturedVar.valueType)))
                        } else {
                            remove(insn.previous)
                            remove(insn)
                        }
                    }

                    localVar.index = capturedVar.localVarIndex
                    localVar.desc = capturedVar.valueType.descriptor
                    localVar.signature = null
                }

                remove(capturedVar.newInsn)
                remove(capturedVar.initCallInsn!!)
                capturedVar.wrapperInsns.forEach { removeOrReplaceByNop(it) }
                capturedVar.getFieldInsns.forEach { set(it, VarInsnNode(loadOpcode, capturedVar.localVarIndex)) }
                capturedVar.putFieldInsns.forEach { set(it, VarInsnNode(storeOpcode, capturedVar.localVarIndex)) }
            }
        }

    }
}

internal const val REF_ELEMENT_FIELD = "element"
internal const val INIT_METHOD_NAME = ""

internal val REF_TYPE_TO_ELEMENT_TYPE = HashMap().apply {
    put(AsmTypes.OBJECT_REF_TYPE.internalName, AsmTypes.OBJECT_TYPE)
    PrimitiveType.entries.forEach {
        put(AsmTypes.sharedTypeForPrimitive(it).internalName, AsmTypes.valueTypeForPrimitive(it))
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy