All Downloads are FREE. Search and download functionalities are using the official Maven repository.

app.cash.sqldelight.driver.r2dbc.R2dbcDriver.kt Maven / Gradle / Ivy

package app.cash.sqldelight.driver.r2dbc

import app.cash.sqldelight.Query
import app.cash.sqldelight.Transacter
import app.cash.sqldelight.db.QueryResult
import app.cash.sqldelight.db.SqlCursor
import app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.db.SqlPreparedStatement
import io.r2dbc.spi.Connection
import io.r2dbc.spi.Statement
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactive.awaitSingle
import org.intellij.lang.annotations.Language
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

class R2dbcDriver(
  val connection: Connection,
  /**
   * This callback is called after [close]. It either contains an error or null, representing a successful close.
   */
  val closed: (Throwable?) -> Unit = { },
) : SqlDriver {
  override fun  executeQuery(
    identifier: Int?,
    @Language("SQL") sql: String,
    mapper: (SqlCursor) -> QueryResult,
    parameters: Int,
    binders: (SqlPreparedStatement.() -> Unit)?,
  ): QueryResult {
    val prepared = connection.createStatement(sql).also { statement ->
      R2dbcPreparedStatement(statement).apply { if (binders != null) this.binders() }
    }

    return QueryResult.AsyncValue {
      val result = prepared.execute().awaitSingle()

      val rowPublisher = result.map { row, rowMetadata ->
        List(rowMetadata.columnMetadatas.size) { index ->
          row.get(index)
        }
      }

      return@AsyncValue mapper(R2dbcCursor(rowPublisher.asIterator())).await()
    }
  }

  override fun execute(
    identifier: Int?,
    @Language("SQL") sql: String,
    parameters: Int,
    binders: (SqlPreparedStatement.() -> Unit)?,
  ): QueryResult {
    val prepared = connection.createStatement(sql).also { statement ->
      R2dbcPreparedStatement(statement).apply { if (binders != null) this.binders() }
    }

    return QueryResult.AsyncValue {
      val result = prepared.execute().awaitSingle()
      return@AsyncValue result.rowsUpdated.awaitFirstOrNull() ?: 0L
    }
  }

  private val transactions = ThreadLocal()
  private var transaction: Transacter.Transaction?
    get() = transactions.get()
    set(value) {
      transactions.set(value)
    }

  override fun newTransaction(): QueryResult = QueryResult.AsyncValue {
    val enclosing = transaction
    val transaction = Transaction(enclosing, connection)
    this.transaction = transaction

    if (enclosing == null) {
      connection.beginTransaction().awaitFirstOrNull()
    }

    return@AsyncValue transaction
  }

  override fun currentTransaction(): Transacter.Transaction? = transaction

  override fun addListener(vararg queryKeys: String, listener: Query.Listener) = Unit
  override fun removeListener(vararg queryKeys: String, listener: Query.Listener) = Unit
  override fun notifyListeners(vararg queryKeys: String) = Unit

  override fun close() {
    // Normally, this is just a Mono, so it completes directly without onNext.
    // But the standard allows any publisher, so we should request unlimited items
    // and wait until the close call is finally completed.
    connection.close().subscribe(object : Subscriber {
      override fun onSubscribe(sub: Subscription) {
        sub.request(Long.MAX_VALUE)
      }

      override fun onError(error: Throwable) {
        closed(error)
      }

      override fun onComplete() {
        closed(null)
      }

      override fun onNext(t: Void) {
        // Do nothing, we wait until completion.
      }
    })
  }

  private inner class Transaction(
    override val enclosingTransaction: Transacter.Transaction?,
    private val connection: Connection,
  ) : Transacter.Transaction() {
    override fun endTransaction(successful: Boolean): QueryResult = QueryResult.AsyncValue {
      if (enclosingTransaction == null) {
        if (successful) {
          connection.commitTransaction().awaitFirstOrNull()
        } else {
          connection.rollbackTransaction().awaitFirstOrNull()
        }
      }
      transaction = enclosingTransaction
    }
  }
}

/**
 * Creates and returns a [R2dbcDriver] with the given [connection].
 *
 * The scope waits until the driver is closed [R2dbcDriver.close].
 */
fun CoroutineScope.R2dbcDriver(
  connection: Connection,
): R2dbcDriver {
  val completed = Job()
  val driver = R2dbcDriver(connection) {
    if (it == null) {
      completed.complete()
    } else {
      completed.completeExceptionally(it)
    }
  }
  launch {
    completed.join()
  }
  return driver
}

// R2DBC uses boxed Java classes instead primitives: https://r2dbc.io/spec/1.0.0.RELEASE/spec/html/#datatypes
class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStatement {
  override fun bindBytes(index: Int, bytes: ByteArray?) {
    if (bytes == null) {
      statement.bindNull(index, ByteArray::class.java)
    } else {
      statement.bind(index, bytes)
    }
  }

  fun bindShort(index: Int, short: Short?) {
    if (short == null) {
      statement.bindNull(index, Short::class.javaObjectType)
    } else {
      statement.bind(index, short)
    }
  }

  fun bindInt(index: Int, int: Int?) {
    if (int == null) {
      statement.bindNull(index, Int::class.javaObjectType)
    } else {
      statement.bind(index, int)
    }
  }

  override fun bindLong(index: Int, long: Long?) {
    if (long == null) {
      statement.bindNull(index, Long::class.javaObjectType)
    } else {
      statement.bind(index, long)
    }
  }

  override fun bindDouble(index: Int, double: Double?) {
    if (double == null) {
      statement.bindNull(index, Double::class.javaObjectType)
    } else {
      statement.bind(index, double)
    }
  }

  override fun bindString(index: Int, string: String?) {
    if (string == null) {
      statement.bindNull(index, String::class.java)
    } else {
      statement.bind(index, string)
    }
  }

  override fun bindBoolean(index: Int, boolean: Boolean?) {
    if (boolean == null) {
      statement.bindNull(index, Boolean::class.javaObjectType)
    } else {
      statement.bind(index, boolean)
    }
  }

  fun bindObject(index: Int, any: Any?) {
    if (any == null) {
      statement.bindNull(index, Any::class.java)
    } else {
      statement.bind(index, any)
    }
  }
}

internal fun  Publisher.asIterator(): AsyncPublisherIterator =
  AsyncPublisherIterator(this)

internal class AsyncPublisherIterator(
  pub: Publisher,
) {
  private var nextValue = CompletableDeferred()
  private val subscription = CompletableDeferred()

  init {
    pub.subscribe(object : Subscriber {
      override fun onSubscribe(sub: Subscription) {
        subscription.complete(sub)
      }

      override fun onError(error: Throwable) {
        nextValue.completeExceptionally(error)
      }

      override fun onComplete() {
        nextValue.complete(null)
      }

      override fun onNext(next: T) {
        nextValue.complete(next)
      }
    })
  }

  suspend fun next(): T? {
    val sub = subscription.await()
    sub.request(1)
    try {
      val next = nextValue.await() ?: return null
      nextValue = CompletableDeferred()
      return next
    } catch (cancel: CancellationException) {
      sub.cancel()
      throw cancel
    }
  }
}

class R2dbcCursor
internal constructor(private val results: AsyncPublisherIterator>) : SqlCursor {
  private lateinit var currentRow: List

  override fun next(): QueryResult.AsyncValue = QueryResult.AsyncValue {
    val next = results.next() ?: return@AsyncValue false
    currentRow = next
    true
  }

  @PublishedApi
  internal fun  get(index: Int): T? {
    @Suppress("UNCHECKED_CAST")
    return currentRow[index] as T?
  }

  override fun getString(index: Int): String? = get(index)
  fun getShort(index: Int): Short? = get(index)?.toShort()
  fun getInt(index: Int): Int? = get(index)?.toInt()

  override fun getLong(index: Int): Long? = get(index)?.toLong()

  override fun getBytes(index: Int): ByteArray? = get(index)

  override fun getDouble(index: Int): Double? = get(index)

  override fun getBoolean(index: Int): Boolean? = get(index)

  inline fun  getObject(index: Int): T? = get(index)

  fun  getArray(index: Int): Array? = get(index)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy