All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ldbc.connector.net.protocol.CallableStatementImpl.scala Maven / Gradle / Ivy

There is a newer version: 0.3.0-beta10
Show newest version
/**
 * Copyright (c) 2023-2024 by Takahiko Tominaga
 * This software is licensed under the MIT License (MIT).
 * For more information see LICENSE or https://opensource.org/licenses/MIT
 */

package ldbc.connector.net.protocol

import java.time.*

import scala.collection.immutable.{ ListMap, SortedMap }

import cats.*
import cats.syntax.all.*

import cats.effect.*

import org.typelevel.otel4s.Attribute
import org.typelevel.otel4s.trace.{ Tracer, Span }

import ldbc.sql.{ Statement, CallableStatement, ResultSet, DatabaseMetaData, ParameterMetaData }

import ldbc.connector.*
import ldbc.connector.data.*
import ldbc.connector.exception.SQLException
import ldbc.connector.net.Protocol
import ldbc.connector.net.packet.response.*
import ldbc.connector.net.packet.request.*

case class CallableStatementImpl[F[_]: Temporal: Exchange: Tracer](
  protocol:                Protocol[F],
  serverVariables:         Map[String, String],
  sql:                     String,
  paramInfo:               CallableStatementImpl.ParamInfo,
  params:                  Ref[F, SortedMap[Int, Parameter]],
  batchedArgs:             Ref[F, Vector[String]],
  connectionClosed:        Ref[F, Boolean],
  statementClosed:         Ref[F, Boolean],
  resultSetClosed:         Ref[F, Boolean],
  currentResultSet:        Ref[F, Option[ResultSet]],
  outputParameterResult:   Ref[F, Option[ResultSetImpl]],
  resultSets:              Ref[F, List[ResultSetImpl]],
  parameterIndexToRsIndex: Ref[F, Map[Int, Int]],
  updateCount:             Ref[F, Long],
  moreResults:             Ref[F, Boolean],
  autoGeneratedKeys:       Ref[F, Int],
  lastInsertId:            Ref[F, Long],
  resultSetType:           Int = ResultSet.TYPE_FORWARD_ONLY,
  resultSetConcurrency:    Int = ResultSet.CONCUR_READ_ONLY
)(using ev: MonadError[F, Throwable])
  extends CallableStatement[F],
          SharedPreparedStatement[F]:

  private val attributes = protocol.initialPacket.attributes ++ List(
    Attribute("type", "CallableStatement"),
    Attribute("sql", sql)
  )

  override def executeQuery(): F[ResultSet] =
    checkClosed() *>
      checkNullOrEmptyQuery(sql) *>
      exchange[F, ResultSet]("statement") { (span: Span[F]) =>
        if sql.toUpperCase.startsWith("CALL") then
          executeCallStatement(span).flatMap { resultSets =>
            resultSets.headOption match
              case None =>
                for
                  resultSet <- ev.pure(
                                 ResultSetImpl.empty(
                                   serverVariables,
                                   protocol.initialPacket.serverVersion
                                 )
                               )
                  _ <- currentResultSet.set(Some(resultSet))
                yield resultSet
              case Some(resultSet) =>
                currentResultSet.update(_ => Some(resultSet)) *> resultSet.pure[F]
          } <* retrieveOutParams()
        else
          params.get.flatMap { params =>
            span.addAttributes(
              (attributes ++ List(
                Attribute("params", params.map((_, param) => param.toString).mkString(", ")),
                Attribute("execute", "query")
              ))*
            ) *>
              protocol.resetSequenceId *>
              protocol.send(
                ComQueryPacket(buildQuery(sql, params), protocol.initialPacket.capabilityFlags, ListMap.empty)
              ) *>
              receiveQueryResult()
          }
      } <* params.set(SortedMap.empty)

  override def executeLargeUpdate(): F[Long] =
    checkClosed() *>
      checkNullOrEmptyQuery(sql) *>
      exchange[F, Long]("statement") { (span: Span[F]) =>
        if sql.toUpperCase.startsWith("CALL") then
          executeCallStatement(span).flatMap { resultSets =>
            resultSets.headOption match
              case None =>
                for
                  resultSet <- ev.pure(
                                 ResultSetImpl.empty(
                                   serverVariables,
                                   protocol.initialPacket.serverVersion
                                 )
                               )
                  _ <- currentResultSet.set(Some(resultSet))
                yield resultSet
              case Some(resultSet) =>
                currentResultSet.update(_ => Some(resultSet)) *> resultSet.pure[F]
          } *> retrieveOutParams() *> ev.pure(-1)
        else
          params.get.flatMap { params =>
            span.addAttributes(
              (attributes ++ List(
                Attribute("params", params.map((_, param) => param.toString).mkString(", ")),
                Attribute("execute", "update")
              ))*
            ) *>
              sendQuery(buildQuery(sql, params)).flatMap {
                case result: OKPacket => lastInsertId.set(result.lastInsertId) *> ev.pure(result.affectedRows)
                case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
                case _: EOFPacket     => ev.raiseError(new SQLException("Unexpected EOF packet"))
              }
          }
      }

  override def execute(): F[Boolean] =
    checkClosed() *>
      checkNullOrEmptyQuery(sql) *>
      exchange[F, Boolean]("statement") { (span: Span[F]) =>
        if sql.toUpperCase.startsWith("CALL") then
          executeCallStatement(span).flatMap { results =>
            moreResults.update(_ => results.nonEmpty) *>
              currentResultSet.update(_ => results.headOption) *>
              resultSets.set(results.toList) *>
              ev.pure(results.nonEmpty)
          } <* retrieveOutParams()
        else
          params.get
            .flatMap { params =>
              span.addAttributes(
                (attributes ++ List(
                  Attribute("params", params.map((_, param) => param.toString).mkString(", ")),
                  Attribute("execute", "update")
                ))*
              ) *>
                sendQuery(buildQuery(sql, params)).flatMap {
                  case result: OKPacket => lastInsertId.set(result.lastInsertId) *> ev.pure(result.affectedRows)
                  case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
                  case _: EOFPacket     => ev.raiseError(new SQLException("Unexpected EOF packet"))
                }
            }
            .map(_ => false)
      }

  override def getMoreResults(): F[Boolean] =
    checkClosed() *> moreResults.get.flatMap { isMoreResults =>
      if isMoreResults then
        resultSets.get.flatMap {
          case Nil => moreResults.set(false) *> ev.pure(false)
          case resultSet :: tail =>
            currentResultSet.set(Some(resultSet)) *> resultSets.set(tail) *> ev.pure(true)
        }
      else ev.pure(false)
    }

  override def addBatch(): F[Unit] =
    checkClosed() *>
      checkNullOrEmptyQuery(sql) *> (
        sql.toUpperCase match
          case q if q.startsWith("CALL") =>
            setInOutParamsOnServer(paramInfo) *> setOutParams()
          case _ => ev.unit
      ) *>
      params.get.flatMap { params =>
        batchedArgs.update(_ :+ buildBatchQuery(sql, params))
      } *>
      params.set(SortedMap.empty)

  override def clearBatch(): F[Unit] = batchedArgs.set(Vector.empty)

  override def executeLargeBatch(): F[Array[Long]] =
    checkClosed() *>
      checkNullOrEmptyQuery(sql) *>
      exchange[F, Array[Long]]("statement") { (span: Span[F]) =>
        batchedArgs.get.flatMap { args =>
          span.addAttributes(
            (attributes ++ List(
              Attribute("execute", "batch"),
              Attribute("size", args.length.toLong),
              Attribute("sql", args.toArray.toSeq)
            ))*
          ) *> (
            if args.isEmpty then ev.pure(Array.empty)
            else
              sql.toUpperCase match
                case q if q.startsWith("INSERT") =>
                  sendQuery(sql.split("VALUES").head + " VALUES" + args.mkString(","))
                    .flatMap {
                      case _: OKPacket      => ev.pure(Array.fill(args.length)(Statement.SUCCESS_NO_INFO.toLong))
                      case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
                      case _: EOFPacket     => ev.raiseError(new SQLException("Unexpected EOF packet"))
                    }
                case q if q.startsWith("update") || q.startsWith("delete") || q.startsWith("CALL") =>
                  protocol.resetSequenceId *>
                    protocol.comSetOption(EnumMySQLSetOption.MYSQL_OPTION_MULTI_STATEMENTS_ON) *>
                    protocol.resetSequenceId *>
                    protocol.send(
                      ComQueryPacket(
                        args.mkString(";"),
                        protocol.initialPacket.capabilityFlags,
                        ListMap.empty
                      )
                    ) *>
                    args
                      .foldLeft(ev.pure(Vector.empty[Long])) { ($acc, _) =>
                        for
                          acc <- $acc
                          result <-
                            protocol
                              .receive(GenericResponsePackets.decoder(protocol.initialPacket.capabilityFlags))
                              .flatMap {
                                case result: OKPacket =>
                                  lastInsertId.set(result.lastInsertId) *> ev.pure(acc :+ result.affectedRows)
                                case error: ERRPacket =>
                                  ev.raiseError(error.toException("Failed to execute batch", acc))
                                case _: EOFPacket => ev.raiseError(new SQLException("Unexpected EOF packet"))
                              }
                        yield result
                      }
                      .map(_.toArray) <*
                    protocol.resetSequenceId <*
                    protocol.comSetOption(EnumMySQLSetOption.MYSQL_OPTION_MULTI_STATEMENTS_OFF)
                case _ =>
                  ev.raiseError(
                    new SQLException("The batch query must be an INSERT, UPDATE, or DELETE, CALL statement.")
                  )
          )
        }
      } <* params.set(SortedMap.empty) <* batchedArgs.set(Vector.empty)

  override def getGeneratedKeys(): F[ResultSet] =
    autoGeneratedKeys.get.flatMap {
      case Statement.RETURN_GENERATED_KEYS =>
        for
          lastInsertId <- lastInsertId.get
          resultSet <- ev.pure(
                         ResultSetImpl(
                           Vector(new ColumnDefinitionPacket:
                             override def table:      String                     = ""
                             override def name:       String                     = "GENERATED_KEYS"
                             override def columnType: ColumnDataType             = ColumnDataType.MYSQL_TYPE_LONGLONG
                             override def flags:      Seq[ColumnDefinitionFlags] = Seq.empty
                           ),
                           Vector(ResultSetRowPacket(Array(Some(lastInsertId.toString)))),
                           serverVariables,
                           protocol.initialPacket.serverVersion
                         )
                       )
          _ <- currentResultSet.set(Some(resultSet))
        yield resultSet
      case Statement.NO_GENERATED_KEYS =>
        ev.raiseError(
          new SQLException(
            "Generated keys not requested. You need to specify Statement.RETURN_GENERATED_KEYS to Statement.executeUpdate(), Statement.executeLargeUpdate() or Connection.prepareStatement()."
          )
        )
    }

  override def close(): F[Unit] = statementClosed.set(true) *> resultSetClosed.set(true)

  override def registerOutParameter(parameterIndex: Int, sqlType: Int): F[Unit] =
    if paramInfo.numParameters > 0 then
      paramInfo.parameterList.find(_.index == parameterIndex) match
        case Some(param) =>
          (if param.jdbcType == sqlType then ev.unit
           else
             ev.raiseError(
               new SQLException(
                 "The type specified for the parameter does not match the type registered as a procedure."
               )
             )
          ) *> (
            if param.isOut && param.isIn then
              val paramName          = param.paramName.getOrElse("nullnp" + param.index)
              val inOutParameterName = mangleParameterName(paramName)

              val queryBuf = new StringBuilder(4 + inOutParameterName.length + 1)
              queryBuf.append("SET ")
              queryBuf.append(inOutParameterName)
              queryBuf.append("=")

              params.get.flatMap { params =>
                val sql = queryBuf.toString ++ params.get(param.index).fold("NULL")(_.sql)
                sendQuery(sql).flatMap {
                  case _: OKPacket      => ev.unit
                  case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
                  case _: EOFPacket     => ev.raiseError(new SQLException("Unexpected EOF packet"))
                }
              }
            else ev.raiseError(new SQLException("No output parameters returned by procedure."))
          )
        case None =>
          ev.raiseError(
            new SQLException(s"Parameter index of $parameterIndex is out of range (1, ${ paramInfo.numParameters })")
          )
    else ev.unit

  override def getString(parameterIndex: Int): F[Option[String]] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getString(index)))
    yield Option(value)

  override def getBoolean(parameterIndex: Int): F[Boolean] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getBoolean(index)))
    yield value

  override def getByte(parameterIndex: Int): F[Byte] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getByte(index)))
    yield value

  override def getShort(parameterIndex: Int): F[Short] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getShort(index)))
    yield value

  override def getInt(parameterIndex: Int): F[Int] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getInt(index)))
    yield value

  override def getLong(parameterIndex: Int): F[Long] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getLong(index)))
    yield value

  override def getFloat(parameterIndex: Int): F[Float] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getFloat(index)))
    yield value

  override def getDouble(parameterIndex: Int): F[Double] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getDouble(index)))
    yield value

  override def getBytes(parameterIndex: Int): F[Option[Array[Byte]]] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getBytes(index)))
    yield Option(value)

  override def getDate(parameterIndex: Int): F[Option[LocalDate]] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getDate(index)))
    yield Option(value)

  override def getTime(parameterIndex: Int): F[Option[LocalTime]] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getTime(index)))
    yield Option(value)

  override def getTimestamp(parameterIndex: Int): F[Option[LocalDateTime]] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getTimestamp(index)))
    yield Option(value)

  override def getBigDecimal(parameterIndex: Int): F[Option[BigDecimal]] =
    for
      resultSet <- checkBounds(parameterIndex) *> getOutputParameters()
      paramMap  <- parameterIndexToRsIndex.get
      index = paramMap.getOrElse(parameterIndex, parameterIndex)
      value <-
        (if index == CallableStatementImpl.NOT_OUTPUT_PARAMETER_INDICATOR then
           ev.raiseError(new SQLException(s"Parameter $parameterIndex is not registered as an output parameter"))
         else shiftF(resultSet.getBigDecimal(index)))
    yield Option(value)

  override def getString(parameterName: String): F[Option[String]] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getString(mangleParameterName(parameterName)))
    yield Option(value)

  override def getBoolean(parameterName: String): F[Boolean] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getBoolean(mangleParameterName(parameterName)))
    yield value

  override def getByte(parameterName: String): F[Byte] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getByte(mangleParameterName(parameterName)))
    yield value

  override def getShort(parameterName: String): F[Short] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getShort(mangleParameterName(parameterName)))
    yield value

  override def getInt(parameterName: String): F[Int] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getInt(mangleParameterName(parameterName)))
    yield value

  override def getLong(parameterName: String): F[Long] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getLong(mangleParameterName(parameterName)))
    yield value

  override def getFloat(parameterName: String): F[Float] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getFloat(mangleParameterName(parameterName)))
    yield value

  override def getDouble(parameterName: String): F[Double] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getDouble(mangleParameterName(parameterName)))
    yield value

  override def getBytes(parameterName: String): F[Option[Array[Byte]]] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getBytes(mangleParameterName(parameterName)))
    yield Option(value)

  override def getDate(parameterName: String): F[Option[LocalDate]] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getDate(mangleParameterName(parameterName)))
    yield Option(value)

  override def getTime(parameterName: String): F[Option[LocalTime]] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getTime(mangleParameterName(parameterName)))
    yield Option(value)

  override def getTimestamp(parameterName: String): F[Option[LocalDateTime]] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getTimestamp(mangleParameterName(parameterName)))
    yield Option(value)

  override def getBigDecimal(parameterName: String): F[Option[BigDecimal]] =
    for
      resultSet <- getOutputParameters()
      value     <- shiftF(resultSet.getBigDecimal(mangleParameterName(parameterName)))
    yield Option(value)

  private def setParameter(index: Int, value: String): F[Unit] =
    params.update(_ + (index -> Parameter.parameter(value)))

  private def sendQuery(sql: String): F[GenericResponsePackets] =
    checkNullOrEmptyQuery(sql) *> protocol.resetSequenceId *> protocol.send(
      ComQueryPacket(sql, protocol.initialPacket.capabilityFlags, ListMap.empty)
    ) *> protocol.receive(GenericResponsePackets.decoder(protocol.initialPacket.capabilityFlags))

  private def receiveUntilOkPacket(resultSets: Vector[ResultSetImpl]): F[Vector[ResultSetImpl]] =
    protocol.receive(ColumnsNumberPacket.decoder(protocol.initialPacket.capabilityFlags)).flatMap {
      case _: OKPacket      => resultSets.pure[F]
      case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
      case result: ColumnsNumberPacket =>
        for
          columnDefinitions <-
            protocol.repeatProcess(
              result.size,
              ColumnDefinitionPacket.decoder(protocol.initialPacket.capabilityFlags)
            )
          resultSetRow <-
            protocol.readUntilEOF[ResultSetRowPacket](
              ResultSetRowPacket.decoder(protocol.initialPacket.capabilityFlags, columnDefinitions)
            )
          resultSet = ResultSetImpl(
                        columnDefinitions,
                        resultSetRow,
                        serverVariables,
                        protocol.initialPacket.serverVersion,
                        resultSetType,
                        resultSetConcurrency
                      )
          resultSets <- receiveUntilOkPacket(resultSets :+ resultSet)
        yield resultSets
    }

  private def receiveQueryResult(): F[ResultSet] =
    protocol.receive(ColumnsNumberPacket.decoder(protocol.initialPacket.capabilityFlags)).flatMap {
      case _: OKPacket =>
        ev.pure(
          ResultSetImpl
            .empty(
              serverVariables,
              protocol.initialPacket.serverVersion
            )
        )
      case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
      case result: ColumnsNumberPacket =>
        for
          columnDefinitions <-
            protocol.repeatProcess(
              result.size,
              ColumnDefinitionPacket.decoder(protocol.initialPacket.capabilityFlags)
            )
          resultSetRow <- protocol.readUntilEOF[ResultSetRowPacket](
                            ResultSetRowPacket.decoder(protocol.initialPacket.capabilityFlags, columnDefinitions)
                          )
          resultSet = ResultSetImpl(
                        columnDefinitions,
                        resultSetRow,
                        serverVariables,
                        protocol.initialPacket.serverVersion,
                        resultSetType,
                        resultSetConcurrency
                      )
          _ <- currentResultSet.set(Some(resultSet))
        yield resultSet
    }

  /**
   * Change the parameter name to an arbitrary prefixed naming.
   *
   * @param origParameterName
   * the original parameter name
   * @return
   * the parameter name
   */
  private def mangleParameterName(origParameterName: String): String =
    val offset = if origParameterName.nonEmpty && origParameterName.charAt(0) == '@' then 1 else 0

    val paramNameBuf = new StringBuilder(
      CallableStatementImpl.PARAMETER_NAMESPACE_PREFIX.length + origParameterName.length
    )
    paramNameBuf.append(CallableStatementImpl.PARAMETER_NAMESPACE_PREFIX)
    paramNameBuf.append(origParameterName.substring(offset))

    paramNameBuf.toString

  /**
   * Set output parameters to be used by the server.
   *
   * @param paramInfo
   * the parameter information
   */
  private def setInOutParamsOnServer(paramInfo: CallableStatementImpl.ParamInfo): F[Unit] =
    if paramInfo.numParameters > 0 then
      paramInfo.parameterList.foldLeft(ev.unit) { (acc, param) =>
        if param.isOut && param.isIn then
          val paramName          = param.paramName.getOrElse("nullnp" + param.index)
          val inOutParameterName = mangleParameterName(paramName)

          val queryBuf = new StringBuilder(4 + inOutParameterName.length + 1)
          queryBuf.append("SET ")
          queryBuf.append(inOutParameterName)
          queryBuf.append("=")

          acc *> params.get.flatMap { params =>
            val sql = queryBuf.toString ++ params.get(param.index).fold("NULL")(_.sql)
            sendQuery(sql).flatMap {
              case _: OKPacket      => ev.unit
              case error: ERRPacket => ev.raiseError(error.toException(Some(sql), None))
              case _: EOFPacket     => ev.raiseError(new SQLException("Unexpected EOF packet"))
            }
          }
        else acc
      }
    else ev.unit

  /**
   * Set output parameters to be handled by the client.
   */
  private def setOutParams(): F[Unit] =
    if paramInfo.numParameters > 0 then
      paramInfo.parameterList.foldLeft(ev.unit) { (acc, param) =>
        if !paramInfo.isFunctionCall && param.isOut then
          val paramName        = param.paramName.getOrElse("nullnp" + param.index)
          val outParameterName = mangleParameterName(paramName)

          acc *> params.get.flatMap { params =>
            for
              outParamIndex <- (
                                 if params.isEmpty then ev.pure(param.index)
                                 else
                                   params.keys
                                     .find(_ == param.index)
                                     .fold(
                                       ev.raiseError(
                                         new SQLException(
                                           s"Parameter ${ param.index } is not registered as an output parameter"
                                         )
                                       )
                                     )(_.pure[F])
                               )
              _ <- setParameter(outParamIndex, outParameterName)
            yield ()
          }
        else acc
      }
    else ev.unit

  /**
   * Issues a second query to retrieve all output parameters.
   */
  private def retrieveOutParams(): F[Unit] =
    val parameters = paramInfo.parameterList.foldLeft(Vector.empty[(Int, String)]) { (acc, param) =>
      if param.isOut then
        val paramName        = param.paramName.getOrElse("nullnp" + param.index)
        val outParameterName = mangleParameterName(paramName)
        acc :+ (param.index, outParameterName)
      else acc
    }

    if paramInfo.numParameters > 0 && parameters.nonEmpty then

      val sql = parameters.zipWithIndex
        .map {
          case ((_, paramName), index) =>
            val prefix = if index != 0 then ", " else ""
            val atSign = if !paramName.startsWith("@") then "@" else ""
            s"$prefix$atSign$paramName"
        }
        .mkString("SELECT ", "", "")

      checkClosed() *>
        checkNullOrEmptyQuery(sql) *>
        protocol.resetSequenceId *>
        protocol.send(ComQueryPacket(sql, protocol.initialPacket.capabilityFlags, ListMap.empty)) *>
        receiveQueryResult().flatMap {
          case resultSet: ResultSetImpl => outputParameterResult.update(_ => Some(resultSet))
        } *>
        parameters.zipWithIndex.foldLeft(ev.unit) {
          case (acc, ((paramIndex, _), index)) =>
            acc *> parameterIndexToRsIndex.update(_ + (paramIndex -> (index + 1)))
        }
    else ev.unit

  /**
   * Returns the ResultSet that holds the output parameters, or throws an
   * appropriate exception if none exist, or they weren't returned.
   *
   * @return
   * the ResultSet that holds the output parameters
   */
  private def getOutputParameters(): F[ResultSetImpl] =
    outputParameterResult.get.flatMap {
      case None =>
        if paramInfo.numParameters == 0 then ev.raiseError(new SQLException("No output parameters registered."))
        else ev.raiseError(new SQLException("No output parameters returned by procedure."))
      case Some(resultSet) => resultSet.pure[F]
    }

  /**
   * Checks if the parameter index is within the bounds of the number of parameters.
   *
   * @param paramIndex
   * the parameter index to check
   */
  private def checkBounds(paramIndex: Int): F[Unit] =
    if paramIndex < 1 || paramIndex > paramInfo.numParameters then
      ev.raiseError(
        new SQLException(s"Parameter index of ${ paramIndex } is out of range (1, ${ paramInfo.numParameters })")
      )
    else ev.unit

  /**
   * Executes a CALL/Stored function statement.
   *
   * @param span
   * the span
   * @return
   * a list of ResultSet
   */
  private def executeCallStatement(span: Span[F]): F[Vector[ResultSetImpl]] =
    setInOutParamsOnServer(paramInfo) *>
      setOutParams() *>
      params.get.flatMap { params =>
        span.addAttributes(
          (attributes ++ List(
            Attribute("params", params.map((_, param) => param.toString).mkString(", ")),
            Attribute("execute", "query")
          ))*
        ) *>
          protocol.resetSequenceId *>
          protocol.send(
            ComQueryPacket(buildQuery(sql, params), protocol.initialPacket.capabilityFlags, ListMap.empty)
          ) *>
          receiveUntilOkPacket(Vector.empty)
      }

object CallableStatementImpl:

  val NOT_OUTPUT_PARAMETER_INDICATOR: Int = Int.MinValue

  private val PARAMETER_NAMESPACE_PREFIX = "@ldbc_mysql_outparam_"

  /**
   * CallableStatementParameter represents a parameter in a stored procedure.
   *
   * @param paramName
   *   the name of the parameter
   * @param isIn
   *   whether the parameter is an input parameter
   * @param isOut
   *   whether the parameter is an output parameter
   * @param index
   *   the index of the parameter
   * @param jdbcType
   *   the JDBC type of the parameter
   * @param typeName
   *   the name of the type of the parameter
   * @param precision
   *   the precision of the parameter
   * @param scale
   *   the scale of the parameter
   * @param nullability
   *   the nullability of the parameter
   * @param inOutModifier
   *   the in/out modifier of the parameter
   */
  case class CallableStatementParameter(
    paramName:     Option[String],
    isIn:          Boolean,
    isOut:         Boolean,
    index:         Int,
    jdbcType:      Int,
    typeName:      Option[String],
    precision:     Int,
    scale:         Int,
    nullability:   Short,
    inOutModifier: Int
  )

  /**
   * ParamInfo represents the information about the parameters in a stored procedure.
   *
   * @param nativeSql
   *   the original SQL statement
   * @param dbInUse
   *   the database in use
   * @param isFunctionCall
   *   whether the SQL statement is a function call
   * @param numParameters
   *   the number of parameters in the SQL statement
   * @param parameterList
   *   a list of CallableStatementParameter representing each parameter
   * @param parameterMap
   *   a map from parameter name to CallableStatementParameter
   */
  case class ParamInfo(
    nativeSql:      String,
    dbInUse:        Option[String],
    isFunctionCall: Boolean,
    numParameters:  Int,
    parameterList:  List[CallableStatementParameter],
    parameterMap:   ListMap[String, CallableStatementParameter]
  )

  object ParamInfo:

    def apply(
      nativeSql:      String,
      database:       Option[String],
      resultSet:      ResultSetImpl,
      isFunctionCall: Boolean
    ): ParamInfo =
      val builder = List.newBuilder[CallableStatementParameter]
      while resultSet.next() do
        val index           = resultSet.getRow()
        val paramName       = resultSet.getString(4)
        val procedureColumn = resultSet.getInt(5)
        val jdbcType        = resultSet.getInt(6)
        val typeName        = resultSet.getString(7)
        val precision       = resultSet.getInt(8)
        val scale           = resultSet.getInt(19)
        val nullability     = resultSet.getShort(12)

        val inOutModifier = procedureColumn match
          case DatabaseMetaData.procedureColumnIn    => ParameterMetaData.parameterModeIn
          case DatabaseMetaData.procedureColumnInOut => ParameterMetaData.parameterModeInOut
          case DatabaseMetaData.procedureColumnOut | DatabaseMetaData.procedureColumnReturn =>
            ParameterMetaData.parameterModeOut
          case _ => ParameterMetaData.parameterModeUnknown

        val (isOutParameter, isInParameter) =
          if index - 1 == 0 && isFunctionCall then (true, false)
          else if inOutModifier == DatabaseMetaData.procedureColumnInOut then (true, true)
          else if inOutModifier == DatabaseMetaData.procedureColumnIn then (false, true)
          else if inOutModifier == DatabaseMetaData.procedureColumnOut then (true, false)
          else (false, false)

        builder += CallableStatementParameter(
          Option(paramName),
          isInParameter,
          isOutParameter,
          index,
          jdbcType,
          Option(typeName),
          precision,
          scale,
          nullability,
          inOutModifier
        )

      val parameterList = builder.result()

      ParamInfo(
        nativeSql      = nativeSql,
        dbInUse        = database,
        isFunctionCall = isFunctionCall,
        numParameters  = resultSet.rowLength(),
        parameterList  = parameterList,
        parameterMap   = ListMap(parameterList.map(p => p.paramName.getOrElse("") -> p)*)
      )




© 2015 - 2025 Weber Informatics LLC | Privacy Policy