All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ufoss.kotysa.Field.kt Maven / Gradle / Ivy

/*
 * This is free and unencumbered software released into the public domain, following 
 */

package org.ufoss.kotysa

import org.ufoss.kotysa.columns.TsvectorColumn
import org.ufoss.kotysa.postgresql.Tsquery
import java.math.BigDecimal
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KMutableProperty1
import kotlin.reflect.KParameter
import kotlin.reflect.full.memberFunctions
import kotlin.reflect.full.primaryConstructor

public sealed interface Field {
    public val fieldNames: List
    public var alias: String?
    public val builder: (RowImpl) -> T
}

internal sealed class AbstractField : Field {
    final override var alias: String? = null
}

public enum class FieldClassifier {
    NONE, DISTINCT, MAX, MIN
}

public interface FieldNotNull : Field

public interface FieldNullable : Field

internal class CountField internal constructor(
    properties: DefaultSqlClientCommon.Properties,
    column: Column?,
) : AbstractField(), FieldNotNull {
    override val fieldNames: List =
        listOf("COUNT(${column?.getFieldName(properties.tables.allColumns, properties.tables.dbType) ?: "*"})")
    
    override val builder: (RowImpl) -> Long = { row -> row.getAndIncrement(Long::class.javaObjectType)!! }
}

internal class ColumnField internal constructor(
    properties: DefaultSqlClientCommon.Properties,
    column: Column,
    classifier: FieldClassifier,
) : AbstractField(), FieldNullable {
    override val fieldNames: List = when (classifier) {
        FieldClassifier.NONE -> listOf(column.getFieldName(properties.tables.allColumns, properties.tables.dbType))
        FieldClassifier.DISTINCT -> listOf("DISTINCT ${column.getFieldName(properties.tables.allColumns, properties.tables.dbType)}")
        FieldClassifier.MAX -> listOf("MAX(${column.getFieldName(properties.tables.allColumns, properties.tables.dbType)})")
        FieldClassifier.MIN -> listOf("MIN(${column.getFieldName(properties.tables.allColumns, properties.tables.dbType)})")
    }
    override val builder: (RowImpl) -> U? = { row -> row.getAndIncrement(column, properties) }
}

internal class AvgField internal constructor(
    properties: DefaultSqlClientCommon.Properties,
    column: Column,
) : AbstractField(), FieldNotNull {
    override val fieldNames: List = listOf("AVG(${column.getFieldName(
        properties.tables.allColumns,
        properties.tables.dbType
    )})")
    
    override val builder: (RowImpl) -> BigDecimal = { row ->
        when {
            properties.tables.dbType == DbType.H2 && properties.module == Module.R2DBC ->
                row.getAndIncrement(Double::class.javaObjectType)!!.toBigDecimal()
            // fixme : remove test below when Spring uses r2dbc 0.9+
            properties.tables.dbType == DbType.H2 && properties.module == Module.SPRING_R2DBC ->
                row.getAndIncrement(Int::class.javaObjectType)!!.toBigDecimal()
            else -> row.getAndIncrement(BigDecimal::class.javaObjectType)!!
        }
    }
}

internal class LongSumField internal constructor(
    properties: DefaultSqlClientCommon.Properties,
    column: Column,
) : AbstractField(), FieldNotNull {
    override val fieldNames: List = listOf("SUM(${column.getFieldName(
        properties.tables.allColumns,
        properties.tables.dbType
    )})")

    override val builder: (RowImpl) -> Long = { row ->
        when {
            properties.tables.dbType == DbType.MYSQL && properties.dbAccessType == DbAccessType.R2DBC ->
                row.getAndIncrement(BigDecimal::class.javaObjectType)!!.toLong()
            else -> row.getAndIncrement(Long::class.javaObjectType)!!
        }
    }
}

internal class TableField internal constructor(
    availableColumns: Map, KotysaColumn<*, *>>,
    availableTables: Map, KotysaTable<*>>,
    internal val table: AbstractTable,
    dbType: DbType,
) : AbstractField() {

    override val fieldNames: List =
        table.kotysaColumns
            // tsvector should not be fetched when querying a table
            .filter { column -> column.sqlType != SqlType.TSVECTOR }
            .map { column -> column.getFieldName(availableColumns, dbType) }

    override val builder: (RowImpl) -> T = { row ->
        val kotysaTable = table.getKotysaTable(availableTables)
        val associatedColumns = mutableListOf>()
        val constructor = getTableConstructor(kotysaTable.tableClass)
        val instance = with(constructor!!) {
            val args = mutableMapOf()
            parameters.forEach { param ->
                // get the mapped property with same name
                val column = kotysaTable.dbColumns
                    .firstOrNull { column ->
                    var getterMatch = false
                    val getterName = column.entityGetter.toCallable().name
                    if (getterName.startsWith("get") && getterName.length > 3) {
                        if (getterName.substring(3).equals(param.name!!, ignoreCase = true)) {
                            getterMatch = true
                        }
                    }
                    val matchFound = getterMatch || (getterName == param.name)
                    if (matchFound) {
                        associatedColumns.add(column)
                    }
                    matchFound
                }
                if (column != null) {
                    args[param] =
                        row.getWithOffset(kotysaTable.columns.indexOf(column), column.columnClass.javaObjectType)
                } else {
                    require(param.isOptional) {
                        "Cannot instantiate Table \"${kotysaTable.tableClass.qualifiedName}\"," +
                                "parameter \"${param.name}\" is required and is not mapped to a Column"
                    }
                }
            }
            // invoke constructor
            callBy(args)
        }

        // Then try to invoke var or setter for each unassociated getter
        if (associatedColumns.size < table.kotysaColumns.size) {
            kotysaTable.dbColumns
                .filter { column -> !associatedColumns.contains(column) }
                .forEach { column ->
                    val getter = column.entityGetter
                    if (getter is KMutableProperty1) {
                        getter.set(
                            instance,
                            row.getWithOffset(kotysaTable.columns.indexOf(column), column.columnClass.javaObjectType)
                        )
                        associatedColumns.add(column)
                    } else {
                        val callable = getter.toCallable()
                        if (callable is KFunction
                            && (callable.name.startsWith("get")
                                    || callable.name.startsWith("is"))
                            && callable.name.length > 3
                        ) {
                            // try to find setter
                            val setter = if (callable.name.startsWith("get")) {
                                kotysaTable.tableClass.memberFunctions.firstOrNull { function ->
                                    function.name == callable.name.replaceFirst("g", "s")
                                            && function.parameters.size == 2
                                }
                            } else {
                                // then "is" for Boolean
                                kotysaTable.tableClass.memberFunctions.firstOrNull { function ->
                                    function.name == callable.name.replaceFirst("is", "set")
                                            && function.parameters.size == 2
                                }
                            }
                            if (setter != null) {
                                setter.call(
                                    instance,
                                    row.getWithOffset(
                                        kotysaTable.columns.indexOf(column),
                                        column.columnClass.javaObjectType
                                    )
                                )
                                associatedColumns.add(column)
                            }
                        }
                    }
                }
        }
        // increment row index by the number of selected columns in this table
        row.incrementWithDelayedIndex()
        instance
    }

    private fun getTableConstructor(tableClass: KClass) = with(tableClass) {
        if (primaryConstructor != null) {
            primaryConstructor
        } else {
            var nbParameters = -1
            var mostCompleteConstructor: KFunction? = null
            constructors.forEach { constructor ->
                if (constructor.parameters.size > nbParameters) {
                    nbParameters = constructor.parameters.size
                    mostCompleteConstructor = constructor
                }
            }
            mostCompleteConstructor
        }
    }
}

internal class SubQueryField internal constructor(
    private val subQueryReturn: SqlClientSubQuery.Return,
    override val builder: (RowImpl) -> T?,
    private val parentProperties: DefaultSqlClientCommon.Properties,
) : AbstractField(), FieldNullable {
    override val fieldNames get() = listOf("( ${subQueryReturn.sql(parentProperties)} )")
}

internal class CaseWhenExistsSubQueryField internal constructor(
    private val dbType: DbType,
    private val subQueryReturn: SqlClientSubQuery.Return,
    private val then: U,
    private val elseVal: U,
    private val parentProperties: DefaultSqlClientCommon.Properties,
) : AbstractField(), FieldNotNull {
    override val fieldNames get() = listOf("CASE WHEN\nEXISTS( ${subQueryReturn.sql(parentProperties)} )\n" +
            "THEN ${then.defaultValue(dbType)} ELSE ${elseVal.defaultValue(dbType)}\nEND")
    
    override val builder: (RowImpl) -> U = { row -> row.getAndIncrement(then::class.javaObjectType)!! }
}

internal class StarField internal constructor(
    override val builder: (RowImpl) -> T?,
) : AbstractField(), FieldNullable {
    override val fieldNames: List = listOf("*")
}

internal class FieldDsl(
    properties: DefaultSqlClientSelect.Properties,
    private val dsl: (ValueProvider) -> T
) : AbstractField(), FieldNotNull {
    private val selectDsl = SelectDsl(properties)

    override val fieldNames: List = FieldValueProvider(properties).initialize(dsl)

    override val builder: (RowImpl) -> T = { row ->
        selectDsl.row = row
        dsl(selectDsl)
    }
}

internal class TsRankCdField internal constructor(
    properties: DefaultSqlClientCommon.Properties,
    tsvectorColumn: TsvectorColumn<*>,
    tsquery: Tsquery,
) : AbstractField(), FieldNotNull {
    override val fieldNames: List =
        listOf("ts_rank_cd(${tsvectorColumn.getFieldName(properties.tables.allColumns, properties.tables.dbType)}," +
                "${tsquery.alias})")

    override val builder: (RowImpl) -> Float = { row -> row.getAndIncrement(Float::class.javaObjectType)!! }
}