All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.js.coroutine.CoroutinePasses.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2017 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.js.coroutine

import org.jetbrains.kotlin.js.backend.ast.*
import org.jetbrains.kotlin.js.backend.ast.metadata.*
import org.jetbrains.kotlin.js.inline.util.collectFreeVariables
import org.jetbrains.kotlin.js.inline.util.replaceNames
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils

fun JsNode.collectNodesToSplit(breakContinueTargets: Map): Set {
    val root = this
    val nodes = mutableSetOf()

    val visitor = object : RecursiveJsVisitor() {
        var childrenInSet = false
        var finallyLevel = 0

        override fun visitExpressionStatement(x: JsExpressionStatement) {
            super.visitExpressionStatement(x)
            if (x.expression.isSuspend) {
                nodes += x.expression
                childrenInSet = true
            }
            else {
                val assignment = JsAstUtils.decomposeAssignment(x.expression)
                if (assignment != null && assignment.second.isSuspend) {
                    nodes += assignment.second
                    childrenInSet = true
                }
            }
        }

        override fun visitReturn(x: JsReturn) {
            super.visitReturn(x)

            if (root in nodes || finallyLevel > 0) {
                nodes += x
                childrenInSet = true
            }
        }

        // We don't handle JsThrow case here the same way as we do for JsReturn.
        // Exception will be caught by the surrounding catch and then dispatched to a corresponding $exceptionState.
        // Even if there's no `catch` clause, we generate a fake one that dispatches to a finally block.

        override fun visitBreak(x: JsBreak) {
            super.visitBreak(x)

            val breakTarget = breakContinueTargets[x]!!
            if (breakTarget in nodes) {
                nodes += x
                childrenInSet = true
            }
        }

        override fun visitContinue(x: JsContinue) {
            super.visitContinue(x)

            val continueTarget = breakContinueTargets[x]!!
            if (continueTarget in nodes) {
                nodes += x
                childrenInSet = true
            }
        }

        override fun visitTry(x: JsTry) {
            if (x.finallyBlock != null) {
                finallyLevel++
            }
            super.visitTry(x)
            if (x.finallyBlock != null) {
                finallyLevel--
            }
        }

        override fun visitElement(node: JsNode) {
            val oldChildrenInSet = childrenInSet
            childrenInSet = false

            node.acceptChildren(this)

            if (childrenInSet) {
                nodes += node
            }
            else {
                childrenInSet = oldChildrenInSet
            }
        }
    }

    while (true) {
        val countBefore = nodes.size
        visitor.accept(this)
        val countAfter = nodes.size
        if (countAfter == countBefore) break
    }

    return nodes
}

fun List.replaceCoroutineFlowStatements(context: CoroutineTransformationContext) {
    val blockIndexes = withIndex().associate { (index, block) -> Pair(block, index) }

    val blockReplacementVisitor = object : JsVisitorWithContextImpl() {
        override fun endVisit(x: JsDebugger, ctx: JsContext) {
            val target = x.targetBlock
            if (target != null) {
                val lhs = JsNameRef(context.metadata.stateName, JsAstUtils.stateMachineReceiver())
                val rhs = JsIntLiteral(blockIndexes[target]!!)
                ctx.replaceMe(JsExpressionStatement(JsAstUtils.assignment(lhs, rhs).source(x.source)).apply {
                    targetBlock = true
                })
            }

            val exceptionTarget = x.targetExceptionBlock
            if (exceptionTarget != null) {
                val lhs = JsNameRef(context.metadata.exceptionStateName, JsAstUtils.stateMachineReceiver())
                val rhs = JsIntLiteral(blockIndexes[exceptionTarget]!!)
                ctx.replaceMe(JsExpressionStatement(JsAstUtils.assignment(lhs, rhs).source(x.source)).apply {
                    targetExceptionBlock = true
                })
            }

            val finallyPath = x.finallyPath
            if (finallyPath != null) {
                if (finallyPath.isNotEmpty()) {
                    val lhs = JsNameRef(context.metadata.finallyPathName, JsAstUtils.stateMachineReceiver())
                    val rhs = JsArrayLiteral(finallyPath.map { JsIntLiteral(blockIndexes[it]!!) })
                    ctx.replaceMe(JsExpressionStatement(JsAstUtils.assignment(lhs, rhs).source(x.source)).apply {
                        this.finallyPath = true
                    })
                }
                else {
                    ctx.removeMe()
                }
            }
        }
    }
    return forEach { blockReplacementVisitor.accept(it.jsBlock) }
}

fun CoroutineBlock.buildGraph(globalCatchBlock: CoroutineBlock?): Map> {
    // That's a little more than DFS due to need of tracking finally paths

    val visitedBlocks = mutableSetOf()
    val graph = mutableMapOf>()

    fun visitBlock(block: CoroutineBlock) {
        if (block in visitedBlocks) return

        for (finallyPath in block.collectFinallyPaths()) {
            for ((finallySource, finallyTarget) in (listOf(block) + finallyPath).zip(finallyPath)) {
                if (graph.getOrPut(finallySource) { mutableSetOf() }.add(finallyTarget)) {
                    visitedBlocks -= finallySource
                }
            }
        }

        visitedBlocks += block

        val successors = graph.getOrPut(block) { mutableSetOf() }
        successors += block.collectTargetBlocks()
        if (block == this && globalCatchBlock != null) {
            successors += globalCatchBlock
        }
        successors.forEach(::visitBlock)
    }

    visitBlock(this)

    return graph
}

private fun CoroutineBlock.collectTargetBlocks(): Set {
    val targetBlocks = mutableSetOf()
    jsBlock.accept(object : RecursiveJsVisitor() {
        override fun visitDebugger(x: JsDebugger) {
            targetBlocks += listOfNotNull(x.targetExceptionBlock) + listOfNotNull(x.targetBlock)
        }
    })
    return targetBlocks
}

private fun CoroutineBlock.collectFinallyPaths(): List> {
    val finallyPaths = mutableListOf>()
    jsBlock.accept(object : RecursiveJsVisitor() {
        override fun visitDebugger(x: JsDebugger) {
            x.finallyPath?.let { finallyPaths += it }
        }
    })
    return finallyPaths
}

fun JsBlock.replaceSpecialReferences(context: CoroutineTransformationContext) {
    val visitor = object : JsVisitorWithContextImpl() {
        override fun endVisit(x: JsThisRef, ctx: JsContext) {
            ctx.replaceMe(JsNameRef(context.receiverFieldName, JsThisRef()))
        }

        override fun visit(x: JsFunction, ctx: JsContext<*>) = false

        override fun endVisit(x: JsNameRef, ctx: JsContext) {
            when {
                x.coroutineReceiver -> {
                    ctx.replaceMe(JsThisRef())
                }

                x.coroutineController -> {
                    ctx.replaceMe(JsNameRef(context.controllerFieldName, x.qualifier).apply {
                        source = x.source
                        sideEffects = SideEffectKind.PURE
                    })
                }

                x.coroutineResult -> {
                    ctx.replaceMe(JsNameRef(context.metadata.resultName, x.qualifier).apply {
                        source = x.source
                        sideEffects = SideEffectKind.DEPENDS_ON_STATE
                    })
                }
            }
        }
    }
    visitor.accept(this)
}

fun JsBlock.replaceLocalVariables(context: CoroutineTransformationContext, localVariables: Set) {
    replaceSpecialReferences(context)

    val visitor = object : JsVisitorWithContextImpl() {
        override fun visit(x: JsFunction, ctx: JsContext<*>): Boolean = false

        override fun endVisit(x: JsFunction, ctx: JsContext) {
            val freeVars = x.collectFreeVariables().intersect(localVariables)
            if (freeVars.isNotEmpty()) {
                val wrapperFunction = JsFunction(x.scope.parent, JsBlock(), "")
                val wrapperInvocation = JsInvocation(wrapperFunction)
                wrapperFunction.body.statements += JsReturn(x)
                val nameMap = freeVars.associate { it to JsScope.declareTemporaryName(it.ident) }
                for (freeVar in freeVars) {
                    wrapperFunction.parameters += JsParameter(nameMap[freeVar]!!)
                    wrapperInvocation.arguments += JsNameRef(context.getFieldName(freeVar), JsThisRef())
                }
                x.body = replaceNames(x.body, nameMap.mapValues { it.value.makeRef() })
                ctx.replaceMe(wrapperInvocation)
            }
        }

        override fun endVisit(x: JsNameRef, ctx: JsContext) {
            if (x.qualifier == null && x.name in localVariables) {
                val fieldName = context.getFieldName(x.name!!)
                ctx.replaceMe(JsNameRef(fieldName, JsThisRef()).source(x.source))
            }
        }

        override fun endVisit(x: JsVars, ctx: JsContext) {
            val assignments = x.vars.mapNotNull {
                val fieldName = context.getFieldName(it.name)
                val initExpression = it.initExpression
                if (initExpression != null) {
                    JsAstUtils.assignment(JsNameRef(fieldName, JsThisRef()), it.initExpression)
                }
                else {
                    null
                }
            }

            if (assignments.isNotEmpty()) {
                ctx.replaceMe(JsExpressionStatement(JsAstUtils.newSequence(assignments)))
            }
            else {
                ctx.removeMe()
            }
        }
    }
    visitor.accept(this)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy