org.jetbrains.kotlin.backend.wasm.ir2wasm.WasmCompiledModuleFragment.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kotlin-compiler-embeddable Show documentation
Show all versions of kotlin-compiler-embeddable Show documentation
the Kotlin compiler embeddable
/*
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.backend.wasm.ir2wasm
import org.jetbrains.kotlin.ir.IrBuiltIns
import org.jetbrains.kotlin.ir.declarations.IrDeclarationWithName
import org.jetbrains.kotlin.ir.declarations.IrExternalPackageFragment
import org.jetbrains.kotlin.ir.symbols.*
import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable
import org.jetbrains.kotlin.ir.util.getPackageFragment
import org.jetbrains.kotlin.ir.util.isInterface
import org.jetbrains.kotlin.wasm.ir.*
import org.jetbrains.kotlin.wasm.ir.source.location.SourceLocation
class WasmCompiledModuleFragment(
val irBuiltIns: IrBuiltIns,
generateTrapsInsteadOfExceptions: Boolean,
) {
val functions =
ReferencableAndDefinable()
val globalFields =
ReferencableAndDefinable()
val globalVTables =
ReferencableAndDefinable()
val globalClassITables =
ReferencableAndDefinable()
val functionTypes =
ReferencableAndDefinable()
val gcTypes =
ReferencableAndDefinable()
val vTableGcTypes =
ReferencableAndDefinable()
val classITableGcType =
ReferencableAndDefinable()
val classITableInterfaceSlot =
ReferencableAndDefinable()
val typeIds =
ReferencableElements()
val stringLiteralAddress =
ReferencableElements()
val stringLiteralPoolId =
ReferencableElements()
val constantArrayDataSegmentId =
ReferencableElements, WasmType>, Int>()
private val tagFuncType = WasmFunctionType(
listOf(
WasmRefNullType(WasmHeapType.Type(gcTypes.reference(irBuiltIns.throwableClass)))
),
emptyList()
)
val tags = if (generateTrapsInsteadOfExceptions) emptyList() else listOf(WasmTag(tagFuncType))
val typeInfo = ReferencableAndDefinable()
val exports = mutableListOf>()
class JsCodeSnippet(val importName: String, val jsCode: String)
val jsFuns = mutableListOf()
val jsModuleImports = mutableSetOf()
class FunWithPriority(val function: WasmFunction, val priority: String)
val initFunctions = mutableListOf()
val scratchMemAddr = WasmSymbol()
val stringPoolSize = WasmSymbol()
open class ReferencableElements {
val unbound = mutableMapOf>()
fun reference(ir: Ir): WasmSymbol {
val declaration = (ir as? IrSymbol)?.owner as? IrDeclarationWithName
if (declaration != null) {
val packageFragment = declaration.getPackageFragment()
if (packageFragment is IrExternalPackageFragment) {
error("Referencing declaration without package fragment ${declaration.fqNameWhenAvailable}")
}
}
return unbound.getOrPut(ir) { WasmSymbol() }
}
}
class ReferencableAndDefinable : ReferencableElements() {
fun define(ir: Ir, wasm: Wasm) {
if (ir in defined)
error("Trying to redefine element: IR: $ir Wasm: $wasm")
elements += wasm
defined[ir] = wasm
wasmToIr[wasm] = ir
}
val defined = LinkedHashMap()
val elements = mutableListOf()
val wasmToIr = mutableMapOf()
}
fun linkWasmCompiledFragments(): WasmModule {
bind(functions.unbound, functions.defined)
bind(globalFields.unbound, globalFields.defined)
bind(globalVTables.unbound, globalVTables.defined)
bind(gcTypes.unbound, gcTypes.defined)
bind(vTableGcTypes.unbound, vTableGcTypes.defined)
bind(classITableGcType.unbound, classITableGcType.defined)
bind(classITableInterfaceSlot.unbound, classITableInterfaceSlot.defined)
bind(globalClassITables.unbound, globalClassITables.defined)
// Associate function types to a single canonical function type
val canonicalFunctionTypes =
functionTypes.elements.associateWithTo(LinkedHashMap()) { it }
functionTypes.unbound.forEach { (irSymbol, wasmSymbol) ->
if (irSymbol !in functionTypes.defined)
error("Can't link symbol ${irSymbolDebugDump(irSymbol)}")
wasmSymbol.bind(canonicalFunctionTypes.getValue(functionTypes.defined.getValue(irSymbol)))
}
var currentDataSectionAddress = 0
var interfaceId = 0
typeIds.unbound.forEach { (klassSymbol, wasmSymbol) ->
if (klassSymbol.owner.isInterface) {
interfaceId--
wasmSymbol.bind(interfaceId)
} else {
wasmSymbol.bind(currentDataSectionAddress)
currentDataSectionAddress += typeInfo.defined.getValue(klassSymbol).sizeInBytes
}
}
currentDataSectionAddress = alignUp(currentDataSectionAddress, INT_SIZE_BYTES)
scratchMemAddr.bind(currentDataSectionAddress)
val stringDataSectionBytes = mutableListOf()
var stringDataSectionStart = 0
var stringLiteralCount = 0
for ((string, symbol) in stringLiteralAddress.unbound) {
symbol.bind(stringDataSectionStart)
stringLiteralPoolId.reference(string).bind(stringLiteralCount)
val constData = ConstantDataCharArray("string_literal", string.toCharArray())
stringDataSectionBytes += constData.toBytes().toList()
stringDataSectionStart += constData.sizeInBytes
stringLiteralCount++
}
stringPoolSize.bind(stringLiteralCount)
val data = mutableListOf()
data.add(WasmData(WasmDataMode.Passive, stringDataSectionBytes.toByteArray()))
constantArrayDataSegmentId.unbound.forEach { (constantArraySegment, symbol) ->
symbol.bind(data.size)
val integerSize = when (constantArraySegment.second) {
WasmI8 -> BYTE_SIZE_BYTES
WasmI16 -> SHORT_SIZE_BYTES
WasmI32 -> INT_SIZE_BYTES
WasmI64 -> LONG_SIZE_BYTES
else -> TODO("type ${constantArraySegment.second} is not implemented")
}
val constData = ConstantDataIntegerArray("constant_array", constantArraySegment.first, integerSize)
data.add(WasmData(WasmDataMode.Passive, constData.toBytes()))
}
typeIds.unbound.forEach { (klassSymbol, typeId) ->
if (!klassSymbol.owner.isInterface) {
val instructions = mutableListOf()
WasmIrExpressionBuilder(instructions).buildConstI32(
typeId.owner,
SourceLocation.NoLocation("Compile time data per class")
)
val typeData = WasmData(
WasmDataMode.Active(0, instructions),
typeInfo.defined.getValue(klassSymbol).toBytes()
)
data.add(typeData)
}
}
val masterInitFunctionType = WasmFunctionType(emptyList(), emptyList())
val masterInitFunction = WasmFunction.Defined("_initialize", WasmSymbol(masterInitFunctionType))
with(WasmIrExpressionBuilder(masterInitFunction.instructions)) {
initFunctions.sortedBy { it.priority }.forEach {
buildCall(WasmSymbol(it.function), SourceLocation.NoLocation("Generated service code"))
}
}
exports += WasmExport.Function("_initialize", masterInitFunction)
val typeInfoSize = currentDataSectionAddress
val memorySizeInPages = (typeInfoSize / 65_536) + 1
val memory = WasmMemory(WasmLimits(memorySizeInPages.toUInt(), null /* "unlimited" */))
// Need to export the memory in order to pass complex objects to the host language.
// Export name "memory" is a WASI ABI convention.
exports += WasmExport.Memory("memory", memory)
val importedFunctions = functions.elements.filterIsInstance()
fun wasmTypeDeclarationOrderKey(declaration: WasmTypeDeclaration): Int {
return when (declaration) {
is WasmArrayDeclaration -> 0
is WasmFunctionType -> 0
is WasmStructDeclaration ->
// Subtype depth
declaration.superType?.let { wasmTypeDeclarationOrderKey(it.owner) + 1 } ?: 0
}
}
val recGroupTypes = mutableListOf()
recGroupTypes.addAll(vTableGcTypes.elements)
recGroupTypes.addAll(this.gcTypes.elements)
recGroupTypes.addAll(classITableGcType.elements.distinct())
recGroupTypes.sortBy(::wasmTypeDeclarationOrderKey)
val globals = mutableListOf()
globals.addAll(globalFields.elements)
globals.addAll(globalVTables.elements)
globals.addAll(globalClassITables.elements.distinct())
val allFunctionTypes = canonicalFunctionTypes.values.toList() + tagFuncType + masterInitFunctionType
// Partition out function types that can't be recursive,
// we don't need to put them into a rec group
// so that they can be matched with function types from other Wasm modules.
val (potentiallyRecursiveFunctionTypes, nonRecursiveFunctionTypes) =
allFunctionTypes.partition { it.referencesTypeDeclarations() }
recGroupTypes.addAll(potentiallyRecursiveFunctionTypes)
val module = WasmModule(
functionTypes = nonRecursiveFunctionTypes,
recGroupTypes = recGroupTypes,
importsInOrder = importedFunctions,
importedFunctions = importedFunctions,
definedFunctions = functions.elements.filterIsInstance() + masterInitFunction,
tables = emptyList(),
memories = listOf(memory),
globals = globals,
exports = exports,
startFunction = null, // Module is initialized via export call
elements = emptyList(),
data = data,
dataCount = true,
tags = tags
)
module.calculateIds()
return module
}
}
fun > bind(
unbound: Map,
defined: Map
) {
unbound.forEach { (irSymbol, wasmSymbol) ->
if (irSymbol !in defined)
error("Can't link symbol ${irSymbolDebugDump(irSymbol)}")
wasmSymbol.bind(defined.getValue(irSymbol))
}
}
private fun irSymbolDebugDump(symbol: Any?): String =
when (symbol) {
is IrFunctionSymbol -> "function ${symbol.owner.fqNameWhenAvailable}"
is IrClassSymbol -> "class ${symbol.owner.fqNameWhenAvailable}"
else -> symbol.toString()
}
fun alignUp(x: Int, alignment: Int): Int {
assert(alignment and (alignment - 1) == 0) { "power of 2 expected" }
return (x + alignment - 1) and (alignment - 1).inv()
}