All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.backend.wasm.ir2wasm.WasmCompiledModuleFragment.kt Maven / Gradle / Ivy

/*
 * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.backend.wasm.ir2wasm

import org.jetbrains.kotlin.wasm.ir.*
import org.jetbrains.kotlin.backend.wasm.lower.WasmSignature
import org.jetbrains.kotlin.ir.declarations.IrDeclarationWithName
import org.jetbrains.kotlin.ir.declarations.IrExternalPackageFragment
import org.jetbrains.kotlin.ir.symbols.*
import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable
import org.jetbrains.kotlin.ir.util.getPackageFragment

class WasmCompiledModuleFragment {
    val functions =
        ReferencableAndDefinable()
    val globals =
        ReferencableAndDefinable()
    val functionTypes =
        ReferencableAndDefinable()
    val structTypes =
        ReferencableAndDefinable()
    val classIds =
        ReferencableElements()
    val interfaceId =
        ReferencableElements()
    val virtualFunctionId =
        ReferencableElements()
    val signatureId =
        ReferencableElements()
    val stringLiteralId =
        ReferencableElements()

    val runtimeTypes =
        ReferencableAndDefinable()

    val classes = mutableListOf()
    val interfaces = mutableListOf()
    val virtualFunctions = mutableListOf()
    val signatures = LinkedHashSet()
    val stringLiterals = mutableListOf()

    val typeInfo =
        ReferencableAndDefinable()
    val exports = mutableListOf>()

    class JsCodeSnippet(val importName: String, val jsCode: String)

    val jsFuns = mutableListOf()

    var startFunction: WasmFunction? = null

    open class ReferencableElements {
        val unbound = mutableMapOf>()
        fun reference(ir: Ir): WasmSymbol {
            val declaration = (ir as? IrSymbol)?.owner as? IrDeclarationWithName
            if (declaration != null) {
                val packageFragment = declaration.getPackageFragment()
                    ?: error("Referencing declaration without package fragment ${declaration.fqNameWhenAvailable}")
                if (packageFragment is IrExternalPackageFragment) {
                    error("Referencing declaration without package fragment ${declaration.fqNameWhenAvailable}")
                }
            }
            return unbound.getOrPut(ir) { WasmSymbol() }
        }
    }

    class ReferencableAndDefinable : ReferencableElements() {
        fun define(ir: Ir, wasm: Wasm) {
            if (ir in defined)
                error("Trying to redefine element: IR: $ir Wasm: $wasm")

            elements += wasm
            defined[ir] = wasm
            wasmToIr[wasm] = ir
        }

        val defined = LinkedHashMap()
        val elements = mutableListOf()

        val wasmToIr = mutableMapOf()
    }

    @OptIn(ExperimentalUnsignedTypes::class)
    fun linkWasmCompiledFragments(): WasmModule {
        bind(functions.unbound, functions.defined)
        bind(globals.unbound, globals.defined)
        bind(functionTypes.unbound, functionTypes.defined)
        bind(structTypes.unbound, structTypes.defined)
        bind(runtimeTypes.unbound, runtimeTypes.defined)

        val klassIds = mutableMapOf()
        var classId = 0
        for (typeInfoElement in typeInfo.elements) {
            val ir = typeInfo.wasmToIr.getValue(typeInfoElement)
            klassIds[ir] = classId
            classId += typeInfoElement.sizeInBytes
        }

        bind(classIds.unbound, klassIds)
        bindIndices(virtualFunctionId.unbound, virtualFunctions)
        bindIndices(signatureId.unbound, signatures.toList())
        bindIndices(interfaceId.unbound, interfaces)
        bindIndices(stringLiteralId.unbound, stringLiterals)

        val data = typeInfo.elements.map {
            val ir = typeInfo.wasmToIr.getValue(it)
            val id = klassIds.getValue(ir)
            val offset = mutableListOf()
            WasmIrExpressionBuilder(offset).buildConstI32(id)
            WasmData(WasmDataMode.Active(0, offset), it.toBytes())
        }

        val logTypeInfo = false
        if (logTypeInfo) {
            println("Signatures: ")
            for ((index, signature: WasmSignature) in signatures.withIndex()) {
                println("  -- $index $signature")
            }

            println("Interfaces: ")
            for ((index, iface: IrClassSymbol) in interfaces.withIndex()) {
                println("  -- $index ${iface.owner.fqNameWhenAvailable}")
            }

            println("Virtual functions: ")
            for ((index, vf: IrSimpleFunctionSymbol) in virtualFunctions.withIndex()) {
                println("  -- $index ${vf.owner.fqNameWhenAvailable}")
            }

            println(
                ConstantDataStruct("typeInfo", typeInfo.elements).dump("", 0)
            )
        }

        val table = WasmTable(
            limits = WasmLimits(virtualFunctions.size.toUInt(), virtualFunctions.size.toUInt()),
            elementType = WasmFuncRef,
        )

        val offsetExpr = mutableListOf()
        WasmIrExpressionBuilder(offsetExpr).buildConstI32(0)

        val elements = WasmElement(
            WasmFuncRef,
            values = virtualFunctions.map {
                WasmTable.Value.Function(functions.defined.getValue(it))
            },
            WasmElement.Mode.Active(table, offsetExpr)
        )

        val typeInfoSize = classId
        val memorySizeInPages = (typeInfoSize / 65_536) + 1
        val memory = WasmMemory(WasmLimits(memorySizeInPages.toUInt(), memorySizeInPages.toUInt()))

        val importedFunctions = functions.elements.filterIsInstance()

        // Sorting by depth for a valid init order
        val sortedRttGlobals = runtimeTypes.elements.sortedBy { (it.type as WasmRtt).depth }

        val module = WasmModule(
            functionTypes = functionTypes.elements,
            structs = structTypes.elements,
            importsInOrder = importedFunctions,
            importedFunctions = importedFunctions,
            definedFunctions = functions.elements.filterIsInstance(),
            tables = listOf(table),
            memories = listOf(memory),
            globals = globals.elements + sortedRttGlobals,
            exports = exports,
            startFunction = startFunction!!,
            elements = listOf(elements),
            data = data
        )
        module.calculateIds()
        return module
    }
}

fun > bind(
    unbound: Map,
    defined: Map
) {
    unbound.forEach { (irSymbol, wasmSymbol) ->
        if (irSymbol !in defined)
            error("Can't link symbol ${irSymbolDebugDump(irSymbol)}")
        wasmSymbol.bind(defined.getValue(irSymbol))
    }
}

private fun irSymbolDebugDump(symbol: Any?): String =
    when (symbol) {
        is IrFunctionSymbol -> "function ${symbol.owner.fqNameWhenAvailable}"
        is IrClassSymbol -> "class ${symbol.owner.fqNameWhenAvailable}"
        else -> symbol.toString()
    }

fun  bindIndices(
    unbound: Map>,
    ordered: List
) {
    unbound.forEach { (irSymbol, wasmSymbol) ->
        val index = ordered.indexOf(irSymbol)
        if (index == -1)
            error("Can't link symbol with indices ${irSymbolDebugDump(irSymbol)}")
        wasmSymbol.bind(index)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy