All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.augustnagro.magnum.util.scala Maven / Gradle / Ivy

The newest version!
package com.augustnagro.magnum

import com.augustnagro.magnum.SqlException

import java.lang.System.Logger.Level
import java.sql.{Connection, PreparedStatement, ResultSet, Statement}
import java.util.concurrent.TimeUnit
import javax.sql.DataSource
import scala.collection.mutable.ReusableBuilder
import scala.util.{Failure, Success, Try, Using, boundary}
import scala.deriving.Mirror
import scala.compiletime.{
  constValue,
  constValueTuple,
  erasedValue,
  error,
  summonInline
}
import scala.compiletime.ops.any.==
import scala.compiletime.ops.boolean.&&
import scala.concurrent.duration.FiniteDuration
import scala.reflect.ClassTag
import scala.quoted.*

def connect[T](transactor: Transactor)(f: DbCon ?=> T): T =
  Using.resource(transactor.dataSource.getConnection): con =>
    transactor.connectionConfig(con)
    f(using DbCon(con, transactor.sqlLogger))

def connect[T](dataSource: DataSource)(f: DbCon ?=> T): T =
  connect(Transactor(dataSource))(f)

def transact[T](transactor: Transactor)(f: DbTx ?=> T): T =
  Using.resource(transactor.dataSource.getConnection): con =>
    transactor.connectionConfig(con)
    con.setAutoCommit(false)
    try
      val res = f(using DbTx(con, transactor.sqlLogger))
      con.commit()
      res
    catch
      case t =>
        con.rollback()
        throw t

def transact[T](dataSource: DataSource)(f: DbTx ?=> T): T =
  transact(Transactor(dataSource))(f)

def transact[T](dataSource: DataSource, connectionConfig: Connection => Unit)(
    f: DbTx ?=> T
): T =
  val transactor =
    Transactor(dataSource = dataSource, connectionConfig = connectionConfig)
  transact(transactor)(f)

extension (inline sc: StringContext)
  inline def sql(inline args: Any*): Frag =
    ${ sqlImpl('{ sc }, '{ args }) }

private def sqlImpl(sc: Expr[StringContext], args: Expr[Seq[Any]])(using
    Quotes
): Expr[Frag] =
  import quotes.reflect.*
  val argsExprs: Seq[Expr[Any]] = args match
    case Varargs(ae) => ae
//  val stringExprs: Seq[Expr[String]] = sc match
//    case '{ StringContext(${ Varargs(strings) }: _*) } => strings

  val interpolatedVarargs = Varargs(argsExprs.map {
    case '{ $arg: SqlLiteral } => '{ $arg.queryRepr }
    case '{ $arg: tp } =>
      val codecExpr = summonWriter[tp]
      '{ $codecExpr.queryRepr }
  })

  val paramExprs = argsExprs.filter {
    case '{ $arg: SqlLiteral } => false
    case _                     => true
  }

  val queryExpr = '{ $sc.s($interpolatedVarargs: _*) }
  val exprParams = Expr.ofSeq(paramExprs)

  '{
    val argValues = $exprParams
    val writer: FragWriter = (ps: PreparedStatement, pos: Int) => {
      ${ sqlWriter('{ ps }, '{ pos }, '{ argValues }, paramExprs, '{ 0 }) }
    }
    Frag($queryExpr, argValues, writer)
  }
end sqlImpl

private def sqlWriter(
    psExpr: Expr[PreparedStatement],
    posExpr: Expr[Int],
    args: Expr[Seq[Any]],
    argsExprs: Seq[Expr[Any]],
    iExpr: Expr[Int]
)(using Quotes): Expr[Int] =
  import quotes.reflect.*
  argsExprs match
    case head +: tail =>
      head match
        case '{ $arg: tp } =>
          val codecExpr = summonWriter[tp]
          '{
            val i = $iExpr
            val argValue = $args(i).asInstanceOf[tp]
            val pos = $posExpr
            val codec = $codecExpr
            codec.writeSingle(argValue, $psExpr, pos)
            val newPos = pos + codec.cols.length
            val newI = i + 1
            ${ sqlWriter(psExpr, '{ newPos }, args, tail, '{ newI }) }
          }
        case _ =>
          report.errorAndAbort("Args must be explicit", head)
    case Seq() => posExpr
end sqlWriter

private def summonWriter[T: Type](using Quotes): Expr[DbCodec[T]] =
  import quotes.reflect.*

  Expr
    .summon[DbCodec[T]]
    .orElse(
      TypeRepr.of[T].widen.asType match
        case '[tpe] =>
          Expr
            .summon[DbCodec[tpe]]
            .map(codec => '{ $codec.asInstanceOf[DbCodec[T]] })
    )
    .getOrElse:
      report.info(
        s"Could not find given DbCodec for ${TypeRepr.of[T].show}. Using PreparedStatement::setObject instead."
      )
      '{ DbCodec.AnyCodec.asInstanceOf[DbCodec[T]] }

def batchUpdate[T](values: Iterable[T])(f: T => Update)(using
    con: DbCon
): BatchUpdateResult =
  val it = values.iterator
  if !it.hasNext then return BatchUpdateResult.Success(0)
  val firstUpdate = f(it.next())
  val firstFrag = firstUpdate.frag

  Using.Manager(use =>
    val ps = use(con.connection.prepareStatement(firstFrag.sqlString))
    firstFrag.writer.write(ps, 1)
    ps.addBatch()

    while it.hasNext do
      val frag = f(it.next()).frag
      assert(
        frag.sqlString == firstFrag.sqlString,
        "all queries must be the same for batch PreparedStatement"
      )
      frag.writer.write(ps, 1)
      ps.addBatch()
    batchUpdateResult(ps.executeBatch())
  ) match
    case Success(res) => res
    case Failure(t) =>
      throw SqlException(
        con.sqlLogger.exceptionMsg(
          SqlExceptionEvent(firstFrag.sqlString, firstFrag.params, t)
        ),
        t
      )
  end match
end batchUpdate

private val Log = System.getLogger("com.augustnagro.magnum")

private def parseParams(params: Any): Iterator[Iterator[Any]] =
  params match
    case p: Product => Iterator(p.productIterator)
    case it: Iterable[?] =>
      it.headOption match
        case Some(h: Product) =>
          it.asInstanceOf[Iterable[Product]]
            .iterator
            .map(_.productIterator)
        case _ =>
          Iterator(it.iterator)
    case x => Iterator(Iterator(x))

private def paramsString(params: Iterator[Iterator[Any]]): String =
  params.map(_.mkString("(", ", ", ")")).mkString("", ",\n", "\n")

private def timed[T](f: => T): (T, FiniteDuration) =
  val start = System.currentTimeMillis()
  val res = f
  val execTime = FiniteDuration(
    System.currentTimeMillis() - start,
    TimeUnit.MILLISECONDS
  )
  (res, execTime)

private def batchUpdateResult(updateCounts: Array[Int]): BatchUpdateResult =
  boundary:
    val updatedRows = updateCounts.foldLeft(0L)((res, c) =>
      c match
        case rowCount if rowCount >= 0 =>
          res + rowCount
        case Statement.SUCCESS_NO_INFO =>
          boundary.break(BatchUpdateResult.SuccessNoInfo)
        case errorCode =>
          throw RuntimeException(s"Received JDBC error code $errorCode")
    )
    BatchUpdateResult.Success(updatedRows)

private def assertECIsSubsetOfE[EC: Type, E: Type](using Quotes): Unit =
  import quotes.reflect.*
  val eRepr = TypeRepr.of[E]
  val ecRepr = TypeRepr.of[EC]
  val eFields = eRepr.typeSymbol.caseFields
  val ecFields = ecRepr.typeSymbol.caseFields

  for ecField <- ecFields do
    if !eFields.exists(f =>
        f.name == ecField.name &&
          f.signature.resultSig == ecField.signature.resultSig
      )
    then
      report.error(
        s"""${ecRepr.show} must be an effective subset of ${eRepr.show}.
           |Are there any fields on ${ecRepr.show} you forgot to update on ${eRepr.show}?
           |""".stripMargin
      )

private def tableExprs[EC: Type, E: Type, ID: Type](using
    Quotes
): TableExprs =
  import quotes.reflect.*
  assertECIsSubsetOfE[EC, E]

  val idIndex = idAnnotIndex[E]
  val table: Expr[Table] =
    DerivingUtil.tableAnnot[E] match
      case Some(table) => table
      case None =>
        report.errorAndAbort(
          s"${TypeRepr.of[E].show} must have @Table annotation"
        )
  val nameMapper: Expr[SqlNameMapper] = '{ $table.nameMapper }

  Expr.summon[Mirror.Of[E]] match
    case Some('{
          $eMirror: Mirror.Of[E] {
            type MirroredLabel = eLabel
            type MirroredElemLabels = eMels
          }
        }) =>
      Expr.summon[Mirror.Of[EC]] match
        case Some('{
              $ecMirror: Mirror.Of[EC] {
                type MirroredElemLabels = ecMels
              }
            }) =>
          val tableNameScala = Type.valueOfConstant[eLabel].get.toString
          val tableNameScalaExpr = Expr(tableNameScala)
          val tableNameSql = DerivingUtil.sqlTableNameAnnot[E] match
            case Some(sqlName) => '{ $sqlName.name }
            case None => '{ $nameMapper.toTableName($tableNameScalaExpr) }
          val eElemNames = elemNames[eMels]()
          val eElemNamesSql = eElemNames.map(elemName =>
            sqlNameAnnot[E](elemName) match
              case Some(sqlName) => '{ $sqlName.name }
              case None =>
                '{ $nameMapper.toColumnName(${ Expr(elemName) }) }
          )
          val ecElemNames = elemNames[ecMels]()
          val ecElemNamesSql = ecElemNames.map(elemName =>
            sqlNameAnnot[E](elemName) match
              case Some(sqlName) => '{ $sqlName.name }
              case None =>
                '{ $nameMapper.toColumnName(${ Expr(elemName) }) }
          )
          TableExprs(
            table,
            tableNameScalaExpr,
            tableNameSql,
            eElemNames,
            eElemNamesSql,
            ecElemNames,
            ecElemNamesSql,
            idIndex
          )
        case _ =>
          report.errorAndAbort(
            s"A Mirror is required to derive RepoDefaults for ${TypeRepr.of[EC].show}"
          )
    case _ =>
      report.errorAndAbort(
        s"A Mirror is required to derive RepoDefaults for ${TypeRepr.of[E].show}"
      )
  end match
end tableExprs

private def idAnnotIndex[E: Type](using q: Quotes): Expr[Int] =
  import q.reflect.*
  val idAnnot = TypeRepr.of[Id].typeSymbol
  val index = TypeRepr
    .of[E]
    .typeSymbol
    .primaryConstructor
    .paramSymss
    .head
    .indexWhere(sym => sym.hasAnnotation(idAnnot)) match
    case -1 => 0
    case x  => x
  Expr(index)

private def elemNames[Mels: Type](res: List[String] = Nil)(using
    Quotes
): List[String] =
  import quotes.reflect.*
  Type.of[Mels] match
    case '[mel *: melTail] =>
      val melString = Type.valueOfConstant[mel].get.toString
      elemNames[melTail](melString :: res)
    case '[EmptyTuple] =>
      res.reverse

private def sqlNameAnnot[T: Type](elemName: String)(using
    Quotes
): Option[Expr[SqlName]] =
  import quotes.reflect.*
  val annot = TypeRepr.of[SqlName].typeSymbol
  TypeRepr
    .of[T]
    .typeSymbol
    .primaryConstructor
    .paramSymss
    .head
    .find(sym => sym.name == elemName && sym.hasAnnotation(annot))
    .flatMap(sym => sym.getAnnotation(annot))
    .map(term => term.asExprOf[SqlName])

private def handleQuery[A](sql: String, params: Any)(
    attempt: Try[(A, FiniteDuration)]
)(using con: DbCon): A =
  attempt match
    case Success((res, execTime)) =>
      con.sqlLogger.log(SqlSuccessEvent(sql, params, execTime))
      res
    case Failure(t) =>
      val msg = con.sqlLogger.exceptionMsg(SqlExceptionEvent(sql, params, t))
      throw SqlException(msg, t)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy