All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlinx.dataframe.plugin.interpret.kt Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2010-2022 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlinx.dataframe.plugin

import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.checkers.toClassLikeSymbol
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.extensions.Marker
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
import org.jetbrains.kotlin.fir.declarations.findArgumentByName
import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId
import org.jetbrains.kotlin.fir.declarations.hasAnnotation
import org.jetbrains.kotlin.fir.declarations.utils.isLocal
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
import org.jetbrains.kotlin.fir.expressions.FirCallableReferenceAccess
import org.jetbrains.kotlin.fir.expressions.FirErrorExpression
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirPropertyAccessExpression
import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier
import org.jetbrains.kotlin.fir.expressions.FirReturnExpression
import org.jetbrains.kotlin.fir.expressions.FirThisReceiverExpression
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
import org.jetbrains.kotlin.fir.expressions.arguments
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
import org.jetbrains.kotlin.fir.references.FirResolvedCallableReference
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.resolved
import org.jetbrains.kotlin.fir.references.symbol
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.scopes.collectAllProperties
import org.jetbrains.kotlin.fir.scopes.getProperties
import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope
import org.jetbrains.kotlin.fir.symbols.impl.FirEnumEntrySymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirTypeAliasSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirValueParameterSymbol
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
import org.jetbrains.kotlin.fir.types.ConeIntersectionType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeStarProjection
import org.jetbrains.kotlin.fir.types.ConeTypeParameterType
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.isNullableString
import org.jetbrains.kotlin.fir.types.isPrimitiveOrNullablePrimitive
import org.jetbrains.kotlin.fir.types.isStarProjection
import org.jetbrains.kotlin.fir.types.isString
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.fir.types.returnType
import org.jetbrains.kotlin.fir.types.toConeTypeProjection
import org.jetbrains.kotlin.fir.types.toRegularClassSymbol
import org.jetbrains.kotlin.fir.types.toSymbol
import org.jetbrains.kotlin.fir.types.type
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addToStdlib.runIf
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.KPropertyApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation

fun  KotlinTypeFacade.interpret(
    functionCall: FirFunctionCall,
    processor: Interpreter,
    additionalArguments: Map> = emptyMap(),
    reporter: InterpretationErrorReporter,
): Interpreter.Success? {
    val refinedArguments: RefinedArguments = functionCall.collectArgumentExpressions()

    val defaultArguments = processor.expectedArguments.filter { it.defaultValue is Present }.map { it.name }.toSet()
    val actualValueArguments = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
    val conflictingKeys = additionalArguments.keys intersect actualValueArguments.keys
    if (conflictingKeys.isNotEmpty()) {
        if (isTest) {
            interpretationFrameworkError("Conflicting keys: $conflictingKeys")
        }
        return null
    }
    val expectedArgsMap = processor.expectedArguments
        .associateBy { it.name }.toSortedMap().minus(additionalArguments.keys)

    val typeArguments = buildMap {
        functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
            val key = "typeArg$index"
            val lens = expectedArgsMap[key]?.lens ?: return@forEachIndexed
            val value: Any = if (lens == Interpreter.Id) {
                firTypeProjection.toConeTypeProjection()
            } else {
                val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
                if (type is ConeIntersectionType) return@forEachIndexed
                Marker(type)
            }
            put(key, Interpreter.Success(value))
        }
    }

    val unexpectedArguments = (expectedArgsMap.keys - defaultArguments) != (actualValueArguments.keys + typeArguments.keys - defaultArguments)
    if (unexpectedArguments) {
        if (isTest) {
            val message = buildString {
                appendLine("ERROR: Different set of arguments")
                appendLine("Implementation class: $processor")
                appendLine("Not found in actual: ${expectedArgsMap.keys - actualValueArguments.keys}")
                val diff = actualValueArguments.keys - expectedArgsMap.keys
                appendLine("Passed, but not expected: ${diff}")
                appendLine("add arguments to an interpeter:")
                appendLine(diff.map { actualValueArguments[it] })
            }
            interpretationFrameworkError(message)
        }
        return null
    }

    val arguments = mutableMapOf>()
    arguments += additionalArguments
    arguments += typeArguments
    val interpretationResults = refinedArguments.refinedArguments.mapNotNull {
        val name = it.name.identifier
        val expectedArgument = expectedArgsMap[name] ?: error("$processor $name")
        val expectedReturnType = expectedArgument.klass
        val value: Interpreter.Success? = when (expectedArgument.lens) {
            is Interpreter.Value -> {
                when (val expression = it.expression) {
                    is FirLiteralExpression -> Interpreter.Success(expression.value!!)
                    is FirVarargArgumentsExpression -> {
                        val args = expression.arguments.map {
                            when (it) {
                                is FirLiteralExpression -> it.value
                                is FirCallableReferenceAccess -> {
                                    toKPropertyApproximation(it, session)
                                }
                                is FirFunctionCall -> {
                                    it.loadInterpreter()?.let { processor ->
                                        interpret(it, processor, emptyMap(), reporter)
                                    }
                                }
                                else -> null
                            }

                        }
                        Interpreter.Success(args)
                    }

                    is FirFunctionCall -> {
                        val interpreter = expression.loadInterpreter()
                        if (interpreter == null) {
                            // if the plugin already transformed call, its original form is the last expression of .let { }
                            val argument = expression.arguments.getOrNull(0)
                            val last = (argument as? FirAnonymousFunctionExpression)?.anonymousFunction?.body?.statements?.lastOrNull()
                            val call = (last as? FirReturnExpression)?.result as? FirFunctionCall
                            call?.loadInterpreter()?.let {
                                interpret(call, it, emptyMap(), reporter)
                            }
                        } else {
                            interpreter.let {
                                val result = interpret(expression, interpreter, emptyMap(), reporter)
                                result
                            }
                        }
                    }

                    is FirPropertyAccessExpression -> {
                        (expression.calleeReference as? FirResolvedNamedReference)?.let {
                            val symbol = it.resolvedSymbol
                            val firPropertySymbol = symbol as? FirPropertySymbol
                            val literalInitializer = runIf(firPropertySymbol?.resolvedReturnType?.canHaveLiteralInitializer == true) {
                                firPropertySymbol?.resolvedInitializer as? FirLiteralExpression
                            }
                            if (symbol is FirEnumEntrySymbol) {
                                Interpreter.Success(
                                    DataFrameCallableId(
                                        packageName = symbol.callableId.packageName.asString(),
                                        className = symbol.callableId.className!!.asString(),
                                        callableName = symbol.callableId.callableName.asString()
                                    )
                                )
                            } else if (literalInitializer != null) {
                                Interpreter.Success(literalInitializer.value)
                            } else {
                                Interpreter.Success(columnWithPathApproximations(expression))
                            }
                        }
                    }

                    is FirCallableReferenceAccess -> {
                        Interpreter.Success(toKPropertyApproximation(expression, session))
                    }

                    is FirAnonymousFunctionExpression -> {
                        val result = (expression.anonymousFunction.body?.statements?.lastOrNull() as? FirReturnExpression)?.result
                        val col: Any? = when (result) {
                            is FirPropertyAccessExpression -> {
                                columnWithPathApproximations(result)
                            }

                            is FirFunctionCall -> {
                                val interpreter = result.loadInterpreter()
                                if (interpreter == null) {
                                    reporter.reportInterpretationError(result, "Cannot load interpreter")
                                }
                                interpreter?.let {
                                    val value = interpret(result, interpreter, reporter = reporter)?.value
                                    value
                                }
                            }

                            is FirErrorExpression -> null

                            else -> null
                        }
                        col?.let { Interpreter.Success(it) }
                    }

                    else -> null
                }
            }

            is Interpreter.ReturnType -> {
                val returnType = it.expression.resolvedType.returnType(session)
                Interpreter.Success(Marker(returnType))
            }

            is Interpreter.Dsl -> {
                { receiver: Any, dslArguments: Map> ->
                    val map = mapOf("dsl" to Interpreter.Success(receiver)) + dslArguments
                    (it.expression as FirAnonymousFunctionExpression)
                        .anonymousFunction.body!!
                        .statements.filterIsInstance()
                        .forEach { call ->
                            val schemaProcessor = call.loadInterpreter() ?: return@forEach
                            interpret(
                                call,
                                schemaProcessor,
                                map,
                                reporter
                            )
                        }
                }.let { Interpreter.Success(it) }
            }

            is Interpreter.Schema -> {
                assert(expectedReturnType.toString() == PluginDataFrameSchema::class.qualifiedName!!) {
                    "'$name' should be ${PluginDataFrameSchema::class.qualifiedName!!}, but plugin expect $expectedReturnType"
                }

                val objectWithSchema = it.expression.getSchema()
                if (objectWithSchema == null) {
                    reporter.doNotReportInterpretationError()
                    null
                } else {
                    val arg = objectWithSchema.schemaArg
                    val schemaTypeArg = (objectWithSchema.typeRef as ConeClassLikeType).typeArguments[arg]
                    val schema = pluginDataFrameSchema(schemaTypeArg)
                    Interpreter.Success(schema)
                }
            }

            is Interpreter.Id -> {
                Interpreter.Success(it.expression)
            }
        }
        value?.let { value1 -> it.name.identifier to value1 }
    }

    return if (interpretationResults.size == refinedArguments.refinedArguments.size) {
        arguments.putAll(interpretationResults)
        when (val res = processor.interpret(arguments, this)) {
            is Interpreter.Success -> res
            is Interpreter.Error -> {
                reporter.reportInterpretationError(functionCall, res.message ?: "")
                return null
            }
        }
    } else {
        return null
    }
}

fun interpretationFrameworkError(message: String): Nothing = throw InterpretationFrameworkError(message)

class InterpretationFrameworkError(message: String) : Error(message)

interface InterpretationErrorReporter {
    val errorReported: Boolean
    fun reportInterpretationError(call: FirFunctionCall, message: String)

    fun doNotReportInterpretationError()

    companion object {
        val DEFAULT = object : InterpretationErrorReporter {
            override val errorReported: Boolean = false
            override fun reportInterpretationError(call: FirFunctionCall, message: String) {

            }

            override fun doNotReportInterpretationError() = Unit
        }
    }
}

fun KotlinTypeFacade.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): PluginDataFrameSchema {
    val schema = if (schemaTypeArg.isStarProjection) {
        PluginDataFrameSchema.EMPTY
    } else {
        val coneClassLikeType = schemaTypeArg.type as? ConeClassLikeType ?: return PluginDataFrameSchema.EMPTY
        pluginDataFrameSchema(coneClassLikeType)
    }
    return schema
}

fun KotlinTypeFacade.pluginDataFrameSchema(coneClassLikeType: ConeClassLikeType): PluginDataFrameSchema {
    val symbol = coneClassLikeType.toSymbol(session) as? FirRegularClassSymbol ?: return PluginDataFrameSchema.EMPTY
    val declarationSymbols = if (symbol.isLocal && symbol.resolvedSuperTypes.firstOrNull() != session.builtinTypes.anyType.type) {
        val rootSchemaSymbol = symbol.resolvedSuperTypes.first().toSymbol(session) as? FirRegularClassSymbol
        rootSchemaSymbol?.declaredMemberScope(session, FirResolvePhase.DECLARATIONS)
    } else {
        symbol.declaredMemberScope(session, FirResolvePhase.DECLARATIONS)
    }.let { scope ->
        val names = scope?.getCallableNames() ?: emptySet()
        names.flatMap { scope?.getProperties(it) ?: emptyList() }
    }

    val mapping = symbol.typeParameterSymbols
        .mapIndexed { i, symbol -> symbol to coneClassLikeType.typeArguments[i] }
        .toMap()

    var propertySymbols = declarationSymbols.filterIsInstance()
    val annotations = propertySymbols.mapNotNull {
        val orderArgument = it.getAnnotationByClassId(
            Names.ORDER_ANNOTATION,
            session
        )?.argumentMapping?.mapping?.get(Names.ORDER_ARGUMENT)
        (orderArgument as? FirLiteralExpression)?.value as? Int
    }
    if (propertySymbols.size == annotations.size) {
        propertySymbols = propertySymbols.zip(annotations).sortedBy { it.second }.map { it.first }
    }
    val columns = propertySymbols.mapNotNull { propertySymbol ->
        columnOf(propertySymbol, mapping)
    }

    return PluginDataFrameSchema(columns)
}

private fun KotlinTypeFacade.columnWithPathApproximations(result: FirPropertyAccessExpression): ColumnsResolver {
    return result.resolvedType.let {
        val column = when (it.classId) {
            Names.DATA_COLUMN_CLASS_ID -> {
                val type = when (val arg = it.typeArguments.single()) {
                    is ConeStarProjection -> session.builtinTypes.nullableAnyType.type
                    else -> arg as ConeClassLikeType
                }
                SimpleDataColumn(f(result), Marker(type))
            }
            Names.COLUM_GROUP_CLASS_ID -> {
                val arg = it.typeArguments.single()
                val path = f(result)
                SimpleColumnGroup(path, pluginDataFrameSchema(arg).columns())
            }
            else -> return object : ColumnsResolver {
                override fun resolve(df: PluginDataFrameSchema): List {
                    return emptyList()
                }
            }
        }
        SingleColumnApproximation(
            ColumnWithPathApproximation(
                path = ColumnPathApproximation(path(result)),
                column
            )
        )
    }
}

private fun KotlinTypeFacade.columnOf(it: FirPropertySymbol, mapping: Map): SimpleCol? {
    val annotation = it.getAnnotationByClassId(Names.COLUMN_NAME_ANNOTATION, session)
    val columnName = (annotation?.argumentMapping?.mapping?.get(Names.COLUMN_NAME_ARGUMENT) as? FirLiteralExpression)?.value as? String
    val name = columnName ?: it.name.identifier
    return when {
        shouldBeConvertedToFrameColumn(it) -> {
            val nestedColumns = it.resolvedReturnType.typeArguments[0].type
                ?.toRegularClassSymbol(session)
                ?.declaredMemberScope(session, FirResolvePhase.DECLARATIONS)
                ?.collectAllProperties()
                ?.filterIsInstance()
                ?.mapNotNull { columnOf(it, mapping) }
                ?: emptyList()

            SimpleFrameColumn(name, nestedColumns)
        }

        shouldBeConvertedToColumnGroup(it) -> {
            val type = if (isDataRow(it)) it.resolvedReturnType.typeArguments[0].type!! else it.resolvedReturnType
            val nestedColumns = type
                .toRegularClassSymbol(session)
                ?.declaredMemberScope(session, FirResolvePhase.DECLARATIONS)
                ?.collectAllProperties()
                ?.filterIsInstance()
                ?.mapNotNull { columnOf(it, mapping) }
                ?: emptyList()
            SimpleColumnGroup(name, nestedColumns)
        }

        else -> {
            val type = when (val type = it.resolvedReturnType) {
                is ConeTypeParameterType -> {
                    val projection = mapping[type.lookupTag.typeParameterSymbol]
                    if (projection is ConeStarProjection) {
                        type.lookupTag.typeParameterSymbol.resolvedBounds.singleOrNull()?.type
                    } else {
                        projection as? ConeKotlinType
                    }
                }

                else -> type
            }
            type?.let { type ->
                SimpleDataColumn(name, TypeApproximation(type))
            }
        }
    }
}

private fun KotlinTypeFacade.shouldBeConvertedToColumnGroup(it: FirPropertySymbol) =
    isDataRow(it) ||
        it.resolvedReturnType.toRegularClassSymbol(session)?.hasAnnotation(Names.DATA_SCHEMA_CLASS_ID, session) == true

private fun isDataRow(it: FirPropertySymbol) =
    it.resolvedReturnType.classId == Names.DATA_ROW_CLASS_ID

private fun KotlinTypeFacade.shouldBeConvertedToFrameColumn(it: FirPropertySymbol) =
    isDataFrame(it) ||
        (it.resolvedReturnType.classId == Names.LIST &&
            it.resolvedReturnType.typeArguments[0].type?.toRegularClassSymbol(session)?.hasAnnotation(Names.DATA_SCHEMA_CLASS_ID, session) == true)

private fun isDataFrame(it: FirPropertySymbol) =
    it.resolvedReturnType.classId == Names.DF_CLASS_ID

fun path(propertyAccessExpression: FirPropertyAccessExpression): List {
    val colName = f(propertyAccessExpression)
    val typeRef = propertyAccessExpression.dispatchReceiver?.resolvedType
    val joinDsl = ClassId(FqName("org.jetbrains.kotlinx.dataframe.api"), Name.identifier("JoinDsl"))
    if (typeRef?.classId?.equals(joinDsl) == true && colName == "right") {
        return emptyList()
    }
    return when (val explicitReceiver = propertyAccessExpression.explicitReceiver) {
        null, is FirThisReceiverExpression -> listOf(colName)
        else -> {
            val propertyAccess = explicitReceiver as FirPropertyAccessExpression
            if (propertyAccess.calleeReference.symbol is FirValueParameterSymbol) {
                listOf(colName)
            } else {
                path(propertyAccess) + colName
            }
        }
    }
}

fun f(propertyAccessExpression: FirPropertyAccessExpression): String {
    return propertyAccessExpression.calleeReference.resolved!!.name.identifier
}

private fun KotlinTypeFacade.toKPropertyApproximation(
    firCallableReferenceAccess: FirCallableReferenceAccess,
    session: FirSession
): KPropertyApproximation {
    val propertyName = firCallableReferenceAccess.calleeReference.name.identifier
    return (firCallableReferenceAccess.calleeReference as FirResolvedCallableReference).let {
        val symbol = it.toResolvedCallableSymbol()!!
        val columnName = symbol.annotations
            .find { it.fqName(session)!!.asString() == ColumnName::class.qualifiedName!! }
            ?.let {
                (it.argumentMapping.mapping[Name.identifier(ColumnName::name.name)] as FirLiteralExpression).value as String
            }
        val kotlinType = symbol.resolvedReturnTypeRef.type

        val type1 = Marker(kotlinType)
        KPropertyApproximation(columnName ?: propertyName, type1)
    }
}

internal fun FirFunctionCall.collectArgumentExpressions(): RefinedArguments {
    val refinedArgument = mutableListOf()

    val parameterName = Name.identifier("receiver")
    explicitReceiver?.let {
        if (it is FirResolvedQualifier && it.resolvedToCompanionObject) {
            return@let
        }
        refinedArgument += RefinedArgument(parameterName, it)
    }

    (argumentList as FirResolvedArgumentList).mapping.forEach { (expression, parameter) ->
        refinedArgument += RefinedArgument(parameter.name, expression)
    }
    return RefinedArguments(refinedArgument)
}

internal val KotlinTypeFacade.getSchema: FirExpression.() -> ObjectWithSchema? get() = { getSchema(session) }

internal fun FirExpression.getSchema(session: FirSession): ObjectWithSchema? {
    return resolvedType.toSymbol(session)?.let {
        val (typeRef: ConeKotlinType, symbol) = if (it is FirTypeAliasSymbol) {
            it.resolvedExpandedTypeRef.coneType to it.resolvedExpandedTypeRef.toClassLikeSymbol(session)!!
        } else {
            resolvedType to it
        }
        symbol.annotations.firstNotNullOfOrNull {
            runIf(it.fqName(session)?.asString() == HasSchema::class.qualifiedName!!) {
                val argumentName = Name.identifier(HasSchema::schemaArg.name)
                val schemaArg = (it.findArgumentByName(argumentName) as FirLiteralExpression).value
                ObjectWithSchema((schemaArg as Number).toInt(), typeRef)
            }
        } ?: error("Annotate ${symbol} with @HasSchema")
    }
}

internal class ObjectWithSchema(val schemaArg: Int, val typeRef: ConeKotlinType)

internal val ConeKotlinType.canHaveLiteralInitializer get() = isPrimitiveOrNullablePrimitive || isString || isNullableString




© 2015 - 2024 Weber Informatics LLC | Privacy Policy