All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ldbc.connector.net.protocol.StatementImpl.scala Maven / Gradle / Ivy

The newest version!
/**
 * Copyright (c) 2023-2024 by Takahiko Tominaga
 * This software is licensed under the MIT License (MIT).
 * For more information see LICENSE or https://opensource.org/licenses/MIT
 */

package ldbc.connector.net.protocol

import scala.collection.immutable.ListMap

import cats.*
import cats.syntax.all.*

import cats.effect.*

import org.typelevel.otel4s.Attribute
import org.typelevel.otel4s.trace.{ Tracer, Span }

import ldbc.sql.{ Statement, ResultSet }

import ldbc.connector.ResultSetImpl
import ldbc.connector.data.*
import ldbc.connector.exception.SQLException
import ldbc.connector.net.Protocol
import ldbc.connector.net.packet.response.*
import ldbc.connector.net.packet.request.*

private[ldbc] case class StatementImpl[F[_]: Temporal: Exchange: Tracer](
  protocol:             Protocol[F],
  serverVariables:      Map[String, String],
  batchedArgs:          Ref[F, Vector[String]],
  connectionClosed:     Ref[F, Boolean],
  statementClosed:      Ref[F, Boolean],
  resultSetClosed:      Ref[F, Boolean],
  currentResultSet:     Ref[F, Option[ResultSet]],
  updateCount:          Ref[F, Long],
  moreResults:          Ref[F, Boolean],
  autoGeneratedKeys:    Ref[F, Int],
  lastInsertId:         Ref[F, Long],
  resultSetType:        Int = ResultSet.TYPE_FORWARD_ONLY,
  resultSetConcurrency: Int = ResultSet.CONCUR_READ_ONLY
)(using ev: MonadError[F, Throwable])
  extends StatementImpl.ShareStatement[F]:

  private val attributes = protocol.initialPacket.attributes ++ List(Attribute("type", "Statement"))

  override def executeQuery(sql: String): F[ResultSet] =
    checkClosed() *> checkNullOrEmptyQuery(sql) *> exchange[F, ResultSet]("statement") { (span: Span[F]) =>
      span.addAttributes((attributes ++ List(Attribute("execute", "query"), Attribute("sql", sql)))*) *>
        protocol.resetSequenceId *>
        protocol.send(ComQueryPacket(sql, protocol.initialPacket.capabilityFlags, ListMap.empty)) *>
        protocol.receive(ColumnsNumberPacket.decoder(protocol.initialPacket.capabilityFlags)).flatMap {
          case _: OKPacket =>
            for
              resultSet <- ev.pure(
                             ResultSetImpl
                               .empty(
                                 serverVariables,
                                 protocol.initialPacket.serverVersion
                               )
                           )
              _ <- currentResultSet.set(Some(resultSet))
            yield resultSet
          case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
          case result: ColumnsNumberPacket =>
            for
              columnDefinitions <-
                protocol.repeatProcess(
                  result.size,
                  ColumnDefinitionPacket.decoder(protocol.initialPacket.capabilityFlags)
                )
              resultSetRow <- protocol.readUntilEOF[ResultSetRowPacket](
                                ResultSetRowPacket.decoder(protocol.initialPacket.capabilityFlags, columnDefinitions)
                              )
              resultSet = ResultSetImpl(
                            columnDefinitions,
                            resultSetRow,
                            serverVariables,
                            protocol.initialPacket.serverVersion,
                            resultSetType,
                            resultSetConcurrency
                          )
              _ <- currentResultSet.set(Some(resultSet))
            yield resultSet
        }
    }

  override def executeUpdate(sql: String): F[Int] =
    executeLargeUpdate(sql).map(_.toInt)

  override def executeLargeUpdate(sql: String): F[Long] =
    checkClosed() *> checkNullOrEmptyQuery(sql) *> exchange[F, Long]("statement") { (span: Span[F]) =>
      span.addAttributes(
        (attributes ++ List(Attribute("execute", "update"), Attribute("sql", sql)))*
      ) *> protocol.resetSequenceId *> (
        protocol.send(ComQueryPacket(sql, protocol.initialPacket.capabilityFlags, ListMap.empty)) *>
          protocol.receive(GenericResponsePackets.decoder(protocol.initialPacket.capabilityFlags)).flatMap {
            case result: OKPacket =>
              lastInsertId.set(result.lastInsertId) *> updateCount.updateAndGet(_ => result.affectedRows)
            case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
            case _: EOFPacket     => ev.raiseError(new SQLException("Unexpected EOF packet"))
          }
      )
    }

  override def close(): F[Unit] = statementClosed.set(true) *> resultSetClosed.set(true)

  override def execute(sql: String): F[Boolean] =
    checkClosed() *> checkNullOrEmptyQuery(sql) *> (
      if sql.toUpperCase.startsWith("SELECT") then
        executeQuery(sql).map {
          case resultSet: ResultSetImpl => resultSet.hasRows()
          case _                        => false
        }
      else executeUpdate(sql).map(_ => false)
    )

  override def addBatch(sql: String): F[Unit] = batchedArgs.update(_ :+ sql)

  override def clearBatch(): F[Unit] = batchedArgs.set(Vector.empty)

  override def executeLargeBatch(): F[Array[Long]] =
    checkClosed() *> protocol.resetSequenceId *>
      protocol.comSetOption(EnumMySQLSetOption.MYSQL_OPTION_MULTI_STATEMENTS_ON) *>
      exchange[F, Array[Long]]("statement") { (span: Span[F]) =>
        batchedArgs.get.flatMap { args =>
          span.addAttributes(
            (attributes ++ List(
              Attribute("execute", "batch"),
              Attribute("size", args.length.toLong),
              Attribute("sql", args.toArray.toSeq)
            ))*
          ) *> (
            if args.isEmpty then ev.pure(Array.empty)
            else
              protocol.resetSequenceId *>
                protocol.send(
                  ComQueryPacket(args.mkString(";"), protocol.initialPacket.capabilityFlags, ListMap.empty)
                ) *>
                args
                  .foldLeft(ev.pure(Vector.empty[Long])) { ($acc, _) =>
                    for
                      acc <- $acc
                      result <-
                        protocol
                          .receive(GenericResponsePackets.decoder(protocol.initialPacket.capabilityFlags))
                          .flatMap {
                            case result: OKPacket =>
                              lastInsertId.set(result.lastInsertId) *> ev.pure(acc :+ result.affectedRows)
                            case error: ERRPacket =>
                              ev.raiseError(error.toException("Failed to execute batch", acc))
                            case _: EOFPacket => ev.raiseError(new SQLException("Unexpected EOF packet"))
                          }
                    yield result
                  }
                  .map(_.toArray)
          )
        }
      } <* protocol.resetSequenceId <* protocol.comSetOption(
        EnumMySQLSetOption.MYSQL_OPTION_MULTI_STATEMENTS_OFF
      ) <* clearBatch()

  override def getGeneratedKeys(): F[ResultSet] =
    autoGeneratedKeys.get.flatMap {
      case Statement.RETURN_GENERATED_KEYS =>
        for
          lastInsertId <- lastInsertId.get
          resultSet = ResultSetImpl(
                        Vector(new ColumnDefinitionPacket:
                          override def table: String = ""

                          override def name: String = "GENERATED_KEYS"

                          override def columnType: ColumnDataType = ColumnDataType.MYSQL_TYPE_LONGLONG

                          override def flags: Seq[ColumnDefinitionFlags] = Seq.empty
                        ),
                        Vector(ResultSetRowPacket(Array(Some(lastInsertId.toString)))),
                        serverVariables,
                        protocol.initialPacket.serverVersion
                      )
          _ <- currentResultSet.set(Some(resultSet))
        yield resultSet
      case Statement.NO_GENERATED_KEYS =>
        ev.raiseError(
          new SQLException(
            "Generated keys not requested. You need to specify Statement.RETURN_GENERATED_KEYS to Statement.executeUpdate(), Statement.executeLargeUpdate() or Connection.prepareStatement()."
          )
        )
    }

object StatementImpl:

  private[ldbc] trait ShareStatement[F[_]: Temporal](using ev: MonadError[F, Throwable]) extends Statement[F]:

    def statementClosed:   Ref[F, Boolean]
    def connectionClosed:  Ref[F, Boolean]
    def currentResultSet:  Ref[F, Option[ResultSet]]
    def updateCount:       Ref[F, Long]
    def moreResults:       Ref[F, Boolean]
    def autoGeneratedKeys: Ref[F, Int]
    def lastInsertId:      Ref[F, Long]

    override def getResultSet():        F[Option[ResultSet]] = checkClosed() *> currentResultSet.get
    override def getUpdateCount():      F[Int]               = checkClosed() *> updateCount.get.map(_.toInt)
    override def getLargeUpdateCount(): F[Long]              = checkClosed() *> updateCount.get
    override def getMoreResults():      F[Boolean]           = checkClosed() *> moreResults.get

    override def executeUpdate(sql: String, autoGeneratedKeys: Int): F[Int] =
      this.autoGeneratedKeys.set(autoGeneratedKeys) *> executeUpdate(sql)

    override def executeLargeUpdate(sql: String, autoGeneratedKeys: Int): F[Long] =
      this.autoGeneratedKeys.set(autoGeneratedKeys) *> executeLargeUpdate(sql)

    override def execute(sql: String, autoGeneratedKeys: Int): F[Boolean] =
      this.autoGeneratedKeys.set(autoGeneratedKeys) *> execute(sql)

    override def isClosed(): F[Boolean] =
      for
        connClosed <- connectionClosed.get
        stmtClosed <- statementClosed.get
      yield connClosed || stmtClosed

    override def executeBatch(): F[Array[Int]] = executeLargeBatch().map(_.map(_.toInt))

    protected def checkClosed(): F[Unit] =
      isClosed().ifM(
        close() *> ev.raiseError(new SQLException("No operations allowed after statement closed.")),
        ev.unit
      )

    protected def checkNullOrEmptyQuery(sql: String): F[Unit] =
      if sql.isEmpty then ev.raiseError(new SQLException("Can not issue empty query."))
      else if sql == null then ev.raiseError(new SQLException("Can not issue NULL query."))
      else ev.unit

    protected def shiftF[T](f: => T): F[T] =
      try ev.pure(f)
      catch case scala.util.control.NonFatal(e) => ev.raiseError(e)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy