All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.backend.wasm.ir2wasm.WasmCompiledModuleFragment.kt Maven / Gradle / Ivy

There is a newer version: 2.1.0-RC
Show newest version
/*
 * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.backend.wasm.ir2wasm

import org.jetbrains.kotlin.ir.IrBuiltIns
import org.jetbrains.kotlin.ir.declarations.IrDeclarationWithName
import org.jetbrains.kotlin.ir.declarations.IrExternalPackageFragment
import org.jetbrains.kotlin.ir.symbols.*
import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable
import org.jetbrains.kotlin.ir.util.getPackageFragment
import org.jetbrains.kotlin.ir.util.isInterface
import org.jetbrains.kotlin.wasm.ir.*
import org.jetbrains.kotlin.wasm.ir.source.location.SourceLocation

class WasmCompiledModuleFragment(
    val irBuiltIns: IrBuiltIns,
    generateTrapsInsteadOfExceptions: Boolean,
) {
    val functions =
        ReferencableAndDefinable()
    val globalFields =
        ReferencableAndDefinable()
    val globalVTables =
        ReferencableAndDefinable()
    val globalClassITables =
        ReferencableAndDefinable()
    val functionTypes =
        ReferencableAndDefinable()
    val gcTypes =
        ReferencableAndDefinable()
    val vTableGcTypes =
        ReferencableAndDefinable()
    val classITableGcType =
        ReferencableAndDefinable()
    val classITableInterfaceSlot =
        ReferencableAndDefinable()
    val typeIds =
        ReferencableElements()
    val stringLiteralAddress =
        ReferencableElements()
    val stringLiteralPoolId =
        ReferencableElements()
    val constantArrayDataSegmentId =
        ReferencableElements, WasmType>, Int>()

    private val tagFuncType = WasmFunctionType(
        listOf(
            WasmRefNullType(WasmHeapType.Type(gcTypes.reference(irBuiltIns.throwableClass)))
        ),
        emptyList()
    )
    val tags = if (generateTrapsInsteadOfExceptions) emptyList() else listOf(WasmTag(tagFuncType))

    val typeInfo = ReferencableAndDefinable()

    val exports = mutableListOf>()

    class JsCodeSnippet(val importName: String, val jsCode: String)

    val jsFuns = mutableListOf()
    val jsModuleImports = mutableSetOf()

    class FunWithPriority(val function: WasmFunction, val priority: String)

    val initFunctions = mutableListOf()

    val scratchMemAddr = WasmSymbol()

    val stringPoolSize = WasmSymbol()

    open class ReferencableElements {
        val unbound = mutableMapOf>()
        fun reference(ir: Ir): WasmSymbol {
            val declaration = (ir as? IrSymbol)?.owner as? IrDeclarationWithName
            if (declaration != null) {
                val packageFragment = declaration.getPackageFragment()
                if (packageFragment is IrExternalPackageFragment) {
                    error("Referencing declaration without package fragment ${declaration.fqNameWhenAvailable}")
                }
            }
            return unbound.getOrPut(ir) { WasmSymbol() }
        }
    }

    class ReferencableAndDefinable : ReferencableElements() {
        fun define(ir: Ir, wasm: Wasm) {
            if (ir in defined)
                error("Trying to redefine element: IR: $ir Wasm: $wasm")

            elements += wasm
            defined[ir] = wasm
            wasmToIr[wasm] = ir
        }

        val defined = LinkedHashMap()
        val elements = mutableListOf()

        val wasmToIr = mutableMapOf()
    }

    fun linkWasmCompiledFragments(): WasmModule {
        bind(functions.unbound, functions.defined)
        bind(globalFields.unbound, globalFields.defined)
        bind(globalVTables.unbound, globalVTables.defined)
        bind(gcTypes.unbound, gcTypes.defined)
        bind(vTableGcTypes.unbound, vTableGcTypes.defined)
        bind(classITableGcType.unbound, classITableGcType.defined)
        bind(classITableInterfaceSlot.unbound, classITableInterfaceSlot.defined)
        bind(globalClassITables.unbound, globalClassITables.defined)

        // Associate function types to a single canonical function type
        val canonicalFunctionTypes =
            functionTypes.elements.associateWithTo(LinkedHashMap()) { it }

        functionTypes.unbound.forEach { (irSymbol, wasmSymbol) ->
            if (irSymbol !in functionTypes.defined)
                error("Can't link symbol ${irSymbolDebugDump(irSymbol)}")
            wasmSymbol.bind(canonicalFunctionTypes.getValue(functionTypes.defined.getValue(irSymbol)))
        }

        var currentDataSectionAddress = 0
        var interfaceId = 0
        typeIds.unbound.forEach { (klassSymbol, wasmSymbol) ->
            if (klassSymbol.owner.isInterface) {
                interfaceId--
                wasmSymbol.bind(interfaceId)
            } else {
                wasmSymbol.bind(currentDataSectionAddress)
                currentDataSectionAddress += typeInfo.defined.getValue(klassSymbol).sizeInBytes
            }
        }
        currentDataSectionAddress = alignUp(currentDataSectionAddress, INT_SIZE_BYTES)
        scratchMemAddr.bind(currentDataSectionAddress)

        val stringDataSectionBytes = mutableListOf()
        var stringDataSectionStart = 0
        var stringLiteralCount = 0
        for ((string, symbol) in stringLiteralAddress.unbound) {
            symbol.bind(stringDataSectionStart)
            stringLiteralPoolId.reference(string).bind(stringLiteralCount)
            val constData = ConstantDataCharArray("string_literal", string.toCharArray())
            stringDataSectionBytes += constData.toBytes().toList()
            stringDataSectionStart += constData.sizeInBytes
            stringLiteralCount++
        }
        stringPoolSize.bind(stringLiteralCount)

        val data = mutableListOf()
        data.add(WasmData(WasmDataMode.Passive, stringDataSectionBytes.toByteArray()))
        constantArrayDataSegmentId.unbound.forEach { (constantArraySegment, symbol) ->
            symbol.bind(data.size)
            val integerSize = when (constantArraySegment.second) {
                WasmI8 -> BYTE_SIZE_BYTES
                WasmI16 -> SHORT_SIZE_BYTES
                WasmI32 -> INT_SIZE_BYTES
                WasmI64 -> LONG_SIZE_BYTES
                else -> TODO("type ${constantArraySegment.second} is not implemented")
            }
            val constData = ConstantDataIntegerArray("constant_array", constantArraySegment.first, integerSize)
            data.add(WasmData(WasmDataMode.Passive, constData.toBytes()))
        }

        typeIds.unbound.forEach { (klassSymbol, typeId) ->
            if (!klassSymbol.owner.isInterface) {
                val instructions = mutableListOf()
                WasmIrExpressionBuilder(instructions).buildConstI32(
                    typeId.owner,
                    SourceLocation.NoLocation("Compile time data per class")
                )
                val typeData = WasmData(
                    WasmDataMode.Active(0, instructions),
                    typeInfo.defined.getValue(klassSymbol).toBytes()
                )
                data.add(typeData)
            }
        }

        val masterInitFunctionType = WasmFunctionType(emptyList(), emptyList())
        val masterInitFunction = WasmFunction.Defined("_initialize", WasmSymbol(masterInitFunctionType))
        with(WasmIrExpressionBuilder(masterInitFunction.instructions)) {
            initFunctions.sortedBy { it.priority }.forEach {
                buildCall(WasmSymbol(it.function), SourceLocation.NoLocation("Generated service code"))
            }
        }
        exports += WasmExport.Function("_initialize", masterInitFunction)

        val typeInfoSize = currentDataSectionAddress
        val memorySizeInPages = (typeInfoSize / 65_536) + 1
        val memory = WasmMemory(WasmLimits(memorySizeInPages.toUInt(), null /* "unlimited" */))

        // Need to export the memory in order to pass complex objects to the host language.
        // Export name "memory" is a WASI ABI convention.
        exports += WasmExport.Memory("memory", memory)

        val importedFunctions = functions.elements.filterIsInstance()

        fun wasmTypeDeclarationOrderKey(declaration: WasmTypeDeclaration): Int {
            return when (declaration) {
                is WasmArrayDeclaration -> 0
                is WasmFunctionType -> 0
                is WasmStructDeclaration ->
                    // Subtype depth
                    declaration.superType?.let { wasmTypeDeclarationOrderKey(it.owner) + 1 } ?: 0
            }
        }

        val recGroupTypes = mutableListOf()
        recGroupTypes.addAll(vTableGcTypes.elements)
        recGroupTypes.addAll(this.gcTypes.elements)
        recGroupTypes.addAll(classITableGcType.elements.distinct())
        recGroupTypes.sortBy(::wasmTypeDeclarationOrderKey)

        val globals = mutableListOf()
        globals.addAll(globalFields.elements)
        globals.addAll(globalVTables.elements)
        globals.addAll(globalClassITables.elements.distinct())

        val allFunctionTypes = canonicalFunctionTypes.values.toList() + tagFuncType + masterInitFunctionType

        // Partition out function types that can't be recursive,
        // we don't need to put them into a rec group
        // so that they can be matched with function types from other Wasm modules.
        val (potentiallyRecursiveFunctionTypes, nonRecursiveFunctionTypes) =
            allFunctionTypes.partition { it.referencesTypeDeclarations() }
        recGroupTypes.addAll(potentiallyRecursiveFunctionTypes)

        val module = WasmModule(
            functionTypes = nonRecursiveFunctionTypes,
            recGroupTypes = recGroupTypes,
            importsInOrder = importedFunctions,
            importedFunctions = importedFunctions,
            definedFunctions = functions.elements.filterIsInstance() + masterInitFunction,
            tables = emptyList(),
            memories = listOf(memory),
            globals = globals,
            exports = exports,
            startFunction = null,  // Module is initialized via export call
            elements = emptyList(),
            data = data,
            dataCount = true,
            tags = tags
        )
        module.calculateIds()
        return module
    }
}

fun > bind(
    unbound: Map,
    defined: Map
) {
    unbound.forEach { (irSymbol, wasmSymbol) ->
        if (irSymbol !in defined)
            error("Can't link symbol ${irSymbolDebugDump(irSymbol)}")
        wasmSymbol.bind(defined.getValue(irSymbol))
    }
}

private fun irSymbolDebugDump(symbol: Any?): String =
    when (symbol) {
        is IrFunctionSymbol -> "function ${symbol.owner.fqNameWhenAvailable}"
        is IrClassSymbol -> "class ${symbol.owner.fqNameWhenAvailable}"
        else -> symbol.toString()
    }

fun alignUp(x: Int, alignment: Int): Int {
    assert(alignment and (alignment - 1) == 0) { "power of 2 expected" }
    return (x + alignment - 1) and (alignment - 1).inv()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy