scalikejdbc.StatementExecutor.scala Maven / Gradle / Ivy
package scalikejdbc
import java.sql.PreparedStatement
import org.slf4j.LoggerFactory
import scala.language.reflectiveCalls
import scala.util.control.NonFatal
import JavaUtilDateConverterImplicits._
/**
* Companion object.
*/
object StatementExecutor {
val eol: String = System.getProperty("line.separator")
private trait Executor {
def apply[A](execute: () => A): A
}
private class NakedExecutor extends Executor {
override def apply[A](execute: () => A): A = execute()
}
private val LocalDateEpoch = java.time.LocalDate.ofEpochDay(0)
object PrintableQueryBuilder extends PrintableQueryBuilder {
// Find ? placeholders, but ignore ?? because that's an escaped question mark.
private val substituteRegex = "(? null
case ParameterBinder(v) => normalize(v)
case None => null
case Some(p) => normalize(p)
case p: String => p
case p: java.util.Date => p.toSqlTimestamp.toString
case p =>
ClassNameUtil.getClassName(param.getClass) match {
case "org.joda.time.DateTime" =>
param.asInstanceOf[{ def toDate: java.util.Date }].toDate.toSqlTimestamp.toString
case "org.joda.time.LocalDateTime" =>
param.asInstanceOf[{ def toDate: java.util.Date }].toDate.toSqlTimestamp
case "org.joda.time.LocalDate" =>
param.asInstanceOf[{ def toDate: java.util.Date }].toDate.toSqlDate
case "org.joda.time.LocalTime" =>
val millis = param.asInstanceOf[{ def toDateTimeToday: { def getMillis: Long } }].toDateTimeToday.getMillis
new java.sql.Time(millis)
case _ => p
}
}
}
(normalize(param) match {
case null => "null"
case result: String =>
settingsProvider.loggingSQLAndTime(GlobalSettings.loggingSQLAndTime).maxColumnSize.collect {
case maxSize if result.length > maxSize =>
"'" + result.take(maxSize) + "... (" + result.length + ")" + "'"
}.getOrElse {
"'" + result + "'"
}
case result => result.toString
}).replaceAll("\r", "\\\\r")
.replaceAll("\n", "\\\\n")
}
var i = 0
@annotation.tailrec
def trimSpaces(s: String, i: Int = 0): String = i match {
case i if i > 10 => s
case i => trimSpaces(s.replaceAll(" ", " "), i + 1)
}
val sqlWithPlaceholders = trimSpaces(SQLTemplateParser.trimComments(template)
.replaceAll("[\r\n\t]", " "))
sqlWithPlaceholders.split('\'').zipWithIndex.map {
// Even numbered parts are outside quotes, odd numbered are inside
case (target, quoteCount) if (quoteCount % 2 == 0) => {
substituteRegex.replaceAllIn(target, m => {
i += 1
if (params.size >= i) {
toPrintable(params(i - 1)).replace("\\", "\\\\").replace("$", "\\$")
} else {
// In this case, SQLException will be thrown later.
// At least, throwing java.lang.IndexOutOfBoundsException here is meaningless.
m.source.toString()
}
})
}
case (s, quoteCount) if (quoteCount % 2 == 1) =>
// If the statement is valid, we can always expect an odd number of elements
// Thus, we can add two quotes here.
"'" + s + "'"
case (s, _) => s
}.mkString
} catch {
case NonFatal(e) =>
log.debug(s"Failed to build a printable SQL statement with ${template}, params: ${params}", e)
template
}
}
}
}
/**
* java.sql.Statement Executor.
*
* @param underlying preparedStatement
* @param template SQL template
* @param singleParams parameters for single execution (= not batch execution)
* @param isBatch is batch flag
*/
case class StatementExecutor(
underlying: PreparedStatement,
template: String,
connectionAttributes: DBConnectionAttributes,
singleParams: collection.Seq[Any] = Nil,
tags: collection.Seq[String] = Nil,
isBatch: Boolean = false,
settingsProvider: SettingsProvider = SettingsProvider.default) extends LogSupport with UnixTimeInMillisConverterImplicits with AutoCloseable {
import StatementExecutor._
private[this] lazy val batchParamsList = scala.collection.mutable.ArrayBuffer.empty[collection.Seq[Any]]
initialize()
/**
* Initializes this instance.
*/
private def initialize(): Unit = {
bindParams(singleParams)
if (isBatch) {
batchParamsList.clear()
}
}
/**
* Binds parameters to the underlying java.sql.PreparedStatement object.
* @param params parameters
*/
def bindParams(params: collection.Seq[Any]): Unit = {
val paramsWithIndices = params.map {
case option: Option[_] => option.orNull[Any]
case other => other
}.zipWithIndex
for ((param, idx) <- paramsWithIndices) {
bind(param, idx + 1)
}
if (isBatch) {
batchParamsList += params
}
}
@annotation.tailrec
private[this] def bind(param: Any, i: Int): Unit = {
param match {
case null => underlying.setObject(i, null)
case AsIsParameterBinder(None) => bind(null, i)
case AsIsParameterBinder(Some(value)) => bind(value, i)
case AsIsParameterBinder(value) => bind(value, i)
case binder: ParameterBinder => binder(underlying, i)
case p: java.sql.Array => underlying.setArray(i, p)
case p: BigDecimal => underlying.setBigDecimal(i, p.bigDecimal)
case p: BigInt => underlying.setBigDecimal(i, new java.math.BigDecimal(p.bigInteger))
case p: Boolean => underlying.setBoolean(i, p)
case p: Byte => underlying.setByte(i, p)
case p: java.sql.Date => underlying.setDate(i, p)
case p: Double => underlying.setDouble(i, p)
case p: Float => underlying.setFloat(i, p)
case p: Int => underlying.setInt(i, p)
case p: Long => underlying.setLong(i, p)
case p: Short => underlying.setShort(i, p)
case p: java.sql.SQLXML => underlying.setSQLXML(i, p)
case p: String => underlying.setString(i, p)
case p: java.sql.Time => underlying.setTime(i, p)
case p: java.sql.Timestamp => underlying.setTimestamp(i, p)
case p: java.net.URL => underlying.setURL(i, p)
case p: java.util.Date => underlying.setTimestamp(i, p.toSqlTimestamp)
case p: java.time.ZonedDateTime =>
underlying.setTimestamp(i, java.sql.Timestamp.from(p.toInstant))
case p: java.time.OffsetDateTime =>
underlying.setTimestamp(i, java.sql.Timestamp.from(p.toInstant))
case p: java.time.Instant =>
underlying.setTimestamp(i, java.sql.Timestamp.from(p))
case p: java.time.LocalDateTime =>
underlying.setTimestamp(i, java.sql.Timestamp.valueOf(p))
case p: java.time.LocalDate =>
underlying.setDate(i, java.sql.Date.valueOf(p))
case p: java.time.LocalTime =>
val millis = p.atDate(StatementExecutor.LocalDateEpoch).atZone(java.time.ZoneId.systemDefault).toInstant.toEpochMilli
val time = new java.sql.Time(millis)
underlying.setTime(i, time)
case p: java.io.InputStream => underlying.setBinaryStream(i, p)
case p =>
ClassNameUtil.getClassName(param.getClass) match {
case "org.joda.time.DateTime" =>
val t = p.asInstanceOf[{ def toDate: java.util.Date }].toDate.toSqlTimestamp
underlying.setTimestamp(i, t)
case "org.joda.time.LocalDateTime" =>
val t = p.asInstanceOf[{ def toDate: java.util.Date }].toDate.toSqlTimestamp
underlying.setTimestamp(i, t)
case "org.joda.time.LocalDate" =>
val t = p.asInstanceOf[{ def toDate: java.util.Date }].toDate.toSqlDate
underlying.setDate(i, t)
case "org.joda.time.LocalTime" =>
val millis = p.asInstanceOf[{ def toDateTimeToday: { def getMillis: Long } }].toDateTimeToday.getMillis
underlying.setTime(i, new java.sql.Time(millis))
case _ =>
log.debug("The parameter(" + p + ") is bound as an Object.")
underlying.setObject(i, p)
}
}
}
/**
* SQL String value
*/
private[this] lazy val sqlString: String = {
def singleSqlString(params: collection.Seq[Any]): String = {
val sql = PrintableQueryBuilder.build(template, settingsProvider, params)
try {
settingsProvider.sqlFormatter(GlobalSettings.sqlFormatter).formatter match {
case Some(formatter) =>
formatter.format(sql)
case None =>
sql
}
} catch {
case e: Exception =>
log.debug("Caught an exception when formatting SQL because of " + e.getMessage)
sql
}
}
if (isBatch) {
settingsProvider.loggingSQLAndTime(GlobalSettings.loggingSQLAndTime).maxBatchParamSize.collect {
case maxSize if batchParamsList.size > maxSize =>
batchParamsList.take(maxSize).map(params => singleSqlString(params)).mkString(";" + eol + " ") + ";" + eol +
" ... (total: " + batchParamsList.size + " times)"
}.getOrElse {
batchParamsList.map(params => singleSqlString(params)).mkString(";" + eol + " ")
}
} else {
singleSqlString(singleParams)
}
}
/**
* Returns stack trace information as String value
* @return stack trace
*/
private[this] def stackTraceInformation: String = {
val loggingSQLAndTime = settingsProvider.loggingSQLAndTime(GlobalSettings.loggingSQLAndTime)
val stackTrace = Thread.currentThread.getStackTrace
val lines = (if (loggingSQLAndTime.printUnprocessedStackTrace) {
stackTrace.tail
} else {
stackTrace.dropWhile { trace =>
val className = trace.getClassName
className != getClass.toString &&
(className.startsWith("java.lang.") || className.startsWith("scalikejdbc."))
}
}).take(loggingSQLAndTime.stackTraceDepth).map { trace => " " + trace.toString }
" [Stack Trace]" + eol +
" ..." + eol +
lines.mkString(eol) + eol +
" ..." + eol
}
/**
* Logging SQL and timing (this trait depends on this instance)
*/
private[this] trait LoggingSQLAndTiming extends Executor with LogSupport {
abstract override def apply[A](execute: () => A): A = {
val loggingSQLAndTime = settingsProvider.loggingSQLAndTime(GlobalSettings.loggingSQLAndTime)
def messageInSingleLine(spentMillis: Long): String = "[SQL Execution] " + sqlString + "; (" + spentMillis + " ms)"
def messageInMultiLines(spentMillis: Long): String = {
"SQL execution completed" + eol +
eol +
" [SQL Execution]" + eol +
" " + sqlString + "; (" + spentMillis + " ms)" + eol +
eol +
stackTraceInformation
}
val before = System.currentTimeMillis
val result = super.apply(execute)
val after = System.currentTimeMillis
val spentMillis = after - before
// logging SQL and time
if (loggingSQLAndTime.enabled) {
if (loggingSQLAndTime.warningEnabled &&
spentMillis >= loggingSQLAndTime.warningThresholdMillis) {
if (loggingSQLAndTime.singleLineMode) {
log.withLevel(loggingSQLAndTime.warningLogLevel)(messageInSingleLine(spentMillis))
} else {
log.withLevel(loggingSQLAndTime.warningLogLevel)(messageInMultiLines(spentMillis))
}
} else {
if (loggingSQLAndTime.singleLineMode) {
log.withLevel(loggingSQLAndTime.logLevel)(messageInSingleLine(spentMillis))
} else {
log.withLevel(loggingSQLAndTime.logLevel)(messageInMultiLines(spentMillis))
}
}
}
// call event handler
settingsProvider.queryCompletionListener(GlobalSettings.queryCompletionListener).apply(template, singleParams, spentMillis)
settingsProvider.taggedQueryCompletionListener(GlobalSettings.taggedQueryCompletionListener).apply(template, singleParams, spentMillis, tags)
// result from super.apply()
result
}
}
private[this] trait LoggingSQLIfFailed extends Executor with LogSupport {
abstract override def apply[A](execute: () => A): A = try {
super.apply(execute)
} catch {
case e: Exception =>
if (settingsProvider.loggingSQLErrors(GlobalSettings.loggingSQLErrors)) {
if (settingsProvider.loggingSQLAndTime(GlobalSettings.loggingSQLAndTime).singleLineMode) {
log.error("[SQL Execution Failed] " + sqlString + " (Reason: " + e.getMessage + ")")
} else {
log.error("SQL execution failed (Reason: " + e.getMessage + "):" + eol + eol + " " + sqlString + eol)
}
} else {
log.debug("Logging SQL errors is disabled.")
}
// call event handler
settingsProvider.queryFailureListener(GlobalSettings.queryFailureListener).apply(template, singleParams, e)
settingsProvider.taggedQueryFailureListener(GlobalSettings.taggedQueryFailureListener).apply(template, singleParams, e, tags)
throw e
}
}
/**
* Executes SQL statement
*/
private[this] val statementExecute = new NakedExecutor with LoggingSQLAndTiming with LoggingSQLIfFailed
def generatedKeysResultSet: java.sql.ResultSet = underlying.getGeneratedKeys
def addBatch(): Unit = underlying.addBatch()
def execute(): Boolean = statementExecute(() => underlying.execute())
def execute(x1: String): Boolean = statementExecute(() => underlying.execute(x1))
def execute(x1: String, x2: Array[Int]): Boolean = statementExecute(() => underlying.execute(x1, x2))
def execute(x1: String, x2: Array[String]): Boolean = statementExecute(() => underlying.execute(x1, x2))
def execute(x1: String, x2: Int): Boolean = statementExecute(() => underlying.execute(x1, x2))
def executeBatch(): Array[Int] = statementExecute(() => underlying.executeBatch())
def executeLargeBatch(): Array[Long] = statementExecute(() => underlying.executeLargeBatch())
def executeQuery(): java.sql.ResultSet = statementExecute(() => underlying.executeQuery())
def executeQuery(x1: String): java.sql.ResultSet = statementExecute(() => underlying.executeQuery(x1))
def executeUpdate(): Int = statementExecute(() => underlying.executeUpdate())
def executeUpdate(x1: String): Int = statementExecute(() => underlying.executeUpdate(x1))
def executeUpdate(x1: String, x2: Array[Int]): Int = statementExecute(() => underlying.executeUpdate(x1, x2))
def executeUpdate(x1: String, x2: Array[String]): Int = statementExecute(() => underlying.executeUpdate(x1, x2))
def executeUpdate(x1: String, x2: Int): Int = statementExecute(() => underlying.executeUpdate(x1, x2))
def executeLargeUpdate(): Long = statementExecute(() => underlying.executeLargeUpdate())
def executeLargeUpdate(sql: String): Long = statementExecute(() => underlying.executeLargeUpdate(sql))
def executeLargeUpdate(sql: String, columnIndexes: Array[Int]): Long = statementExecute(() => underlying.executeLargeUpdate(sql, columnIndexes))
def executeLargeUpdate(sql: String, columnNames: Array[String]): Long = statementExecute(() => underlying.executeLargeUpdate(sql, columnNames))
def executeLargeUpdate(sql: String, autoGeneratedKeys: Int): Long = statementExecute(() => underlying.executeLargeUpdate(sql, autoGeneratedKeys))
def close(): Unit = underlying.close()
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy