All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.js.translate.expression.WhenTranslator.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2018 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.js.translate.expression

import org.jetbrains.kotlin.backend.common.CodegenUtil
import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.cfg.WhenChecker
import org.jetbrains.kotlin.config.languageVersionSettings
import org.jetbrains.kotlin.descriptors.CallableDescriptor
import org.jetbrains.kotlin.descriptors.ClassDescriptor
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.js.backend.ast.*
import org.jetbrains.kotlin.js.translate.context.Namer
import org.jetbrains.kotlin.js.translate.context.TranslationContext
import org.jetbrains.kotlin.js.translate.general.AbstractTranslator
import org.jetbrains.kotlin.js.translate.general.Translation
import org.jetbrains.kotlin.js.translate.operation.InOperationTranslator
import org.jetbrains.kotlin.js.translate.utils.BindingUtils
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils.not
import org.jetbrains.kotlin.js.translate.utils.mutator.CoercionMutator
import org.jetbrains.kotlin.js.translate.utils.mutator.LastExpressionMutator
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.getTextWithLocation
import org.jetbrains.kotlin.resolve.DescriptorUtils
import org.jetbrains.kotlin.resolve.bindingContextUtil.getDataFlowInfoBefore
import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowValueFactory
import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowValueFactoryImpl
import org.jetbrains.kotlin.resolve.constants.CompileTimeConstant
import org.jetbrains.kotlin.resolve.constants.EnumValue
import org.jetbrains.kotlin.resolve.constants.evaluate.ConstantExpressionEvaluator
import org.jetbrains.kotlin.resolve.descriptorUtil.getSuperClassOrAny
import org.jetbrains.kotlin.types.KotlinType

private typealias EntryWithConstants = Pair, KtWhenEntry>

class WhenTranslator
private constructor(private val whenExpression: KtWhenExpression, context: TranslationContext) : AbstractTranslator(context) {
    private val subjectType: KotlinType?
    private val expressionToMatch: JsExpression?
    private val type: KotlinType?
    private val uniqueConstants = mutableSetOf()
    private val uniqueEnumNames = mutableSetOf()
    private val dataFlowValueFactory: DataFlowValueFactory = DataFlowValueFactoryImpl(context.languageVersionSettings)

    private val isExhaustive: Boolean
        get() {
            val type = bindingContext().getType(whenExpression)
            val isStatement = type != null && KotlinBuiltIns.isUnit(type) && !type.isMarkedNullable
            return CodegenUtil.isExhaustive(bindingContext(), whenExpression, isStatement)
        }

    init {
        val subjectVariable = whenExpression.subjectVariable
        val subjectExpression = whenExpression.subjectExpression

        when {
            subjectVariable != null -> {
                val variable = Translation.translateAsStatement(subjectVariable, context) as JsVars
                context.addStatementToCurrentBlock(variable)

                val descriptor = BindingUtils.getDescriptorForElement(context.bindingContext(), subjectVariable) as? CallableDescriptor
                subjectType = descriptor?.returnType
                expressionToMatch = variable.vars.first().name.makeRef()
            }
            subjectExpression != null -> {
                subjectType = bindingContext().getType(subjectExpression)
                expressionToMatch = context.defineTemporary(Translation.translateAsExpression(subjectExpression, context))
            }
            else -> {
                subjectType = null
                expressionToMatch = null
            }
        }

        type = bindingContext().getType(whenExpression)
    }

    private fun translate(): JsNode {
        var resultIf: JsNode? = null
        var setWhenStatement: (JsStatement) -> Unit = { resultIf = it }

        var i = 0
        var hasElse = false
        while (i < whenExpression.entries.size) {
            val asSwitch = translateAsSwitch(i)
            if (asSwitch != null) {
                val (jsSwitch, next) = asSwitch
                setWhenStatement(jsSwitch)
                setWhenStatement = { whenStatement ->
                    jsSwitch.cases += JsDefault().apply {
                        statements += whenStatement
                        statements += JsBreak().apply { source = whenExpression }
                    }
                }
                i = next
                continue
            }

            val entry = whenExpression.entries[i++]
            val statementBlock = JsBlock()
            var statement = translateEntryExpression(entry, context(), statementBlock)

            if (resultIf == null && entry.isElse) {
                context().addStatementsToCurrentBlockFrom(statementBlock)
                return statement
            }
            statement = JsAstUtils.mergeStatementInBlockIfNeeded(statement, statementBlock)

            val conditionsBlock = JsBlock()
            if (entry.isElse) {
                hasElse = true
                setWhenStatement(statement)
                break
            }
            val jsIf = JsAstUtils.newJsIf(translateConditions(entry, context().innerBlock(conditionsBlock)), statement)
            jsIf.source = entry

            val statementToAdd = JsAstUtils.mergeStatementInBlockIfNeeded(jsIf, conditionsBlock)
            setWhenStatement(statementToAdd)
            setWhenStatement = { jsIf.elseStatement = it }
        }

        if (isExhaustive && !hasElse) {
            val noWhenMatchedInvocation = JsInvocation(JsAstUtils.pureFqn("noWhenBranchMatched", Namer.kotlinObject()))
            setWhenStatement(JsAstUtils.asSyntheticStatement(noWhenMatchedInvocation))
        }

        return if (resultIf != null) resultIf!! else JsNullLiteral()
    }

    private fun translateAsSwitch(fromIndex: Int): Pair? {
        val subjectType = subjectType ?: return null
        val ktSubject = whenExpression.subjectExpression ?: return null

        val dataFlow = dataFlowValueFactory.createDataFlowValue(
                ktSubject, subjectType, bindingContext(), context().declarationDescriptor ?: context().currentModule)
        val languageVersionSettings = context().config.configuration.languageVersionSettings
        val expectedTypes = bindingContext().getDataFlowInfoBefore(ktSubject).getStableTypes(dataFlow, languageVersionSettings) +
                            setOf(subjectType)
        val subject = expressionToMatch ?: return null
        var subjectSupplier = { subject }

        val enumClass = expectedTypes.asSequence().mapNotNull { it.getEnumClass() }.firstOrNull()
        val (entriesForSwitch, nextIndex) = if (enumClass != null) {
            subjectSupplier = {
                val enumBaseClass = enumClass.getSuperClassOrAny()
                val nameProperty = DescriptorUtils.getPropertyByName(enumBaseClass.unsubstitutedMemberScope, Name.identifier("name"))
                JsNameRef(context().getNameForDescriptor(nameProperty), subject)
            }
            collectEnumEntries(fromIndex, whenExpression.entries, enumClass.defaultType)
        }
        else {
            collectPrimitiveConstantEntries(fromIndex, whenExpression.entries, expectedTypes)
        }

        return if (entriesForSwitch.asSequence().map { it.first.size }.sum() > 1) {
            val switchEntries = mutableListOf()
            entriesForSwitch.flatMapTo(switchEntries) { (conditions, entry) ->
                val members = conditions.map {
                    JsCase().apply {
                        caseExpression = it.source(entry)
                    }
                }
                val block = JsBlock()
                val statement = translateEntryExpression(entry, context(), block)
                val lastEntry = members.last()
                lastEntry.statements += block.statements
                lastEntry.statements += statement
                lastEntry.statements += JsBreak().apply { source = entry }
                members
            }
            Pair(JsSwitch(subjectSupplier(), switchEntries).apply { source = whenExpression }, nextIndex)
        }
        else {
            null
        }
    }

    private fun collectPrimitiveConstantEntries(
            fromIndex: Int,
            entries: List,
            expectedTypes: Set
    ): Pair, Int> {
        return collectConstantEntries(
                fromIndex, entries,
                { constant -> expectedTypes.asSequence().mapNotNull { constant.getValue(it) }.firstOrNull() },
                { uniqueConstants.add(it) },
                {
                    when (it) {
                        is String -> JsStringLiteral(it)
                        is Int -> JsIntLiteral(it)
                        is Short -> JsIntLiteral(it.toInt())
                        is Byte -> JsIntLiteral(it.toInt())
                        is Char -> JsIntLiteral(it.toInt())
                        else -> null
                    }
                }
        )
    }

    private fun collectEnumEntries(
            fromIndex: Int,
            entries: List,
            expectedType: KotlinType
    ): Pair, Int> {
        val classId = WhenChecker.getClassIdForTypeIfEnum(expectedType)
        return collectConstantEntries(
            fromIndex, entries,
            {
                (it.toConstantValue(expectedType) as? EnumValue)
                    ?.takeIf { enumEntry -> enumEntry.enumClassId == classId }
                    ?.enumEntryName?.identifier
            },
            { uniqueEnumNames.add(it) },
            { JsStringLiteral(it) }
        )
    }

    private fun  collectConstantEntries(
            fromIndex: Int,
            entries: List,
            extractor: (CompileTimeConstant<*>) -> T?,
            filter: (T) -> Boolean,
            wrapper: (T) -> JsExpression?
    ): Pair, Int> {
        val entriesForSwitch = mutableListOf()
        var i = fromIndex
        while (i < entries.size) {
            val entry = entries[i]
            if (entry.isElse) break

            var hasImproperConstants = false
            val constantValues = entry.conditions.mapNotNull { condition ->
                val expression = (condition as? KtWhenConditionWithExpression)?.expression
                expression?.let { ConstantExpressionEvaluator.getConstant(it, bindingContext()) }?.let(extractor) ?: run {
                    hasImproperConstants = true
                    null
                }
            }
            if (hasImproperConstants) break

            val constants = constantValues.filter(filter).mapNotNull {
                wrapper(it) ?: run {
                    hasImproperConstants = true
                    null
                }
            }
            if (hasImproperConstants) break

            if (constants.isNotEmpty()) {
                entriesForSwitch += Pair(constants, entry)
            }
            i++
        }

        return Pair(entriesForSwitch, i)
    }

    private fun KotlinType.getEnumClass(): ClassDescriptor? {
        if (isMarkedNullable) return null
        val classDescriptor = (constructor.declarationDescriptor as? ClassDescriptor)
        return if (classDescriptor?.kind == ClassKind.ENUM_CLASS && !classDescriptor.isExternal) classDescriptor else null
    }

    private fun translateEntryExpression(
            entry: KtWhenEntry,
            context: TranslationContext,
            block: JsBlock
    ): JsStatement {
        val expressionToExecute = entry.expression ?: error("WhenEntry should have whenExpression to execute.")
        val result = Translation.translateAsStatement(expressionToExecute, context, block)
        return if (type != null) {
            LastExpressionMutator.mutateLastExpression(result, CoercionMutator(type, context))
        }
        else {
            result
        }
    }

    private fun translateConditions(entry: KtWhenEntry, context: TranslationContext): JsExpression {
        val conditions = entry.conditions
        assert(conditions.isNotEmpty()) { "When entry (not else) should have at least one condition" }

        val first = translateCondition(conditions[0], context)
        return conditions.asSequence().drop(1).fold(first) { acc, condition -> translateOrCondition(acc, condition, context) }
    }

    private fun translateOrCondition(
            leftExpression: JsExpression,
            condition: KtWhenCondition,
            context: TranslationContext
    ): JsExpression {
        val rightContext = context.innerBlock()
        val rightExpression = translateCondition(condition, rightContext)
        context.moveVarsFrom(rightContext)
        return if (rightContext.currentBlockIsEmpty()) {
            JsBinaryOperation(JsBinaryOperator.OR, leftExpression, rightExpression)
        }
        else {
            assert(rightExpression is JsNameRef) { "expected JsNameRef, but: " + rightExpression }
            val result = rightExpression as JsNameRef
            val ifStatement = JsAstUtils.newJsIf(leftExpression, JsAstUtils.assignment(result, JsBooleanLiteral(true)).makeStmt(),
                                                 rightContext.currentBlock)
            ifStatement.source = condition
            context.addStatementToCurrentBlock(ifStatement)
            result
        }
    }

    private fun translateCondition(condition: KtWhenCondition, context: TranslationContext): JsExpression {
        val patternMatchExpression = translateWhenConditionToBooleanExpression(condition, context)
        return if (isNegated(condition)) not(patternMatchExpression) else patternMatchExpression
    }

    private fun translateWhenConditionToBooleanExpression(
            condition: KtWhenCondition,
            context: TranslationContext
    ): JsExpression = when (condition) {
        is KtWhenConditionIsPattern -> translateIsCondition(condition, context)
        is KtWhenConditionWithExpression -> translateExpressionCondition(condition, context)
        is KtWhenConditionInRange -> translateRangeCondition(condition, context)
        else -> error("Unsupported when condition " + condition.javaClass)
    }

    private fun translateIsCondition(conditionIsPattern: KtWhenConditionIsPattern, context: TranslationContext): JsExpression {
        val expressionToMatch = expressionToMatch ?: error("An is-check is not allowed in when() without subject.")
        val typeReference = conditionIsPattern.typeReference ?: error("An is-check must have a type reference.")

        val result = Translation.patternTranslator(context).translateIsCheck(expressionToMatch, typeReference)
        return (result ?: JsBooleanLiteral(true)).source(conditionIsPattern)
    }

    private fun translateExpressionCondition(condition: KtWhenConditionWithExpression, context: TranslationContext): JsExpression {
        val patternExpression = condition.expression ?: error("Expression pattern should have an expression.")

        val patternTranslator = Translation.patternTranslator(context)
        return if (expressionToMatch == null) {
            patternTranslator.translateExpressionForExpressionPattern(patternExpression)
        } else {
            patternTranslator.translateExpressionPattern(subjectType!!, expressionToMatch, patternExpression)
        }
    }

    private fun translateRangeCondition(condition: KtWhenConditionInRange, context: TranslationContext): JsExpression {
        val expressionToMatch = expressionToMatch ?: error("Range pattern is only available for " +
                                                           "'when (C) { in ... }'  expressions: ${condition.getTextWithLocation()}")

        val subjectAliases = hashMapOf()
        subjectAliases[whenExpression.subjectExpression!!] = expressionToMatch
        val callContext = context.innerContextWithAliasesForExpressions(subjectAliases)
        val negated = condition.operationReference.getReferencedNameElementType() === KtTokens.NOT_IN
        return InOperationTranslator(callContext, expressionToMatch, condition.rangeExpression!!, condition.operationReference,
                                     negated).translate().source(condition)
    }

    companion object {
        @JvmStatic
        fun translate(expression: KtWhenExpression, context: TranslationContext): JsNode = WhenTranslator(expression, context).translate()

        private fun isNegated(condition: KtWhenCondition): Boolean = (condition as? KtWhenConditionIsPattern)?.isNegated ?: false
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy