All Downloads are FREE. Search and download functionalities are using the official Maven repository.

androidx.compose.compiler.plugins.kotlin.lower.IrSourcePrinter.kt Maven / Gradle / Ivy

There is a newer version: 2.1.0-RC
Show newest version
/*
 * Copyright 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

@file:OptIn(UnsafeDuringIrConstructionAPI::class)

package androidx.compose.compiler.plugins.kotlin.lower

import androidx.compose.compiler.plugins.kotlin.ComposeNames
import androidx.compose.compiler.plugins.kotlin.hasComposableAnnotation
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.IrBuiltIns
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrReturnTargetSymbol
import org.jetbrains.kotlin.ir.symbols.IrTypeAliasSymbol
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptVoid
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.types.Variance
import org.jetbrains.kotlin.utils.Printer
import java.util.*
import kotlin.collections.HashSet
import kotlin.collections.List
import kotlin.collections.all
import kotlin.collections.any
import kotlin.collections.contains
import kotlin.collections.filter
import kotlin.collections.firstOrNull
import kotlin.collections.forEach
import kotlin.collections.forEachIndexed
import kotlin.collections.hashSetOf
import kotlin.collections.isNotEmpty
import kotlin.collections.joinToString
import kotlin.collections.last
import kotlin.collections.lastOrNull
import kotlin.collections.map
import kotlin.collections.mapNotNull
import kotlin.collections.mutableListOf
import kotlin.collections.mutableMapOf
import kotlin.collections.mutableSetOf
import kotlin.collections.set
import kotlin.collections.zip

fun IrElement.dumpSrc(useFir: Boolean = false): String {
    val sb = StringBuilder()
    accept(IrSourcePrinterVisitor(sb, "%tab%", useFir), null)
    return sb
        .toString()
        // replace tabs at beginning of line with white space
        .replace(Regex("\\n(%tab%)+", RegexOption.MULTILINE)) {
            val size = it.range.last - it.range.first - 1
            "\n" + (0..(size / 5)).joinToString("") { "  " }
        }
        // tabs that are inserted in the middle of lines should be replaced with empty strings
        .replace(Regex("%tab%", RegexOption.MULTILINE), "")
        // remove empty lines
        .replace(Regex("\\n(\\s)*$", RegexOption.MULTILINE), "")
        // brackets with comma on new line
        .replace(Regex("}\\n(\\s)*,", RegexOption.MULTILINE), "},")
}

class Scope(
    val owner: IrFunction? = null,
    val localValues: HashSet = hashSetOf(),
)

class IrSourcePrinterVisitor(
    out: Appendable,
    indentUnit: String = "  ",
    private val useFir: Boolean = false,
) : IrElementVisitorVoid {
    private val printer = Printer(out, indentUnit)
    private var currentScope: Scope = Scope()

    private fun IrElement.print() {
        accept(this@IrSourcePrinterVisitor, null)
    }

    private fun print(obj: Any?) = printer.print(obj)
    private fun println(obj: Any?) = printer.println(obj)
    private fun println() = printer.println()

    fun printType(type: IrType) = type.renderSrc()

    private inline fun IrFunction.scoped(block: (IrFunction) -> Unit) {
        val previousScope = currentScope
        currentScope = Scope(
            this,
            HashSet(valueParameters)
        )

        block(this)

        currentScope = previousScope
    }

    private inline fun indented(body: () -> Unit) {
        printer.pushIndent()
        body()
        printer.popIndent()
    }

    private inline fun bracedBlock(body: () -> Unit) {
        println("{")
        indented(body)
        println()
        println("}")
    }

    private fun List.printJoin(separator: String = "") {
        forEachIndexed { index, it ->
            it.print()
            if (index < size - 1) print(separator)
        }
    }

    override fun visitModuleFragment(declaration: IrModuleFragment) {
//        println("// MODULE: ${declaration.name}")
        declaration.files.printJoin()
    }

    override fun visitFile(declaration: IrFile) {
        includeFileNameInExceptionTrace(declaration) {
//          println("// FILE: ${declaration.fileEntry.name}")
            declaration.declarations.printJoin("\n")
        }
    }

    override fun visitValueParameter(declaration: IrValueParameter) {
        if (declaration.isCrossinline) {
            print("crossinline ")
        }
        if (declaration.isNoinline) {
            print("noinline ")
        }
        declaration.printAnnotations()
        print(declaration.normalizedName)
        print(": ")
        print(declaration.type.renderSrc())
        declaration.defaultValue?.let { it ->
            print(" = ")
            it.print()
        }
    }

    override fun visitSimpleFunction(declaration: IrSimpleFunction) {
        if (declaration.origin == IrDeclarationOrigin.FAKE_OVERRIDE) return
        declaration.scoped {
            declaration.printAnnotations(onePerLine = true)
            if (declaration.overriddenSymbols.isNotEmpty()) {
                print("override ")
            } else {
                if (
                    declaration.visibility != DescriptorVisibilities.PUBLIC &&
                    declaration.visibility != DescriptorVisibilities.LOCAL
                ) {
                    print(declaration.visibility.toString().lowercase(Locale.ROOT))
                    print(" ")
                }
                if (declaration.modality != Modality.FINAL) {
                    print(declaration.modality.toString().lowercase(Locale.ROOT))
                    print(" ")
                }
            }
            if (declaration.isSuspend) {
                print("suspend ")
            }
            print("fun ")
            if (declaration.typeParameters.isNotEmpty()) {
                print("<")
                declaration.typeParameters.printJoin(", ")
                print("> ")
            }
            declaration.extensionReceiverParameter?.let {
                print(it.type.renderSrc())
                print(".")
            }
            print(declaration.name)
            print("(")
            declaration.valueParameters.printJoin(", ")
            print(")")
            if (!declaration.returnType.isUnit()) {
                print(": ")
                print(
                    declaration.returnType.renderSrc()
                )
            }
            print(" ")
            declaration.printBody()
        }
    }

    fun IrFunction.printBody() {
        val body = body ?: return
        if (body.statements.isEmpty()) {
            println("{ }")
        } else {
            bracedBlock {
                body.print()
            }
        }
    }

    override fun visitConstructor(declaration: IrConstructor) {
        declaration.printAnnotations(onePerLine = true)
        print("constructor")
        val parameters = declaration.valueParameters
        if (parameters.isNotEmpty()) {
            print("(")
            parameters.printJoin(", ")
            print(")")
        }
        declaration.printBody()
    }

    private var isInNotCall = false

    override fun visitCall(expression: IrCall) {
        val function = expression.symbol.owner
        val name = function.name.asString()
        val isOperator = function.isOperator || function.origin == IrBuiltIns.BUILTIN_OPERATOR
        val isInfix = function.isInfix
        if (isOperator) {
            if (name == "not") {
                // IR tree for `a != b` looks like `not(equals(a, b))` which makes
                // it challenging to print it like the former. To do so, we capture when we are in
                // a "not" call, and then check to see if the argument is an equals call. if it is,
                // we will just print the child call and put the transformer into a mode where it
                // knows to print the negative
                val arg = expression.dispatchReceiver!!
                if (arg is IrCall) {
                    val fn = arg.symbol.owner
                    if (fn.origin == IrBuiltIns.BUILTIN_OPERATOR) {
                        when (fn.name.asString()) {
                            "equals",
                            "EQEQ",
                            "EQEQEQ",
                            -> {
                                val prevIsInNotCall = isInNotCall
                                isInNotCall = true
                                arg.print()
                                isInNotCall = prevIsInNotCall
                                return
                            }
                        }
                    }
                }
            }
            val opSymbol = when (name) {
                "contains" -> "in"
                "equals" -> if (isInNotCall) "!=" else "=="
                "plus" -> "+"
                "not" -> "!"
                "minus" -> "-"
                "times" -> "*"
                "div" -> "/"
                "rem" -> "%"
                "rangeTo" -> ".."
                "plusAssign" -> "+="
                "minusAssign" -> "-="
                "unaryMinus" -> "-"
                "timesAssign" -> "*="
                "divAssign" -> "/="
                "remAssign" -> "%="
                "inc" -> "++"
                "dec" -> "--"
                "greater" -> ">"
                "less" -> "<"
                "lessOrEqual" -> "<="
                "greaterOrEqual" -> ">="
                "EQEQ" -> if (isInNotCall) "!=" else "=="
                "EQEQEQ" -> if (isInNotCall) "!==" else "==="
                "OROR" -> "||"
                "ieee754equals" -> "=="

                // no names for
                "invoke", "get", "set" -> ""
                "iterator", "hasNext", "next", "getValue", "setValue",
                "noWhenBranchMatchedException",
                -> name
                "CHECK_NOT_NULL" -> "!!"
                else -> {
                    if (name.startsWith("component")) name
                    else error("Unhandled operator $name")
                }
            }

            val printBinary = when (name) {
                "equals",
                "EQEQ",
                "EQEQEQ",
                -> when {
                    expression.dispatchReceiver?.type?.isInt() == true -> true
                    expression.extensionReceiver?.type?.isInt() == true -> true
                    expression.valueArgumentsCount > 0 &&
                            expression.getValueArgument(0)?.type?.isInt() == true -> true
                    else -> false
                }
                else -> false
            }
            val prevPrintBinary = printIntsAsBinary
            printIntsAsBinary = printBinary
            when (name) {
                // unary prefx
                "unaryPlus", "unaryMinus", "not" -> {
                    print(opSymbol)
                    (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                }
                // unary postfix
                "inc", "dec" -> {
                    (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                    print(opSymbol)
                }
                "CHECK_NOT_NULL" -> {
                    expression.getValueArgument(0)?.print()
                    print(opSymbol)
                }
                "getValue", "setValue" -> {
                    (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                    print(".")
                    print(opSymbol)
                    expression.printArgumentList()
                }
                "invoke" -> {
                    (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                    expression.printArgumentList()
                }
                // get indexer
                "get" -> {
                    expression.dispatchReceiver?.print()
                    print("[")
                    for (i in 0 until expression.valueArgumentsCount) {
                        val arg = expression.getValueArgument(i)
                        arg?.print()
                    }
                    print("]")
                }
                // set indexer
                "set" -> {
                    (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                    print("[")
                    expression.getValueArgument(0)?.print()
                    print("] = ")
                    expression.getValueArgument(1)?.print()
                }
                // builtin static operators
                "greater", "less", "lessOrEqual", "greaterOrEqual", "EQEQ", "EQEQEQ",
                "ieee754equals",
                -> {
                    expression.getValueArgument(0)?.print()
                    print(" $opSymbol ")
                    expression.getValueArgument(1)?.print()
                }
                "iterator", "hasNext", "next",
                "noWhenBranchMatchedException",
                -> {
                    (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                    print(".")
                    print(opSymbol)
                    print("()")
                }
                else -> {
                    if (name.startsWith("component")) {
                        (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                        print(".")
                        print(opSymbol)
                        print("()")
                    } else {
                        // else binary
                        (expression.dispatchReceiver ?: expression.extensionReceiver)?.print()
                        print(" $opSymbol ")
                        expression.getValueArgument(0)?.print()
                    }
                }
            }
            printIntsAsBinary = prevPrintBinary
            return
        }

        if (isInfix) {
            val prev = printIntsAsBinary
            if (name == "xor" || name == "and" || name == "or") {
                printIntsAsBinary = true
            }
            expression.dispatchReceiver?.print()
                ?: expression.extensionReceiver?.print()
            print(" $name ")
            expression.getValueArgument(0)?.print()
            printIntsAsBinary = prev
            return
        }

        expression.printExplicitReceiver(".", expression.superQualifierSymbol)

        val prop = function.correspondingPropertySymbol?.owner

        if (prop != null && !function.hasComposableAnnotation()) {
            val propName = prop.name.asString()
            print(propName)
            if (function == prop.setter) {
                print(" = ")
                expression.getValueArgument(0)?.print()
            }
        } else {
            print(name)
            expression.printArgumentList()
        }
    }

    private fun IrAnnotationContainer.printAnnotations(onePerLine: Boolean = false) {
        if (annotations.isNotEmpty()) {
            annotations.printJoin(if (onePerLine) "\n" else " ")
            if (onePerLine) println()
            else print(" ")
        }
    }

    private fun IrMemberAccessExpression<*>.printExplicitReceiver(
        suffix: String? = null,
        superQualifierSymbol: IrClassSymbol? = null,
    ) {
        val dispatchReceiver = dispatchReceiver
        val extensionReceiver = extensionReceiver
        val dispatchIsSpecial = dispatchReceiver.let {
            it is IrGetValue && it.symbol.owner.name.isSpecial
        }
        val extensionIsSpecial = extensionReceiver.let {
            it is IrGetValue && it.symbol.owner.name.isSpecial
        }

        if (superQualifierSymbol != null) {
            print("super<${superQualifierSymbol.owner.name}>")
            suffix?.let(::print)
        } else if (dispatchReceiver != null && !dispatchIsSpecial) {
            dispatchReceiver.print()
            suffix?.let(::print)
        } else if (extensionReceiver != null && !extensionIsSpecial) {
            extensionReceiver.print()
            suffix?.let(::print)
        }
    }

    private fun IrFunctionAccessExpression.printArgumentList(
        forceParameterNames: Boolean = false,
        forceSingleLine: Boolean = false,
    ) {
        val arguments = mutableListOf()
        val paramNames = mutableListOf()
        var trailingLambda: IrExpression? = null
        var useParameterNames = forceParameterNames
        for (i in 0 until valueArgumentsCount) {
            val arg = getValueArgument(i)
            if (arg != null) {
                val param = symbol.owner.valueParameters[i]
                val isLambda = arg is IrFunctionExpression ||
                        (arg is IrBlock &&
                                (arg.origin == IrStatementOrigin.LAMBDA))
                if (isLambda) {
                    arg.unwrapLambda()?.let {
                        returnTargetToCall[it] = this
                    }
                }
                val isTrailingLambda = i == symbol.owner.valueParameters.size - 1 && isLambda
                if (isTrailingLambda) {
                    trailingLambda = arg
                } else {
                    arguments.add(arg)
                    paramNames.add(param.normalizedName)
                }
            } else {
                useParameterNames = true
            }
        }
        val multiline = arguments.isNotEmpty() && useParameterNames && !forceSingleLine
        if (arguments.isNotEmpty() || trailingLambda == null) {
            print("(")
            if (multiline) {
                // if we are using parameter names, we go on multiple lines
                println()
                indented {
                    arguments.zip(paramNames).forEachIndexed { i, (arg, name) ->
                        if (useParameterNames) {
                            print(name)
                            print(" = ")
                        }
                        arg.print()
                        if (i < arguments.size - 1) println(", ")
                    }
                }
                println()
            } else {
                arguments.zip(paramNames).forEachIndexed { i, (arg, name) ->
                    if (useParameterNames) {
                        print(name)
                        print(" = ")
                    }
                    when {
                        name.startsWith(ComposeNames.DEFAULT_PARAMETER.identifier) ||
                                name.startsWith(ComposeNames.CHANGED_PARAMETER.identifier) -> {
                            withIntsAsBinaryLiterals {
                                arg.print()
                            }
                        }
                        else -> {
                            arg.printWithExplicitBlock()
                        }
                    }
                    if (i < arguments.size - 1) print(", ")
                }
            }
            print(")")
        }
        trailingLambda?.let {
            print(" ")
            it.print()
        }
    }

    fun IrElement.printWithExplicitBlock() {
        when (this) {
            is IrBlock -> {
                println("{")
                indented { print() }
                println()
                print("}")
            }
            else -> print()
        }
    }

    override fun visitFunctionExpression(expression: IrFunctionExpression) {
        expression.function.printAsLambda()
    }

    fun IrFunction.printAsLambda() {
        scoped {
            print("{")
            val parameters = valueParameters
            if (parameters.isNotEmpty()) {
                print(" ")
                parameters.printJoin(", ")
                println(" ->")
            } else {
                println()
            }
            indented {
                body?.print()
            }
            println()
            println("}")
        }
    }

    override fun visitTypeOperator(expression: IrTypeOperatorCall) {
        when (expression.operator) {
            IrTypeOperator.IMPLICIT_COERCION_TO_UNIT -> {
                expression.argument.print()
            }
            IrTypeOperator.NOT_INSTANCEOF -> {
                expression.argument.print()
            }
            IrTypeOperator.CAST, IrTypeOperator.SAFE_CAST, IrTypeOperator.IMPLICIT_CAST -> {
                expression.argument.print()
            }
            IrTypeOperator.SAM_CONVERSION -> {
                print(expression.type.renderSrc())
                print(" ")
                expression.argument.print()
            }
            IrTypeOperator.IMPLICIT_NOTNULL -> {
                expression.argument.print()
            }
            IrTypeOperator.INSTANCEOF -> {
                expression.argument.print()
            }
            else -> error("Unknown type operator: ${expression.operator}")
        }
    }

    override fun visitComposite(expression: IrComposite) {
        expression.statements.printJoin("\n")
    }

    override fun visitDoWhileLoop(loop: IrDoWhileLoop) {
        println("do {")
        indented {
            loop.body?.print()
            println()
        }
        print("} while (")
        loop.condition.print()
        println(")")
    }

    override fun visitConstructorCall(expression: IrConstructorCall) {
        val constructedClass = expression.symbol.owner.constructedClass
        val name = constructedClass.name
        val isAnnotation = constructedClass.isAnnotationClass
        if (isAnnotation) {
            print("@")
        }
        expression.dispatchReceiver?.let {
            it.print()
            print(".")
        }
        print(name)

        val printArgumentList = if (!isAnnotation) true else {
            var hasArguments = false
            for (i in 0 until expression.valueArgumentsCount) {
                val arg = expression.getValueArgument(i)
                if (arg != null) {
                    hasArguments = true
                    break
                }
            }
            hasArguments
        }
        if (printArgumentList) {
            expression.printArgumentList(
                forceParameterNames = isAnnotation,
                forceSingleLine = isAnnotation
            )
        }
    }

    override fun visitStringConcatenation(expression: IrStringConcatenation) {
        val arguments = expression.arguments
        print("\"")
        for (arg in arguments) {
            when {
                arg is IrConst && arg.kind == IrConstKind.String -> print(arg.value)
                arg is IrGetValue -> {
                    print("$")
                    arg.print()
                }
                else -> {
                    print("\${")
                    arg.print()
                    print("}")
                }
            }
        }
        print("\"")
    }

    override fun visitVararg(expression: IrVararg) {
        expression
            .elements
            .forEachIndexed { index, it ->
                if (it is IrSpreadElement) {
                    print("*")
                    it.expression.print()
                } else if (it is IrExpression) {
                    it.print()
                }
                if (index < expression.elements.size - 1) print(", ")
            }
    }

    override fun visitWhen(expression: IrWhen) {
        val isIf = expression.origin == IrStatementOrigin.IF
        when {
            expression.origin == IrStatementOrigin.OROR -> {
                val lhs = expression.branches[0].condition
                val rhs = expression.branches[1].result
                lhs.print()
                print(" || ")
                rhs.print()
            }
            expression.origin == IrStatementOrigin.ANDAND -> {
                val lhs = expression.branches[0].condition
                val rhs = expression.branches[0].result
                lhs.print()
                print(" && ")
                rhs.print()
            }
            isIf -> {
                val singleLine = expression.branches.all {
                    it.result is IrConst || it.result is IrGetValue
                }
                expression.branches.forEachIndexed { index, branch ->
                    val isElse = index == expression.branches.size - 1 &&
                            (branch.condition as? IrConst)?.value == true
                    when {
                        index == 0 -> {
                            print("if (")
                            branch.condition.print()
                            if (singleLine)
                                print(") ")
                            else
                                println(") {")
                        }
                        isElse -> {
                            if (singleLine)
                                print(" else ")
                            else
                                println("} else {")
                        }
                        else -> {
                            if (singleLine)
                                print(" else if (")
                            else
                                print("} else if (")
                            branch.condition.print()
                            if (singleLine)
                                print(") ")
                            else
                                println(") {")
                        }
                    }
                    if (singleLine)
                        branch.result.print()
                    else indented {
                        branch.result.print()
                        println()
                    }
                }
                if (!singleLine)
                    println("}")
            }
            else -> {
                print("when ")
                bracedBlock {
                    expression.branches.forEach {
                        val isElse = (it.condition as? IrConst)?.value == true

                        if (isElse) {
                            print("else")
                        } else {
                            it.condition.print()
                        }
                        print(" -> ")
                        bracedBlock {
                            it.result.print()
                        }
                    }
                }
            }
        }
    }

    override fun visitWhileLoop(loop: IrWhileLoop) {
        if (loop.label != null) {
            print(loop.label)
            print("@")
        }
        print("while (")
        loop.condition.print()
        println(") {")
        indented {
            loop.body?.print()
            println()
        }
        println("}")
    }

    // Map local return targets to the corresponding function call.
    // This is used to print qualified returns.
    private val returnTargetToCall =
        mutableMapOf()

    private val IrFunction.isLambda: Boolean
        get() = name.asString() == SpecialNames.ANONYMOUS_STRING ||
                origin == IrDeclarationOrigin.ADAPTER_FOR_CALLABLE_REFERENCE

    private val IrFunction.isDelegatedPropertySetter: Boolean
        get() = isSetter && origin == IrDeclarationOrigin.DELEGATED_PROPERTY_ACCESSOR

    private fun IrExpression.isLastStatementIn(statements: List): Boolean {
        val lastStatement = statements.lastOrNull()
        return when {
            lastStatement === this -> true
            lastStatement is IrBlock -> isLastStatementIn(lastStatement.statements)
            else -> false
        }
    }

    private fun IrExpression.isLastStatementIn(function: IrFunction): Boolean =
        function.body?.let { isLastStatementIn(it.statements) } ?: false

    override fun visitReturn(expression: IrReturn) {
        val value = expression.value
        // Only print the return statement directly if it is not the last statement in a lambda,
        // or a delegated property setter. The latter have a superfluous "return" in K1.
        val returnTarget = expression.returnTargetSymbol.owner
        if (returnTarget !is IrFunction ||
            (!returnTarget.isLambda && (useFir || !returnTarget.isDelegatedPropertySetter)) ||
            !expression.isLastStatementIn(returnTarget)
        ) {
            val suffix = returnTargetToCall[returnTarget.symbol]?.let {
                "@${it.symbol.owner.name}"
            } ?: ""
            print("return$suffix ")
        }
        if (value.type.isUnit() && value is IrGetObjectValue) {
            return
        }
        value.print()
    }

    override fun visitBlock(expression: IrBlock) {
        when (expression.origin) {
            IrStatementOrigin.POSTFIX_INCR -> {
                val tmpVar = expression.statements[0] as IrVariable
                val lhs = tmpVar.initializer ?: error("Expected initializer")
                lhs.print()
                print("++")
            }
            IrStatementOrigin.POSTFIX_DECR -> {
                val tmpVar = expression.statements[0] as IrVariable
                val lhs = tmpVar.initializer ?: error("Expected initializer")
                lhs.print()
                print("--")
            }
            IrStatementOrigin.LAMBDA -> {
                val function = expression.statements[0] as IrFunction
                function.printAsLambda()
            }
            IrStatementOrigin.OBJECT_LITERAL -> {
                val classImpl = expression.statements[0] as IrClass
                classImpl.printAsObject()
            }
            IrStatementOrigin.SAFE_CALL -> {
                val lhs = expression.statements[0] as IrVariable
                val rhsStatement = expression.statements[1]
                val rhs = when (rhsStatement) {
                    is IrBlock -> {
                        if (rhsStatement.statements.size == 2) {
                            val target = rhsStatement.statements[1]
                            when (target) {
                                is IrVariable -> target.initializer
                                else -> target
                            }
                        } else rhsStatement
                    }
                    else -> {
                        rhsStatement
                    }
                } as? IrWhen
                val call = rhs?.let { it.branches.last().result as? IrCall }
                if (call == null) {
                    expression.statements.printJoin("\n")
                    return
                }
                lhs.initializer?.print()
                print("?.")
                print(call.symbol.owner.name)
                call.printArgumentList()
            }
            IrStatementOrigin.FOR_LOOP -> {
                expression.statements.printJoin("\n")
            }
            else -> {
                expression.statements.printJoin("\n")
            }
        }
    }

    override fun visitVariable(declaration: IrVariable) {
        currentScope.localValues.add(declaration)

        if (declaration.isLateinit) {
            print("lateinit")
        }
        when {
            declaration.isConst -> print("const ")
            declaration.isVar -> print("var ")
            else -> print("val ")
        }
        print(declaration.normalizedName)
        declaration.initializer?.let {
            print(" = ")
            it.printWithExplicitBlock()
        }
    }

    override fun visitGetObjectValue(expression: IrGetObjectValue) {
        print(expression.symbol.owner.name)
    }

    override fun visitGetValue(expression: IrGetValue) {
        val owner = expression.symbol.owner
        print(owner.normalizedName)

        if (
            owner.parent != currentScope.owner &&
            currentScope.localValues.any { it.name == owner.name }
        ) {
            print("@")
            print(owner.parent.kotlinFqName)
        }
    }

    override fun visitField(declaration: IrField) {
        if (
            declaration.visibility != DescriptorVisibilities.PUBLIC &&
            declaration.visibility != DescriptorVisibilities.LOCAL
        ) {
            print(declaration.visibility.toString().lowercase(Locale.ROOT))
            print(" ")
        }
        if (declaration.isStatic) {
            print("static ")
        }
        if (declaration.isFinal) {
            print("val ")
        } else {
            print("var ")
        }
        print(declaration.symbol.owner.name)
        print(": ")
        val type = declaration.type
        print(type.renderSrc())
        declaration.initializer?.let {
            print(" = ")
            it.printWithExplicitBlock()
        }
    }

    override fun visitGetField(expression: IrGetField) {
        val receiver = expression.receiver
        val owner = expression.symbol.owner
        val parent = owner.parent
        if (receiver != null) {
            expression.receiver?.print()
        } else if (owner.isStatic && parent is IrClass) {
            print(parent.name)
        }
        print(".")
        print(owner.name)
    }

    override fun visitSetField(expression: IrSetField) {
        expression.receiver?.print()
        print(".")
        print(expression.symbol.owner.name)
        print(" = ")
        expression.value.printWithExplicitBlock()
    }

    override fun visitGetEnumValue(expression: IrGetEnumValue) {
        val classifier = expression.type.classOrNull ?: return
        print(classifier.owner.name)
        print(".")
        print(expression.symbol.owner.name)
    }

    override fun visitSetValue(expression: IrSetValue) {
        print(expression.symbol.owner.normalizedName)
        print(" = ")
        expression.value.printWithExplicitBlock()
    }

    override fun visitExpressionBody(body: IrExpressionBody) {
        body.expression.accept(this, null)
    }

    override fun visitProperty(declaration: IrProperty) {
        declaration.printAnnotations(onePerLine = true)
        if (declaration.isLateinit) {
            print("lateinit")
        }
        when {
            declaration.isConst -> print("const ")
            declaration.isVar -> print("var ")
            else -> print("val ")
        }
        print(declaration.name)
        print(": ")
        val type = declaration.backingField?.type
            ?: declaration.getter?.returnType
            ?: error("Couldn't find return type")
        print(type.renderSrc())
        declaration.backingField?.let { field ->
            field.initializer?.let { initializer ->
                print(" = ")
                initializer.print()
            }
        }
        indented {
            declaration.getter?.scoped {
                if (it.origin != IrDeclarationOrigin.DEFAULT_PROPERTY_ACCESSOR) {
                    println()
                    it.printAnnotations()
                    println()
                    println("get() {")
                    indented {
                        it.body?.accept(this, null)
                    }
                    println()
                    println("}")
                }
            }
            declaration.setter?.scoped {
                if (it.origin != IrDeclarationOrigin.DEFAULT_PROPERTY_ACCESSOR) {
                    println()
                    it.printAnnotations()
                    println("set(value) {")
                    indented {
                        it.body?.accept(this, null)
                    }
                    println()
                    println("}")
                }
            }
        }
    }

    private var printIntsAsBinary = false
    fun  withIntsAsBinaryLiterals(block: () -> T): T {
        val prev = printIntsAsBinary
        try {
            printIntsAsBinary = true
            return block()
        } finally {
            printIntsAsBinary = prev
        }
    }

    private fun intAsBinaryString(value: Int): String {
        if (value == 0) return "0"
        var current = if (value >= 0) value else value.inv()
        var result = ""
        while (current != 0 || result.length % 4 != 0) {
            val nextBit = current and 1 != 0
            current = current ushr 1
            result = "${if (nextBit) "1" else "0"}$result"
        }
        return "0b$result" + if (value < 0) ".inv()" else ""
    }

    override fun visitConst(expression: IrConst) {
        val result = when (expression.kind) {
            is IrConstKind.Null -> "${expression.value}"
            is IrConstKind.Boolean -> "${expression.value}"
            is IrConstKind.Char -> "'${expression.value}'"
            is IrConstKind.Byte -> "${expression.value}"
            is IrConstKind.Short -> "${expression.value}"
            is IrConstKind.Int -> {
                if (printIntsAsBinary) {
                    intAsBinaryString(expression.value as Int)
                } else {
                    "${expression.value}"
                }
            }
            is IrConstKind.Long -> "${expression.value}L"
            is IrConstKind.Float -> "${expression.value}f"
            is IrConstKind.Double -> "${expression.value}"
            is IrConstKind.String -> "\"${expression.value}\""
        }
        print(result)
    }

    override fun visitBlockBody(body: IrBlockBody) {
        body.statements.printJoin("\n")
    }

    private fun IrClass.correspondingProperty(param: IrValueParameter): IrProperty? {
        return declarations
            .mapNotNull { it as? IrProperty }
            .firstOrNull {
                if (it.name == param.name) {
                    val init = it.backingField?.initializer?.expression as? IrGetValue
                    init?.origin == IrStatementOrigin.INITIALIZE_PROPERTY_FROM_PARAMETER
                } else false
            }
    }

    override fun visitClass(declaration: IrClass) {
        val primaryConstructor = declaration.primaryConstructor
        declaration.printAnnotations(onePerLine = true)
        if (
            declaration.visibility != DescriptorVisibilities.PUBLIC &&
            declaration.visibility != DescriptorVisibilities.LOCAL
        ) {
            print(declaration.visibility.toString().lowercase(Locale.ROOT))
            print(" ")
        }
        if (declaration.isInner) {
            print("inner ")
        }
        if (declaration.isData) {
            print("data ")
        }
        if (declaration.isInterface) {
            print("interface ")
        } else if (declaration.isObject) {
            print("object ")
        } else {
            if (declaration.modality != Modality.FINAL) {
                print(declaration.modality.toString().lowercase(Locale.ROOT))
                print(" ")
            }
            if (declaration.isAnnotationClass) {
                print("annotation ")
            }
            print("class ")
        }
        print(declaration.name)
        if (declaration.typeParameters.isNotEmpty()) {
            print("<")
            declaration.typeParameters.printJoin(", ")
            print("> ")
        }
        val ctorProperties = mutableSetOf()
        if (primaryConstructor != null) {
            if (primaryConstructor.valueParameters.isNotEmpty()) {
                print("(")
                primaryConstructor.valueParameters.forEachIndexed { index, param ->
                    val property = declaration.correspondingProperty(param)
                    if (property != null) {
                        ctorProperties.add(property)
                        print(if (property.isVar) "var " else "val ")
                    }
                    param.print()
                    if (index < primaryConstructor.valueParameters.size - 1) print(", ")
                }
                print(")")
            }
        }
        print(" ")
        if (declaration.superTypes.any { !it.isAny() }) {
            print(": ")
            print(declaration.superTypes.joinToString(", ") { it.renderSrc() })
            print(" ")
        }
        val nonParamDeclarations = declaration
            .declarations
            .filter { it != primaryConstructor && !ctorProperties.contains(it) }
            .filter { it.origin != IrDeclarationOrigin.FAKE_OVERRIDE }
        if (nonParamDeclarations.isNotEmpty()) {
            bracedBlock {
                nonParamDeclarations.printJoin("\n")
            }
        } else {
            println()
        }
    }

    fun IrClass.printAsObject() {
        print("object ")
        if (!name.isSpecial) {
            print(name)
            print(" ")
        }
        if (superTypes.any { !it.isAny() }) {
            print(": ")
            print(superTypes.joinToString(", ") { it.renderSrc() })
            print(" ")
        }
        val printableDeclarations = declarations
            .filter { it !is IrConstructor }
            .filter { it.origin != IrDeclarationOrigin.FAKE_OVERRIDE }
        if (printableDeclarations.isNotEmpty()) {
            bracedBlock {
                printableDeclarations.printJoin("\n")
            }
        } else {
            println()
        }
    }

    override fun visitBreak(jump: IrBreak) {
        print("break")
        if (jump.label != null) {
            print("@")
            print(jump.label)
        }
    }

    override fun visitContinue(jump: IrContinue) {
        print("continue")
        if (jump.label != null) {
            print("@")
            print(jump.label)
        }
    }

    override fun visitTypeParameter(declaration: IrTypeParameter) {
        print(declaration.name)
        val isNonEmpty = declaration.superTypes.isNotEmpty() &&
                !declaration.superTypes[0].isNullableAny()
        if (isNonEmpty) {
            print(": ")
            print(declaration.superTypes.joinToString(", ") { it.renderSrc() })
        }
    }

    override fun visitThrow(expression: IrThrow) {
        print("throw ")
        expression.value.print()
    }

    override fun visitClassReference(expression: IrClassReference) {
        print(expression.classType.renderSrc())
        print("::class")
    }

    override fun visitFunctionAccess(expression: IrFunctionAccessExpression) {
        super.visitFunctionAccess(expression)
    }

    override fun visitBranch(branch: IrBranch) {
        print("<>")
    }

    override fun visitBreakContinue(jump: IrBreakContinue) {
        print("<>")
    }

    override fun visitCatch(aCatch: IrCatch) {
        print("<>")
    }

    override fun visitContainerExpression(expression: IrContainerExpression) {
        print("<>")
    }

    override fun visitDelegatingConstructorCall(expression: IrDelegatingConstructorCall) {
        val constructedClass = expression.symbol.owner.constructedClass
        val name = constructedClass.name

        print("ctor<")
        print(name)
        print(">")

        expression.printArgumentList()
    }

    override fun visitElseBranch(branch: IrElseBranch) {
        print("<>")
    }

    override fun visitFunction(declaration: IrFunction) {
        print("<>")
    }

    override fun visitFunctionReference(expression: IrFunctionReference) {
        val function = expression.symbol.owner

        expression.printExplicitReceiver("::")
        if (expression.dispatchReceiver == null && expression.extensionReceiver == null) {
            print("::")
        }

        val prop = (function as? IrSimpleFunction)?.correspondingPropertySymbol?.owner

        if (prop != null) {
            val propName = prop.name.asString()
            print(propName)
            if (function == prop.setter) {
                print("::set")
            } else if (function == prop.getter) {
                print("::get")
            }
        } else {
            print(function.name.asString())
        }
    }

    override fun visitInstanceInitializerCall(expression: IrInstanceInitializerCall) {
        val constructedClass = expression.classSymbol.owner
        val name = constructedClass.name

        print("init<$name>()")
    }

    override fun visitLocalDelegatedProperty(declaration: IrLocalDelegatedProperty) {
        declaration.printAnnotations(onePerLine = true)
        print(if (declaration.isVar) "var " else "val ")
        print(declaration.name)
        print(" by ")
        bracedBlock {
            declaration.delegate.acceptVoid(this)
            declaration.getter.scoped { it.printPropertyAccessor() }
            declaration.setter?.scoped { it.printPropertyAccessor(isSetter = true) }
        }
    }

    private fun IrFunction.printPropertyAccessor(isSetter: Boolean = this.isSetter) {
        if (origin != IrDeclarationOrigin.DEFAULT_PROPERTY_ACCESSOR) {
            println()
            printAnnotations()
            println()
            print(if (isSetter) "set" else "get")
            print("(")
            valueParameters.printJoin(", ")
            print(") ")
            bracedBlock {
                body?.acceptVoid(this@IrSourcePrinterVisitor)
            }
        }
    }

    override fun visitLocalDelegatedPropertyReference(
        expression: IrLocalDelegatedPropertyReference,
    ) {
        print("::")
        print(expression.delegate.owner.name)
    }

    override fun visitLoop(loop: IrLoop) {
        print("<>")
    }

    override fun visitPropertyReference(expression: IrPropertyReference) {
        val property = expression.symbol.owner
        expression.printExplicitReceiver()
        print("::")
        print(property.name)
    }

    override fun visitSpreadElement(spread: IrSpreadElement) {
        print("<>")
    }

    override fun visitValueAccess(expression: IrValueAccessExpression) {
        print("<>")
    }

    override fun visitTry(aTry: IrTry) {
        println("try {")
        indented {
            aTry.tryResult.print()
        }
        println()
        if (aTry.catches.isNotEmpty()) {
            aTry.catches.forEach {
                println("} catch() {")
                indented {
                    it.print()
                }
                println()
            }
        }
        aTry.finallyExpression?.let {
            println("} finally {")
            indented {
                it.print()
            }
            println()
        }
        println("}")
    }

    override fun visitTypeAlias(declaration: IrTypeAlias) {
        print("<>")
    }

    override fun visitAnonymousInitializer(declaration: IrAnonymousInitializer) {
        println("init {")
        indented {
            declaration.body.print()
        }
        println()
        println("}")
    }

    private fun IrType.renderSrc() =
        "${renderTypeAnnotations(annotations)}${renderTypeInner()}"

    private fun IrType.renderTypeInner() =
        when (this) {
            is IrDynamicType -> "dynamic"

            is IrErrorType -> "IrErrorType"

            is IrSimpleType -> buildTrimEnd {
                append((classifier.owner as IrDeclarationWithName).name)
                if (arguments.isNotEmpty()) {
                    append(
                        arguments.joinToString(prefix = "<", postfix = ">", separator = ", ") {
                            it.renderTypeArgument()
                        }
                    )
                }
                if (isMarkedNullable()) {
                    append('?')
                }
                abbreviation?.let {
                    append(it.renderTypeAbbreviation())
                }
            }
        }

    private inline fun buildTrimEnd(fn: StringBuilder.() -> Unit): String =
        buildString(fn).trimEnd()

    private fun IrTypeAbbreviation.renderTypeAbbreviation(): String =
        buildString {
            append("{ ")
            append(renderTypeAnnotations(annotations))
            append(typeAlias.renderTypeAliasFqn())
            if (arguments.isNotEmpty()) {
                append(
                    arguments.joinToString(prefix = "<", postfix = ">", separator = ", ") {
                        it.renderTypeArgument()
                    }
                )
            }
            if (hasQuestionMark) {
                append('?')
            }
            append(" }")
        }

    private fun IrTypeArgument.renderTypeArgument(): String =
        when (this) {
            is IrStarProjection -> "*"

            is IrTypeProjection -> buildTrimEnd {
                append(variance.label)
                if (variance != Variance.INVARIANT) append(' ')
                append(type.renderSrc())
            }
        }

    private fun renderTypeAnnotations(annotations: List) =
        if (annotations.isEmpty())
            ""
        else
            annotations.joinToString(prefix = "", postfix = " ", separator = " ") {
                "@[${renderAsAnnotation(it)}]"
            }

    private fun renderAsAnnotation(irAnnotation: IrConstructorCall): String =
        StringBuilder().also { it.renderAsAnnotation(irAnnotation) }.toString()

    private fun StringBuilder.renderAsAnnotation(irAnnotation: IrConstructorCall) {
        val annotationClassName = try {
            irAnnotation.symbol.owner.parentAsClass.name.asString()
        } catch (e: Exception) {
            ""
        }
        append(annotationClassName)

        if (irAnnotation.valueArgumentsCount == 0) return

        val valueParameterNames = irAnnotation.getValueParameterNamesForDebug()
        var first = true
        append("(")
        for (i in 0 until irAnnotation.valueArgumentsCount) {
            if (first) {
                first = false
            } else {
                append(", ")
            }
            append(valueParameterNames[i])
            append(" = ")
            renderAsAnnotationArgument(irAnnotation.getValueArgument(i))
        }
        append(")")
    }

    @OptIn(ObsoleteDescriptorBasedAPI::class)
    private fun IrTypeAliasSymbol.renderTypeAliasFqn(): String =
        if (isBound)
            StringBuilder().also { owner.renderDeclarationFqn(it) }.toString()
        else
            ""

    private fun IrDeclaration.renderDeclarationFqn(sb: StringBuilder) {
        renderDeclarationParentFqn(sb)
        sb.append('.')
        if (this is IrDeclarationWithName) {
            sb.append(name.asString())
        } else {
            sb.append(this)
        }
    }

    private fun IrDeclaration.renderDeclarationParentFqn(sb: StringBuilder) {
        try {
            val parent = this.parent
            if (parent is IrDeclaration) {
                parent.renderDeclarationFqn(sb)
            } else if (parent is IrPackageFragment) {
                sb.append(parent.packageFqName.toString())
            }
        } catch (e: UninitializedPropertyAccessException) {
            sb.append("")
        }
    }

    private fun IrMemberAccessExpression<*>.getValueParameterNamesForDebug(): List {
        val expectedCount = valueArgumentsCount
        return if (symbol.isBound) {
            val owner = symbol.owner
            if (owner is IrFunction) {
                (0 until expectedCount).map {
                    if (it < owner.valueParameters.size)
                        owner.valueParameters[it].normalizedName
                    else
                        "${it + 1}"
                }
            } else {
                getPlaceholderParameterNames(expectedCount)
            }
        } else
            getPlaceholderParameterNames(expectedCount)
    }

    private fun getPlaceholderParameterNames(expectedCount: Int) =
        (1..expectedCount).map { "$it" }

    private fun StringBuilder.renderAsAnnotationArgument(irElement: IrElement?) {
        when (irElement) {
            null -> append("")
            is IrConstructorCall -> renderAsAnnotation(irElement)
            is IrConst -> {
                append('\'')
                append(irElement.value.toString())
                append('\'')
            }
            is IrVararg -> {
                appendListWith(irElement.elements, "[", "]", ", ") {
                    renderAsAnnotationArgument(it)
                }
            }
            else -> append(irElement.accept(this@IrSourcePrinterVisitor, null))
        }
    }

    // Names for temporary variables and synthesized parameters are not consistent between
    // K1 and K2. This function returns the same name for both frontends.
    private val IrValueDeclaration.normalizedName: String
        get() = when {
            // FIR generates both  and tmp0_for_iterator...
            origin == IrDeclarationOrigin.FOR_LOOP_ITERATOR -> ""
            // $anonymous$parameter$x vs $unused$var$x
            origin == IrDeclarationOrigin.UNDERSCORE_PARAMETER -> ""
            !useFir && name.asString().endsWith("_elvis_lhs") -> ""
            !useFir && name.asString() == "\$this\$null" -> ""
            else -> name.asString()
        }

    override fun visitElement(element: IrElement) {
        print("<<${element::class.java.simpleName}>>")
    }
}

private inline fun  StringBuilder.appendListWith(
    list: List,
    prefix: String,
    postfix: String,
    separator: String,
    renderItem: StringBuilder.(T) -> Unit,
) {
    append(prefix)
    var isFirst = true
    for (item in list) {
        if (!isFirst) append(separator)
        renderItem(item)
        isFirst = false
    }
    append(postfix)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy