All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.unibo.collektive.transformers.AlignmentTransformer.kt Maven / Gradle / Ivy

There is a newer version: 11.0.0
Show newest version
package it.unibo.collektive.transformers

import it.unibo.collektive.utils.common.AggregateFunctionNames.ALIGN_FUNCTION
import it.unibo.collektive.utils.common.AggregateFunctionNames.DEALIGN_RAW_FUNCTION
import it.unibo.collektive.utils.common.findAggregateReference
import it.unibo.collektive.utils.common.getAlignmentToken
import it.unibo.collektive.utils.common.irStatement
import it.unibo.collektive.utils.common.isAssignableFrom
import it.unibo.collektive.utils.common.simpleFunctionName
import it.unibo.collektive.utils.stack.StackFunctionCall
import it.unibo.collektive.visitors.collectAggregateReference
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.jvm.ir.receiverAndArgs
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.builders.IrBlockBodyBuilder
import org.jetbrains.kotlin.ir.builders.createTmpVariable
import org.jetbrains.kotlin.ir.builders.irBlock
import org.jetbrains.kotlin.ir.builders.irBoolean
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irString
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.expressions.IrBranch
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrConst
import org.jetbrains.kotlin.ir.expressions.IrContainerExpression
import org.jetbrains.kotlin.ir.expressions.IrElseBranch
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.putArgument
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.util.defaultType
import org.jetbrains.kotlin.ir.visitors.IrElementTransformer

/**
 * This transforms the generated IR only when an aggregate computing's function is involved:
 * for each function call and branch found, they are going to be wrapped in the alignedOn
 * function.
 */
class AlignmentTransformer(
    private val pluginContext: IrPluginContext,
    private val aggregateContextClass: IrClass,
    private val functionToAlign: IrFunction,
    private val alignRawFunction: IrFunction,
    private val dealignFunction: IrFunction,
) : IrElementTransformer {
    private var alignedFunctions = emptyMap()

    @OptIn(UnsafeDuringIrConstructionAPI::class)
    override fun visitCall(expression: IrCall, data: StackFunctionCall): IrElement {
        val contextReference = expression.receiverAndArgs()
            .find { it.type.isAssignableFrom(aggregateContextClass.defaultType) }
            ?: collectAggregateReference(aggregateContextClass, expression.symbol.owner)

        val alignmentToken = expression.getAlignmentToken()
        // If the context is null, this means that the function is not an aggregate function
        if (contextReference == null) {
            data.push(alignmentToken)
        }
        return contextReference?.let { context ->
            // We don't want to align the alignRaw and dealign functions :)
            val functionName = expression.simpleFunctionName()
            if (functionName == ALIGN_FUNCTION || functionName == DEALIGN_RAW_FUNCTION) {
                return super.visitCall(expression, data)
            }
            // If no function, the first time the counter is 1
            val actualCounter = alignedFunctions[alignmentToken]?.let { it + 1 } ?: 1
            alignedFunctions += alignmentToken to actualCounter
            // If the expression contains a lambda, this recursion is necessary to visit the children
            expression.transformChildren(this, StackFunctionCall())
            val tokenCount = alignedFunctions[alignmentToken] ?: error(
                """
                    Unable to find the count for the token $alignmentToken.
                    This is may due to a bug in collektive compiler plugin.
                """.trimIndent(),
            )
            val alignmentTokenRepresentation = "$data$alignmentToken.$tokenCount"
            // Return the modified function body to have as a first statement the alignRaw function,
            // then the body of the function to align and finally the dealign function
            generateAlignmentCode(context, functionToAlign, expression) { irString(alignmentTokenRepresentation) }
        } ?: super.visitCall(expression, data)
    }

    override fun visitBranch(branch: IrBranch, data: StackFunctionCall): IrBranch {
        branch.generateBranchAlignmentCode(true)
        return super.visitBranch(branch, data)
    }

    override fun visitElseBranch(branch: IrElseBranch, data: StackFunctionCall): IrElseBranch {
        branch.generateBranchAlignmentCode(false)
        return super.visitElseBranch(branch, data)
    }

    private fun IrBranch.generateBranchAlignmentCode(condition: Boolean) {
        result.findAggregateReference(aggregateContextClass)?.let {
            result = generateAlignmentCode(it, functionToAlign, result) { irBoolean(condition) }
        }
    }

    private fun  generateAlignmentCode(
        context: IrExpression,
        function: IrFunction,
        expressionBody: IrExpression,
        alignmentToken: IrBlockBodyBuilder.() -> IrConst,
    ): IrContainerExpression {
        return irStatement(pluginContext, function, expressionBody) {
            // Call the `alignRaw` function before the body of the function to align
            irBlock {
                // Call the alignRaw function
                +irCall(alignRawFunction).apply {
                    putArgument(
                        alignRawFunction.dispatchReceiverParameter
                            ?: error("The alignRaw function has no dispatch receiver parameter"),
                        context,
                    )
                    putValueArgument(0, alignmentToken(this@irStatement))
                }
                val code = irBlock { +expressionBody }
                // Call the body of the function to align
                val variableName = "blockResult"
                val variableType = expressionBody.type
                val tmpVar = createTmpVariable(code, irType = variableType, nameHint = variableName)
                // Call the `dealign` function after the body of the function to align
                +irCall(dealignFunction).apply {
                    putArgument(
                        dealignFunction.dispatchReceiverParameter
                            ?: error("The dealign function has no dispatch receiver parameter"),
                        context,
                    )
                }
                +irGet(tmpVar)
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy