All Downloads are FREE. Search and download functionalities are using the official Maven repository.

oc.compiler-plugin.0.8.0.source-code.Memoizer.kt Maven / Gradle / Ivy

package com.sschr15.aoc.compiler.internal

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.backend.common.lower.irIfThen
import org.jetbrains.kotlin.backend.common.lower.irNot
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.addBackingField
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrReturn
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.symbols.impl.IrPropertySymbolImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.typeWith
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import kotlin.reflect.full.allSupertypes

@OptIn(UnsafeDuringIrConstructionAPI::class)
class Memoizer(private val context: IrPluginContext) : IrElementTransformerVoid() {
    private val memoizedOrigin = object : IrDeclarationOrigin {
        override val isSynthetic = true
        override val name = "From Memoize annotation"
    }

    private val mapGet = context.irBuiltIns.mapClass.functions.single { it.owner.name.asString() == "get" && it.owner.valueParameters.size == 1 }
    private val mapPut = context.irBuiltIns.mutableMapClass.functions.single { it.owner.name.asString() == "put" && it.owner.valueParameters.size == 2 }

    private val mutableMapOf = context.referenceFunctions(CallableId(
        FqName("kotlin.collections"),
        null,
        Name.identifier("mutableMapOf")
    )).single { it.owner.valueParameters.isEmpty() }

    private val listOf = context.referenceFunctions(CallableId(
        FqName("kotlin.collections"),
        null,
        Name.identifier("listOf")
    )).single { it.owner.valueParameters.singleOrNull()?.isVararg == true }

    private val pair = context.referenceClass(ClassId(FqName("kotlin"), FqName("Pair"), false))!!
    private val triple = context.referenceClass(ClassId(FqName("kotlin"), FqName("Triple"), false))!!

    private val memoizeAnnotation = FqName("com.sschr15.aoc.annotations.Memoize")

    private fun IrPluginContext.keyFor(params: List): IrType = when (params.size) {
        1 -> params.single().type
        2 -> pair.typeWith(params.map { it.type })
        3 -> triple.typeWith(params.map { it.type })
        else -> irBuiltIns.listClass.typeWith(irBuiltIns.anyType)
    }

    fun IrPluginContext.keyFor(declaration: IrFunction): IrType {
        val params = declaration.valueParameters.toMutableList()
        if (declaration.extensionReceiverParameter != null) {
            params.add(0, declaration.extensionReceiverParameter!!)
        }
        return keyFor(params)
    }

    override fun visitFunction(declaration: IrFunction): IrStatement {
        if (declaration.annotations.none { it.isAnnotationWithEqualFqName(memoizeAnnotation) })
            return super.visitFunction(declaration)

        val memoizeMap = context.irFactory.createProperty(
            name = Name.identifier("\$memoized-${declaration.name}-map"),
            origin = memoizedOrigin,
            visibility = DescriptorVisibilities.PRIVATE,
            modality = Modality.FINAL,
            symbol = IrPropertySymbolImpl(),
            isVar = false,
            isConst = false,
            isLateinit = false,
            isDelegated = false,
            startOffset = UNDEFINED_OFFSET,
            endOffset = UNDEFINED_OFFSET,
        )

        var parent = declaration.parent
        while (parent !is IrDeclarationContainer) {
            if (parent !is IrDeclaration) {
                error(parent::class.allSupertypes + "\n" + parent.dump())
            }
            parent = parent.parent
        }
        parent.addChild(memoizeMap)

        val memoizeMapField = memoizeMap.addBackingField {
            type = context.irBuiltIns.mutableMapClass.typeWith(
                context.keyFor(declaration),
                declaration.returnType
            )
            isStatic = parent is IrFile || declaration.isStatic
        }
        memoizeMapField.initializer = context.irFactory.createExpressionBody(
            expression = context.irBuiltIns.createIrBuilder(memoizeMap.symbol).run {
                irCall(mutableMapOf).apply {
                    type = memoizeMapField.type
                    putTypeArgument(0, [email protected](declaration))
                    putTypeArgument(1, declaration.returnType)
                }
            },
            startOffset = UNDEFINED_OFFSET,
            endOffset = UNDEFINED_OFFSET
        )

        val body = declaration.body ?: error("Memoized function must have a body")

        body.transformChildrenVoid(object : IrElementTransformerVoid() {
            override fun visitReturn(expression: IrReturn): IrExpression {
                return context.irBuiltIns.createIrBuilder(expression.returnTargetSymbol).irBlock {
                    val value = createTmpVariable(expression.value)
                    +irCall(mapPut).apply {
                        dispatchReceiver = irGetField(declaration.dispatchReceiverParameter?.let(::irGet), memoizeMapField)
                        putValueArgument(0, createKeyFor(declaration))
                        putValueArgument(1, irGet(value))
                    }
                    +irReturn(irGet(value))
                }
            }
        })

        val statements = body.statements.toMutableList()

        val toAdd = context.irBuiltIns.createIrBuilder(declaration.symbol).irBlock {
            val check = createTmpVariable(irCall(mapGet).apply {
                dispatchReceiver = irGetField(declaration.dispatchReceiverParameter?.let(::irGet), memoizeMapField)
                putValueArgument(0, createKeyFor(declaration))
            })
            +irIfThen(
                condition = irNot(irEqualsNull(irGet(check))),
                thenPart = irReturn(irGet(check))
            )
        }
        statements.add(0, toAdd)

        declaration.body = context.irFactory.createBlockBody(body.startOffset, body.endOffset, statements)

        return declaration
    }

    fun IrBuilderWithScope.createKeyFor(declaration: IrFunction): IrExpression {
        val params = declaration.valueParameters.toMutableList()
        if (declaration.extensionReceiverParameter != null) {
            params.add(0, declaration.extensionReceiverParameter!!)
        }
        return createKeyFor(params)
    }

    @OptIn(UnsafeDuringIrConstructionAPI::class)
    private fun IrBuilderWithScope.createKeyFor(params: List): IrExpression = when (params.size) {
        1 -> irGet(params.single())
        2 -> irCall(pair.constructors.single()).apply {
            putValueArgument(0, irGet(params[0]))
            putValueArgument(1, irGet(params[1]))
        }
        3 -> irCall(triple.constructors.single()).apply {
            putValueArgument(0, irGet(params[0]))
            putValueArgument(1, irGet(params[1]))
            putValueArgument(2, irGet(params[2]))
        }
        else -> irCall(listOf).apply {
            putValueArgument(0, irVararg(context.irBuiltIns.anyType, params.map { irGet(it) }))
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy