All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.firefly.kotlin.ext.db.AsyncCoroutineContextTransactionalManager.kt Maven / Gradle / Ivy

package com.firefly.kotlin.ext.db

import com.firefly.db.RecordNotFound
import com.firefly.db.SQLClient
import com.firefly.db.SQLConnection
import com.firefly.kotlin.ext.common.CoroutineLocalContext
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.future.await
import kotlinx.coroutines.withTimeout
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

/**
 * Manage transaction in the HTTP request lifecycle.
 *
 * @author Pengtao Qiu
 */
class AsyncCoroutineContextTransactionalManager(val sqlClient: SQLClient) : AsyncTransactionalManager {

    val currentConnKey = "_currentConnKeyKt"
    val rollbackOnlyKey = "_rollbackOnlyKeyKt"
    val transactionCountKey = "_transactionCountKeyKt"

    override suspend fun getConnection(time: Long, unit: TimeUnit): SQLConnection = withTimeout(unit.toMillis(time)) {
        sqlClient.connection.await()
    }

    override suspend fun getCurrentConnection(time: Long, unit: TimeUnit): SQLConnection? =
        withTimeout(unit.toMillis(time)) {
            CoroutineLocalContext.getAttr>(currentConnKey)?.await()
        }

    override suspend fun  execSQL(time: Long, unit: TimeUnit, handler: suspend (conn: SQLConnection) -> T): T {
        if (isInTransaction()) {
            val conn = getCurrentConnection() ?: throw IllegalStateException("The transaction is not begun")
            return try {
                withTimeout(unit.toMillis(time)) { handler.invoke(conn) }
            } catch (e: RecordNotFound) {
                sysLogger.warn("execute SQL exception. record not found", e)
                throw e
            } catch (e: TimeoutCancellationException) {
                sysLogger.error("execute SQL exception. timeout", e)
                setRollback(true)
                throw e
            } catch (e: Exception) {
                sysLogger.error("execute SQL exception", e)
                setRollback(true)
                throw e
            }
        } else {
            return getConnection().safeUse {
                withTimeout(unit.toMillis(time)) { handler.invoke(it) }
            }
        }
    }

    override suspend fun beginTransaction(): Boolean {
        val count = increaseTransactionCount()
        return if (count == 1) {
            val conn = createConnectionIfEmpty().await()
            conn.beginTransaction().await()
        } else {
            false
        }
    }

    override suspend fun rollbackAndEndTransaction() {
        val count = decreaseTransactionCount()
        if (count <= 0) {
            val rollback = isRollback()
            sysLogger.warn("the transaction rollback -> $rollback, $count")
            if (rollback) {
                getCurrentConnection()?.rollbackAndEndTransaction()?.await()
            } else {
                getCurrentConnection()?.commitAndEndTransaction()?.await()
            }
        }
    }

    override suspend fun commitAndEndTransaction() {
        val count = decreaseTransactionCount()
        if (count <= 0) {
            getCurrentConnection()?.commitAndEndTransaction()?.await()
        }
    }

    private fun createConnectionIfEmpty(): CompletableFuture {
        return CoroutineLocalContext.computeIfAbsent(currentConnKey) { sqlClient.connection }
            ?: throw IllegalStateException("Not in coroutine context")
    }

    private fun isInTransaction(): Boolean {
        val count = getTransactionCount()
        return count != null && count > 0
    }

    private fun isRollback(): Boolean {
        val isRollback = CoroutineLocalContext.getAttr(rollbackOnlyKey)
        return isRollback == null || isRollback
    }

    private fun setRollback(rollback: Boolean) {
        CoroutineLocalContext.setAttr(rollbackOnlyKey, rollback)
    }

    private fun increaseTransactionCount(): Int? {
        val count = CoroutineLocalContext.computeIfAbsent(transactionCountKey) { AtomicInteger() }
            ?: throw IllegalStateException("Not in coroutine context")
        return count.incrementAndGet()
    }

    private fun decreaseTransactionCount(): Int {
        val count = CoroutineLocalContext.getAttr(transactionCountKey)
            ?: throw IllegalStateException("The transaction is not begun")
        return count.decrementAndGet()
    }

    private fun getTransactionCount(): Int? {
        return CoroutineLocalContext.getAttr(transactionCountKey)?.get()
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy