All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.com.squareup.sqldelight.Transacter.kt Maven / Gradle / Ivy

There is a newer version: 1.5.5
Show newest version
/*
 * Copyright (C) 2018 Square, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.squareup.sqldelight

import com.squareup.sqldelight.Transacter.Transaction
import com.squareup.sqldelight.db.SqlDriver
import com.squareup.sqldelight.internal.Atomic
import com.squareup.sqldelight.internal.AtomicBoolean
import com.squareup.sqldelight.internal.Supplier
import com.squareup.sqldelight.internal.getValue
import com.squareup.sqldelight.internal.setValue
import com.squareup.sqldelight.internal.sharedMap
import com.squareup.sqldelight.internal.sharedSet
import com.squareup.sqldelight.internal.threadLocalRef

private fun  Supplier<() -> T>.run() = invoke().invoke()

interface TransactionCallbacks {
  fun afterCommit(function: () -> Unit)
  fun afterRollback(function: () -> Unit)
}

interface TransactionWithReturn : TransactionCallbacks {
  /**
   * Rolls back this transaction.
   */
  fun rollback(returnValue: R): Nothing = throw RollbackException(returnValue)

  /**
   * Begin an inner transaction.
   */
  fun  transaction(body: TransactionWithReturn.() -> R): R
}

interface TransactionWithoutReturn : TransactionCallbacks {
  /**
   * Rolls back this transaction.
   */
  fun rollback(): Nothing = throw RollbackException()

  /**
   * Begin an inner transaction.
   */
  fun transaction(body: TransactionWithoutReturn.() -> Unit)
}

/**
 * A transaction-aware [SqlDriver] wrapper which can begin a [Transaction] on the current connection.
 */
interface Transacter {
  /**
   * Starts a [Transaction] and runs [body] in that transaction.
   *
   * @throws IllegalStateException if [noEnclosing] is true and there is already an active
   *   [Transaction] on this thread.
   */
  fun  transactionWithResult(
    noEnclosing: Boolean = false,
    bodyWithReturn: TransactionWithReturn.() -> R
  ): R

  /**
   * Starts a [Transaction] and runs [body] in that transaction.
   *
   * @throws IllegalStateException if [noEnclosing] is true and there is already an active
   *   [Transaction] on this thread.
   */
  fun transaction(
    noEnclosing: Boolean = false,
    body: TransactionWithoutReturn.() -> Unit
  )

  /**
   * A SQL transaction. Can be created through the driver via [SqlDriver.newTransaction] or
   * through an implementation of [Transacter] by calling [Transacter.transaction].
   */
  abstract class Transaction : TransactionCallbacks {
    internal val postCommitHooks = sharedSet Unit>>()
    internal val postRollbackHooks = sharedSet Unit>>()
    internal val queriesFuncs = sharedMap List>>>()

    internal var successful: Boolean by AtomicBoolean(false)
    internal var childrenSuccessful: Boolean by AtomicBoolean(true)
    internal var transacter: Transacter? by Atomic(null)

    /**
     * The parent transaction, if there is any.
     */
    protected abstract val enclosingTransaction: Transaction?

    internal fun enclosingTransaction() = enclosingTransaction

    /**
     * Signal to the underlying SQL driver that this transaction should be finished.
     *
     * @param successful Whether the transaction completed successfully or not.
     */
    protected abstract fun endTransaction(successful: Boolean)

    internal fun endTransaction() = endTransaction(successful && childrenSuccessful)
    /**
     * Queues [function] to be run after this transaction successfully commits.
     */

    override fun afterCommit(function: () -> Unit) {
      postCommitHooks.add(threadLocalRef(function))
    }

    /**
     * Queues [function] to be run after this transaction rolls back.
     */
    override fun afterRollback(function: () -> Unit) {
      postRollbackHooks.add(threadLocalRef(function))
    }
  }
}

private class RollbackException(val value: Any? = null) : Throwable()

private class TransactionWrapper(
  val transaction: Transaction
) : TransactionWithoutReturn, TransactionWithReturn {
  /**
   * Queues [function] to be run after this transaction successfully commits.
   */
  override fun afterCommit(function: () -> Unit) {
    transaction.afterCommit(function)
  }

  /**
   * Queues [function] to be run after this transaction rolls back.
   */
  override fun afterRollback(function: () -> Unit) {
    transaction.afterRollback(function)
  }

  override fun transaction(body: TransactionWithoutReturn.() -> Unit) {
    transaction.transacter!!.transaction(false, body)
  }

  override fun  transaction(body: TransactionWithReturn.() -> R): R {
    return transaction.transacter!!.transactionWithResult(false, body)
  }
}

/**
 * A transaction-aware [SqlDriver] wrapper which can begin a [Transaction] on the current connection.
 */
abstract class TransacterImpl(private val driver: SqlDriver) : Transacter {
  /**
   * For internal use, notifies the listeners of [queryList] that their underlying result set has
   * changed.
   */
  protected fun notifyQueries(identifier: Int, queryList: () -> List>) {
    val transaction = driver.currentTransaction()
    if (transaction != null) {
      if (!transaction.queriesFuncs.containsKey(identifier)) {
        transaction.queriesFuncs[identifier] = threadLocalRef(queryList)
      }
    } else {
      queryList.invoke().forEach { it.notifyDataChanged() }
    }
  }

  /**
   * For internal use, creates a string in the format (?, ?, ?) where there are [count] offset.
   */
  protected fun createArguments(count: Int): String {
    if (count == 0) return "()"

    return buildString(count + 2) {
      append("(?")
      repeat(count - 1) {
        append(",?")
      }
      append(')')
    }
  }

  override fun transaction(
    noEnclosing: Boolean,
    body: TransactionWithoutReturn.() -> Unit
  ) {
    transactionWithWrapper(noEnclosing, body)
  }

  override fun  transactionWithResult(
    noEnclosing: Boolean,
    bodyWithReturn: TransactionWithReturn.() -> R
  ): R {
    return transactionWithWrapper(noEnclosing, bodyWithReturn)
  }

  private fun  transactionWithWrapper(noEnclosing: Boolean, wrapperBody: TransactionWrapper.() -> R): R {
    val transaction = driver.newTransaction()
    val enclosing = transaction.enclosingTransaction()

    check(enclosing == null || !noEnclosing) { "Already in a transaction" }

    var thrownException: Throwable? = null
    var returnValue: R? = null

    try {
      transaction.transacter = this
      returnValue = TransactionWrapper(transaction).wrapperBody()
      transaction.successful = true
    } catch (e: Throwable) {
      thrownException = e
    } finally {
      transaction.endTransaction()
      if (enclosing == null) {
        if (!transaction.successful || !transaction.childrenSuccessful) {
          // TODO: If this throws, and we threw in [body] then create a composite exception.
          try {
            transaction.postRollbackHooks.forEach { it.run() }
          } catch (rollbackException: Throwable) {
            thrownException?.let {
              throw Throwable("Exception while rolling back from an exception.\nOriginal exception: $thrownException\nwith cause ${thrownException.cause}\n\nRollback exception: $rollbackException", rollbackException)
            }
            throw rollbackException
          }
          transaction.postRollbackHooks.clear()
        } else {
          transaction.queriesFuncs
            .flatMap { (_, queryListSupplier) -> queryListSupplier.run() }
            .distinct()
            .forEach { it.notifyDataChanged() }

          transaction.queriesFuncs.clear()
          transaction.postCommitHooks.forEach { it.run() }
          transaction.postCommitHooks.clear()
        }
      } else {
        enclosing.childrenSuccessful = transaction.successful && transaction.childrenSuccessful
        enclosing.postCommitHooks.addAll(transaction.postCommitHooks)
        enclosing.postRollbackHooks.addAll(transaction.postRollbackHooks)
        enclosing.queriesFuncs.putAll(transaction.queriesFuncs)
      }

      if (enclosing == null && thrownException is RollbackException) {
        // We can safely cast to R here because the rollback exception is always created with the
        // correct type.
        @Suppress("UNCHECKED_CAST")
        return thrownException.value as R
      } else if (thrownException != null) {
        throw thrownException
      } else {
        // We can safely cast to R here because any code path that led here will have set the
        // returnValue to the result of the block
        @Suppress("UNCHECKED_CAST")
        return returnValue as R
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy