All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.panteleyev.persistence.DAO.kt Maven / Gradle / Ivy

/*
 * Copyright (c) 2017, Petr Panteleyev 
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 *    list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

package org.panteleyev.persistence

import org.panteleyev.persistence.DAOTypes.BAD_FIELD_TYPE
import org.panteleyev.persistence.annotations.Field
import org.panteleyev.persistence.annotations.ForeignKey
import org.panteleyev.persistence.annotations.Index
import org.panteleyev.persistence.annotations.RecordBuilder
import org.panteleyev.persistence.annotations.Table
import java.beans.IntrospectionException
import java.beans.Introspector
import java.lang.reflect.Constructor
import java.lang.reflect.Method
import java.lang.reflect.ParameterizedType
import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.sql.SQLException
import java.util.ArrayList
import java.util.Arrays
import java.util.HashSet
import java.util.Objects
import java.util.concurrent.ConcurrentHashMap
import javax.sql.DataSource
import kotlin.reflect.KClass

/**
 * Database record.
 */
interface Record {
    /**
     * ID of the record. Default implementation does not provide setter to support
     * immutable objects. Mutable records must override this property as var.
     */
    val id : Int
}

/**
 * Persistence API entry point.
 *
 * @param ds data source
 * @property dataSource
 * @constructor Creates new DAO object with optional data source.
 */
open class DAO (ds: DataSource?){
    private val primaryKeys = ConcurrentHashMap, Int>()
    private val insertSQL = ConcurrentHashMap, String>()
    private val updateSQL = ConcurrentHashMap, String>()
    private val deleteSQL = ConcurrentHashMap, String>()

    private var proxy: DAOProxy? = null

    var dataSource: DataSource? = null
        get() = field
        set(ds) {
            field = ds
            primaryKeys.clear()
            insertSQL.clear()
            deleteSQL.clear()
            proxy = setupProxy()
        }

    init {
        dataSource = ds
        proxy = setupProxy()
    }


    private fun setupProxy(): DAOProxy? {
        // TODO: figure out better way instead of class name check
        if (dataSource != null) {
            val dsClass = dataSource!!.javaClass.name.toLowerCase()

            if (dsClass.contains("mysql")) {
                return MySQLProxy()
            }

            if (dsClass.contains("sqlite")) {
                return SQLiteProxy()
            }

            throw IllegalStateException("Unsupported database type")
        } else {
            return null
        }
    }

    val connection: Connection
        get() = dataSource?.connection?:throw IllegalStateException("Not initialized")

    /**
     * Retrieves record from the database using record [id].
     */
    fun  get(id: Int, clazz: KClass): T? {
        return get(id, clazz.java)
    }

    private fun  get(id: Int, clazz: Class): T? {
        dataSource!!.connection.use { conn ->
            if (!clazz.isAnnotationPresent(Table::class.java)) {
                throw IllegalStateException(NOT_ANNOTATED)
            }
            val ann = clazz.getAnnotation(Table::class.java)

            val tableName = ann.value
            var idName = "id"

            for (method in clazz.methods) {
                val fieldAnn = method.getAnnotation(Field::class.java)
                if (fieldAnn != null && fieldAnn.primaryKey) {
                    idName = fieldAnn.value
                    break
                }
            }

            val sql = "SELECT * FROM $tableName WHERE $idName=?"
            val ps = conn.prepareStatement(sql)
            ps.setInt(1, id)
            val set = ps.executeQuery()

            return if (set.next()) fromSQL(set, clazz) else null
        }
    }


    /**
     * Retrieves all records of the specified [class][clazz].
     */
    fun  getAll(clazz: KClass): List {
        val result = ArrayList()

        dataSource!!.connection.use { conn ->
            if (!clazz.java.isAnnotationPresent(Table::class.java)) {
                throw IllegalStateException(NOT_ANNOTATED)
            }

            val tableName = clazz.java.getAnnotation(Table::class.java).value
            val ps = conn.prepareStatement("SELECT * FROM $tableName")
            val set = ps.executeQuery()
            while (set.next()) {
                result.add(fromSQL(set, clazz.java))
            }
        }

        return result
    }

    /**
     * Retrieves all records of the specified [class][clazz] and fills the [map].
     */
    fun  getAll(clazz: KClass, map: MutableMap) {
        dataSource!!.connection.use { conn ->
            if (!clazz.java.isAnnotationPresent(Table::class.java)) {
                throw IllegalStateException(NOT_ANNOTATED)
            }

            val tableName = clazz.java.getAnnotation(Table::class.java).value
            val ps = conn.prepareStatement("SELECT * FROM $tableName")
            val set = ps.executeQuery()
            while (set.next()) {
                val r = fromSQL(set, clazz.java)
                map.put(r.id, r)
            }
        }
    }

    private fun  fromSQL(set: ResultSet, clazz: Class): T {
        // First try to find @RecordBuilder constructor
        for (constructor in clazz.constructors) {
            if (constructor.isAnnotationPresent(RecordBuilder::class.java)) {
                return fromSQL(set, constructor)
            }
        }

        val result = clazz.newInstance()
        fromSQL(set, result)
        return result
    }

    private fun  fromSQL(set: ResultSet, constructor: Constructor<*>): T {
        val paramCount = constructor.parameterCount

        val paramAnnotations = constructor.parameterAnnotations
        val paramTypes = constructor.parameterTypes
        val params = arrayOfNulls(paramCount)

        for (i in 0..paramCount - 1) {
            val fieldName = Arrays.stream(paramAnnotations[i])
                    .filter { a -> a is Field }
                    .findAny()
                    .map { a -> (a as Field).value }
                    .orElseThrow({ RuntimeException() })

            params[i] = proxy!!.getFieldValue(fieldName, paramTypes[i], set)
        }

        return constructor.newInstance(*params) as T
    }

    private fun fromSQL(set: ResultSet, record: Record) {
        val bi = Introspector.getBeanInfo(record.javaClass)
        val pds = bi.propertyDescriptors
        for (pd in pds) {
            val getter = pd.readMethod
            val setter: Method? = pd.writeMethod

            val getterClass = getter.returnType

            if (setter != null) {
                val fld = getter.getAnnotation(Field::class.java)
                if (fld != null) {
                    setter.invoke(record, proxy!!.getFieldValue(fld.value, getterClass, set))
                }
            }
        }
    }

    private fun getEffectiveType(getter: Method): Class<*> {
        val rType = getter.genericReturnType

        if (rType is ParameterizedType) {
            val actualTypeArguments = rType.actualTypeArguments
            if (actualTypeArguments.size != 1) {
                throw IllegalStateException(BAD_FIELD_TYPE)
            } else {
                return actualTypeArguments[0] as Class<*>
            }
        } else {
            return rType as Class<*>
        }
    }

    /**
     * Creates tables of the [specified classes][tables] according to their annotations.
     */
    fun createTables(tables: List>) {
        if (dataSource == null) {
            throw IllegalStateException("Database not opened")
        }

        try {
            dataSource!!.connection.use { conn ->
                conn.createStatement().use { st ->
                    // Step 1: drop tables in reverse order
                    for (index in tables.indices.reversed()) {
                        val cl = tables[index]
                        if (!cl.java.isAnnotationPresent(Table::class.java)) {
                            throw IllegalStateException(NOT_ANNOTATED)
                        }

                        val table = cl.java.getAnnotation(Table::class.java)
                        st.executeUpdate("DROP TABLE IF EXISTS ${table.value}")
                    }

                    // Step 2: create new tables in natural order
                    for (cl in tables) {
                        val table = cl.java.getAnnotation(Table::class.java)

                        try {
                            val b = StringBuilder("CREATE TABLE IF NOT EXISTS ${table.value} (")

                            val bi = Introspector.getBeanInfo(cl.java)
                            val pds = bi.propertyDescriptors

                            val constraints = mutableListOf()

                            val indexed = HashSet()

                            var first = true
                            for (pd in pds) {
                                val getter = pd.readMethod
                                if (getter != null && getter.isAnnotationPresent(Field::class.java)) {
                                    val fld = getter.getAnnotation(Field::class.java)
                                    val fName = fld.value

                                    val getterType = getEffectiveType(getter)
                                    val typeName = getterType.typeName

                                    if (!first) {
                                        b.append(",")
                                    }
                                    first = false

                                    b.append(fName)
                                            .append(" ")
                                            .append(proxy!!.getColumnString(fld,
                                                    getter.getAnnotation(ForeignKey::class.java), typeName, constraints))

                                    if (getter.isAnnotationPresent(Index::class.java)) {
                                        indexed.add(getter)
                                    }
                                }
                            }

                            if (!constraints.isEmpty()) {
                                b.append(",")
                                b.append(constraints.joinToString(","))
                            }

                            b.append(")")

                            st.executeUpdate(b.toString())

                            // Create indexes
                            for (getter in indexed) {
                                st.executeUpdate(proxy!!.buildIndex(table, getter))
                            }

                        } catch (ex: IntrospectionException) {
                            throw RuntimeException(ex)
                        }

                    }
                }
            }
        } catch (ex: SQLException) {
            throw RuntimeException(ex)
        }

    }

    private fun getInsertSQL(record: Record): String {
        return insertSQL.computeIfAbsent(record.javaClass) { clazz ->
            val b = StringBuilder("INSERT INTO ")

            val table = clazz.getAnnotation(Table::class.java) ?: throw IllegalStateException("Class " + clazz.name + " is not properly annotated")

            b.append(table.value).append(" (")

            var fCount = 0

            val bi = Introspector.getBeanInfo(record.javaClass)
            val pds = bi.propertyDescriptors
            for (pd in pds) {
                val getter = pd.readMethod
                if (getter != null) {
                    val fld = getter.getAnnotation(Field::class.java)
                    if (fld != null) {
                        if (fCount != 0) {
                            b.append(",")
                        }
                        b.append(fld.value)
                        fCount++
                    }
                }
            }

            if (fCount == 0) {
                throw IllegalStateException("No fields")
            }

            b.append(") VALUES (")

            while (fCount != 0) {
                b.append("?")
                if (fCount != 1) {
                    b.append(",")
                }
                fCount--
            }

            b.append(")")
            b.toString()
        }
    }

    private fun getUpdateSQL(record: Record): String {
        return updateSQL.computeIfAbsent(record.javaClass) { clazz ->
            val b = StringBuilder("update ")

            val table = clazz.getAnnotation(Table::class.java) ?: throw IllegalStateException(NOT_ANNOTATED)

            b.append(table.value).append(" set ")

            var fCount = 0

            val bi = Introspector.getBeanInfo(record.javaClass)
            val pds = bi.propertyDescriptors
            for (pd in pds) {
                val getter = pd.readMethod
                if (getter != null) {
                    val fld = getter.getAnnotation(Field::class.java)
                    if (fld != null && !fld.primaryKey) {
                        if (fCount != 0) {
                            b.append(", ")
                        }
                        b.append(fld.value)
                                .append("=?")
                        fCount++
                    }
                }
            }

            if (fCount == 0) {
                throw IllegalStateException("No fields")
            }

            b.append(" WHERE id=?")

            b.toString()
        }
    }

    private fun getDeleteSQL(clazz: Class): String {
        return deleteSQL.computeIfAbsent(clazz) { cl ->
            val table = cl.getAnnotation(Table::class.java) ?: throw IllegalStateException(NOT_ANNOTATED)
            val b = StringBuilder("delete from ${table.value} where ")

            var idName: String? = null

            val bi = Introspector.getBeanInfo(cl)
            val pds = bi.propertyDescriptors
            for (pd in pds) {
                val getter = pd.readMethod
                if (getter != null) {
                    val fld = getter.getAnnotation(Field::class.java)
                    if (fld != null && fld.primaryKey) {
                        idName = fld.value
                        break
                    }
                }
            }

            if (idName == null) {
                throw IllegalStateException(NOT_ANNOTATED)
            }

            b.append(idName)
                    .append("=?")

            b.toString()
        }
    }

    private fun getDeleteSQL(record: Record): String {
        return getDeleteSQL(record.javaClass)
    }

    private fun setData(record: Record, st: PreparedStatement, update: Boolean) {
        val bi = Introspector.getBeanInfo(record.javaClass)
        val pds = bi.propertyDescriptors

        var index = 1
        for (pd in pds) {
            val getter = pd.readMethod
            if (getter != null && getter.isAnnotationPresent(Field::class.java)) {
                // if update skip ID at this point
                val fld = getter.getAnnotation(Field::class.java)
                if (update && fld.primaryKey) {
                    continue
                }

                val value: Any? = getter.invoke(record)
                val getterClass = getter.returnType

                val typeName = getterClass.name
                proxy!!.setFieldData(st, index++, value, typeName)
            }
        }

        if (update) {
            st.setInt(index, record.id)
        }
    }

    private fun getPreparedStatement(record: Record, conn: Connection, update: Boolean): PreparedStatement {
        val sql = if (update) getUpdateSQL(record) else getInsertSQL(record)
        val st = conn.prepareStatement(sql)
        setData(record, st, update)
        return st
    }

    private fun getDeleteStatement(record: Record, conn: Connection): PreparedStatement {
        val st = conn.prepareStatement(getDeleteSQL(record))
        st.setInt(1, record.id)
        return st
    }

    private fun getDeleteStatement(id: Int?, clazz: Class, conn: Connection): PreparedStatement {
        val st = conn.prepareStatement(getDeleteSQL(clazz))
        st.setInt(1, id!!)
        return st
    }

    /**
     * Pre-loads necessary information about specified [list of tables][tables] from the just opened database.
     * This method must be called prior to any other database operations. Otherwise primary keys may be generated
     * incorrectly.
     */
    fun preload(tables: Collection>) {
        // load primary key max values
        tables.filter { it.java.isAnnotationPresent(Table::class.java) }
                .forEach {
                    val a = it.java.getAnnotation(Table::class.java)
                    val id = getIdMaxValue(a.value)
                    primaryKeys.put(it.java, id)
                }
    }

    /**
     * Returns next available primary key value for the specified [class][clazz]. This method is thread safe.
     */
    fun generatePrimaryKey(clazz: KClass): Int {
        return primaryKeys.compute(clazz.java) { _, v -> if (v == null) 1 else v + 1 }!!
    }

    private fun getIdMaxValue(tableName: String): Int {
        dataSource!!.connection.use { conn ->
            val st = conn.prepareStatement("SELECT id FROM $tableName ORDER BY id DESC")
            val rs = st.executeQuery()
            if (rs.next()) {
                return rs.getInt(1)
            } else {
                return 0
            }
        }
    }

    /**
     * Inserts new [record] with predefined id into the database. No attempt to generate
     * new id is made. Calling code must ensure that predefined id is unique.
     */
    fun  insert(record: T): T? {
        Objects.requireNonNull(record.id)

        dataSource!!.connection.use {
            getPreparedStatement(record, it, false).use {
                it.executeUpdate()
                return get(record.id, record.javaClass)
            }
        }
    }

    /**
     * Inserts multiple [records] with predefined id using batch insert. No attempt to generate
     * new id is made. Calling code must ensure that predefined id is unique for all records.
     *
     * Supplied records are divided to batches of the specified [size]. To avoid memory issues [size] of the batch
     * must be tuned appropriately.
     */
    fun  insert(size: Int, records: List) {
        if (size < 1) {
            throw IllegalArgumentException("Batch size must be >= 1")
        }

        if (!records.isEmpty()) {
            val sql = getInsertSQL(records[0])

            connection.use { conn ->
                conn.prepareStatement(sql).use { st ->
                    var count = 0

                    for (r in records) {
                        setData(r, st, false)
                        st.addBatch()

                        if (++count % size == 0) {
                            st.executeBatch()
                        }
                    }

                    st.executeBatch()
                }
            }
        }
    }

    /**
     * Updates [record] in the database.
     */
    fun  update(record: T): T? {
        Objects.requireNonNull(record.id)

        dataSource!!.connection.use {
            getPreparedStatement(record, it, true).use {
                it.executeUpdate()
                return get(record.id, record.javaClass)!!
            }
        }

    }

    /**
     * Deletes [record] from the database.
     */
    fun delete(record: Record) {
        dataSource!!.connection.use {
            conn -> getDeleteStatement(record, conn).use {
                ps -> ps.executeUpdate()
            }
        }
    }

    /**
     * Deletes record of the specified class from the database by [id].
     */
    fun delete(id: Int, clazz: KClass) {
        dataSource!!.connection.use {
            conn -> getDeleteStatement(id, clazz.java, conn).use {
                ps -> ps.executeUpdate()
            }
        }
    }

    companion object {
        private val NOT_ANNOTATED = "Class is not properly annotated"
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy