All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.unibo.collektive.AlignmentIrGenerationExtension.kt Maven / Gradle / Ivy

@file:Suppress("ReturnCount")

package it.unibo.collektive

import it.unibo.collektive.transformers.AggregateCallTransformer
import it.unibo.collektive.utils.common.AggregateFunctionNames
import it.unibo.collektive.utils.common.AggregateFunctionNames.ALIGN_FUNCTION
import it.unibo.collektive.utils.common.AggregateFunctionNames.DEALIGN_RAW_FUNCTION
import it.unibo.collektive.utils.common.AggregateFunctionNames.PROJECT_FUNCTION
import it.unibo.collektive.utils.logging.error
import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.util.functions
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name

/**
 * The generation extension is used to register the transformer plugin, which is going to modify
 * the IR using the function responsible for the alignment.
 */
@OptIn(UnsafeDuringIrConstructionAPI::class)
class AlignmentIrGenerationExtension(private val logger: MessageCollector) : IrGenerationExtension {
    override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
        // Aggregate Context class that has the reference to the stack
        val aggregateClass = pluginContext.referenceClass(
            ClassId.topLevel(FqName(AggregateFunctionNames.AGGREGATE_CLASS)),
        )
        if (aggregateClass == null) {
            return logger.error("Unable to find the aggregate class")
        }

        val projectFunction = pluginContext.referenceFunctions(
            CallableId(
                FqName("it.unibo.collektive.aggregate.api.impl"),
                Name.identifier(PROJECT_FUNCTION),
            ),
        ).firstOrNull() ?: return logger.error("Unable to find the 'project' function")

        // Function that handles the alignment
        val alignRawFunction = aggregateClass.getFunctionReferenceWithName(ALIGN_FUNCTION)
            ?: return logger.error("Unable to find the `$ALIGN_FUNCTION` function")

        val dealignFunction = aggregateClass.getFunctionReferenceWithName(DEALIGN_RAW_FUNCTION)
            ?: return logger.error("Unable to find the `$DEALIGN_RAW_FUNCTION` function")

        /*
         This applies the alignment call on all the aggregate functions
         */
        moduleFragment.transform(
            AggregateCallTransformer(
                pluginContext,
                logger,
                aggregateClass.owner,
                alignRawFunction.owner,
                dealignFunction.owner,
                projectFunction.owner,
            ),
            null,
        )
    }

    private fun IrClassSymbol.getFunctionReferenceWithName(functionName: String): IrFunctionSymbol? =
        functions.firstOrNull { it.owner.name == Name.identifier(functionName) }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy