All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalikejdbc.DBConnection.scala Maven / Gradle / Ivy

The newest version!
package scalikejdbc

import java.sql.{ DatabaseMetaData, Connection }
import scalikejdbc.metadata._
import scala.concurrent.{ ExecutionContext, Future }
import scala.util.control.Exception._
import scala.util.control.ControlThrowable
import java.util.Locale.{ ENGLISH => en }

/**
 * Basic Database Accessor which holds a JDBC connection.
 */
trait DBConnection extends LogSupport with LoanPattern {

  type RSTraversable = ResultSetTraversable

  /**
   * Connection wil be closed automatically by default.
   */
  private[this] var autoCloseEnabled: Boolean = true

  /**
   * Provides default TxBoundary type class instance.
   */
  private[this] def defaultTxBoundary[A]: TxBoundary[A] = TxBoundary.Exception.exceptionTxBoundary[A]

  /**
   * Switches auto close mode.
   * @param autoClose auto close enabled if true
   */
  def autoClose(autoClose: Boolean): DBConnection = {
    this.autoCloseEnabled = autoClose
    this
  }

  /**
   * returns the additional attributes of current JDBC connection.
   */
  def connectionAttributes: DBConnectionAttributes = DBConnectionAttributes()

  /**
   * Returns current JDBC connection.
   */
  def conn: Connection

  /**
   * Returns is the current transaction is active.
   * @return result
   */
  def isTxNotActive: Boolean = {
    if (GlobalSettings.jtaDataSourceCompatible) {
      // JTA managed connection should be used as-is
      false
    } else {
      conn == null || conn.isClosed || conn.isReadOnly
    }
  }

  /**
   * Returns is the current transaction hasn't started yet.
   * @return result
   */
  def isTxNotYetStarted: Boolean = {
    if (GlobalSettings.jtaDataSourceCompatible) {
      // JTA managed connection should be used as-is
      false
    } else {
      conn != null && conn.getAutoCommit
    }
  }

  /**
   * Returns is the current transaction already started.
   * @return result
   */
  def isTxAlreadyStarted: Boolean = {
    if (GlobalSettings.jtaDataSourceCompatible) {
      true
    } else {
      conn != null && !conn.getAutoCommit
    }
  }

  private[this] def setAutoCommit(conn: Connection, readOnly: Boolean): Unit = {
    if (!GlobalSettings.jtaDataSourceCompatible) conn.setAutoCommit(readOnly)
  }

  private[this] def setReadOnly(conn: Connection, readOnly: Boolean): Unit = {
    if (!GlobalSettings.jtaDataSourceCompatible) conn.setReadOnly(readOnly)
  }

  private[this] def newTx(conn: Connection): Tx = {
    setReadOnly(conn, false)
    if (!GlobalSettings.jtaDataSourceCompatible && (isTxNotActive || isTxAlreadyStarted)) {
      throw new IllegalStateException(ErrorMessage.CANNOT_START_A_NEW_TRANSACTION)
    }
    new Tx(conn)
  }

  /**
   * Starts a new transaction and returns it.
   * @return tx
   */
  def newTx: Tx = newTx(conn)

  /**
   * Returns the current transaction.
   * If the transaction has not started yet, IllegalStateException will be thrown.
   * @return tx
   */
  def currentTx: Tx = {
    if (!GlobalSettings.jtaDataSourceCompatible && (isTxNotActive || isTxNotYetStarted)) {
      throw new IllegalStateException(ErrorMessage.TRANSACTION_IS_NOT_ACTIVE)
    }
    new Tx(conn)
  }

  /**
   * Returns the current transaction.
   * If the transaction has not started yet, IllegalStateException will be thrown.
   * @return tx
   */
  def tx: Tx = {
    handling(classOf[IllegalStateException]) by { e =>
      throw new IllegalStateException(
        ErrorMessage.TRANSACTION_IS_NOT_ACTIVE + " If you want to start a new transaction, use #newTx instead."
      )
    } apply currentTx
  }

  /**
   * Close the connection.
   */
  def close(): Unit = {
    ignoring(classOf[Throwable]) {
      conn.close()
    }
    log.debug("A Connection is closed.")
  }

  /**
   * Begins a new transaction.
   */
  def begin(): Unit = newTx.begin()

  /**
   * Begins a new transaction if the other one does not already start.
   */
  def beginIfNotYet(): Unit = {
    ignoring(classOf[IllegalStateException]) apply {
      begin()
    }
  }

  /**
   * Commits the current transaction.
   */
  def commit(): Unit = tx.commit()

  /**
   * Rolls back the current transaction.
   */
  def rollback(): Unit = tx.rollback()

  /**
   * Rolls back the current transaction if the transaction is still active.
   */
  def rollbackIfActive(): Unit = {
    ignoring(classOf[IllegalStateException]) apply {
      tx.rollbackIfActive()
    }
  }

  /**
   * Returns read-only session.
   * @return session
   */
  def readOnlySession(): DBSession = {
    setReadOnly(conn, true)
    DBSession(conn, isReadOnly = true, connectionAttributes = connectionAttributes)
  }

  /**
   * Provides read-only session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def readOnly[A](execution: DBSession => A): A = {
    if (autoCloseEnabled) using(conn)(_ => execution(readOnlySession()))
    else execution(readOnlySession())
  }

  /**
   * Provides read-only session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def readOnlyWithConnection[A](execution: Connection => A): A = {
    readOnly(s => execution(s.conn))
  }

  /**
   * Returns auto-commit session.
   * @return session
   */
  def autoCommitSession(): DBSession = {
    setReadOnly(conn, false)
    setAutoCommit(conn, true)
    DBSession(conn, connectionAttributes = connectionAttributes)
  }

  /**
   * Provides auto-commit session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def autoCommit[A](execution: DBSession => A): A = {
    if (autoCloseEnabled) using(conn)(_ => execution(autoCommitSession()))
    else execution(autoCommitSession())
  }

  /**
   * Provides auto-commit session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def autoCommitWithConnection[A](execution: Connection => A): A = {
    autoCommit(s => execution(s.conn))
  }

  /**
   * Returns within-tx session.
   * @return session
   */
  def withinTxSession(tx: Tx = currentTx): DBSession = {
    if (!GlobalSettings.jtaDataSourceCompatible && !tx.isActive) {
      throw new IllegalStateException(ErrorMessage.TRANSACTION_IS_NOT_ACTIVE)
    }
    DBSession(conn, tx = Some(tx), connectionAttributes = connectionAttributes)
  }

  /**
   * Provides within-tx session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def withinTx[A](execution: DBSession => A): A = {
    execution(withinTxSession(currentTx))
  }

  /**
   * Provides within-tx session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def withinTxWithConnection[A](execution: Connection => A): A = {
    withinTx(s => execution(s.conn))
  }

  private[this] def begin(tx: Tx): Unit = {
    tx.begin()
    if (!GlobalSettings.jtaDataSourceCompatible && !tx.isActive) {
      throw new IllegalStateException(ErrorMessage.TRANSACTION_IS_NOT_ACTIVE)
    }
  }

  private[this] def rollbackIfThrowable[A](f: => A): A = try {
    f
  } catch {
    case e: ControlThrowable =>
      tx.commit()
      throw e
    case originalException: Throwable =>
      try {
        tx.rollback()
      } catch {
        case rollbackException: Throwable => {
          // log rollback exception for application operators
          log.error("Could not successfully complete local transaction", rollbackException)
        }
      }
      // original exception is likely more valuable to calling code to act upon
      throw originalException
  }

  /**
   * Provides local-tx session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def localTx[A](execution: DBSession => A)(implicit boundary: TxBoundary[A] = defaultTxBoundary[A]): A = {
    val doClose = if (autoCloseEnabled) () => conn.close() else () => ()
    val tx = newTx
    begin(tx)
    val txResult = try {
      rollbackIfThrowable[A] {
        val session = DBSession(conn, tx = Option(tx), connectionAttributes = connectionAttributes)
        val result: A = execution(session)
        boundary.finishTx(result, tx)
      }
    } catch {
      case e: Throwable => doClose(); throw e
    }
    boundary.closeConnection(txResult, doClose)
  }

  /**
   * Easy way to checkout the current connection to be used in a transaction
   * that needs to be committed/rolled back depending on Future results.
   * @param execution block that takes a session and returns a future
   * @tparam A future result type
   * @return future result
   */
  def futureLocalTx[A](execution: DBSession => Future[A])(implicit ec: ExecutionContext): Future[A] = {
    // Enable TxBoundary implicits
    import scalikejdbc.TxBoundary.Future._
    localTx(execution)
  }

  /**
   * Provides local-tx session block.
   * @param execution block
   * @tparam A  return type
   * @return result value
   */
  def localTxWithConnection[A](execution: Connection => A)(implicit boundary: TxBoundary[A] = defaultTxBoundary[A]): A = {
    localTx(s => execution(s.conn))
  }

  /**
   * Splits the name to schema and table name
   *
   * @param name name
   * @return schema and table
   */
  private[this] def toSchemaAndTable(name: String): (String, String) = {
    val schema = {
      if (name.split("\\.").size > 1) name.split("\\.").head
      else null
    }
    val table = if (name.split("\\.").size > 1) name.split("\\.")(1) else name
    (schema, table)
  }

  /**
   * Returns all the table information that match the pattern
   *
   * @param tableNamePattern table name pattern (with schema optionally)
   * @return table information
   */
  def getTableNames(tableNamePattern: String = "%", tableTypes: Array[String] = Array("TABLE", "VIEW")): List[String] = {
    readOnlyWithConnection { conn =>
      val meta = conn.getMetaData
      getSchemaAndTableName(meta, tableNamePattern.replaceAll("\\*", "%"), tableTypes).map {
        case (schema, tableNamePattern) =>
          new RSTraversable(meta.getTables(null, schema, tableNamePattern, tableTypes))
            .map { rs =>
              val schemaName = rs.string("TABLE_SCHEM")
              if (schema != null && schema.nonEmpty && schemaName != null) {
                schemaName + "." + rs.string("TABLE_NAME")
              } else {
                rs.string("TABLE_NAME")
              }
            }.toList
      }.getOrElse(List.empty[String])
    }
  }

  /**
   * Returns all the column names on the matched table name
   */
  def getColumnNames(tableName: String, tableTypes: Array[String] = Array("TABLE", "VIEW")): List[String] = {
    readOnlyWithConnection { conn =>
      val meta = conn.getMetaData
      getSchemaAndTableName(meta, tableName, tableTypes).map {
        case (schema, tableName) =>
          new RSTraversable(meta.getColumns(null, schema, tableName, "%")).map(_.string("COLUMN_NAME")).toList.distinct
      }
    }.getOrElse(Nil)
  }

  /**
   * Returns table information if exists
   *
   * @param table table name (with schema optionally)
   * @return table information
   */
  def getTable(table: String, tableTypes: Array[String] = Array("TABLE", "VIEW")): Option[Table] = {
    readOnlyWithConnection { conn =>
      val meta = conn.getMetaData

      getSchemaAndTableName(meta, table, tableTypes).flatMap {
        case (schema, tableName) =>
          _getTable(meta, schema, tableName, tableTypes)
      }
    }
  }

  /**
   * Returns table information if exists.
   *
   * https://docs.oracle.com/javase/8/docs/api/java/sql/DatabaseMetaData.html#getIndexInfo-java.lang.String-java.lang.String-java.lang.String-boolean-boolean-
   *
   * @param meta database meta data
   * @param schema schema name
   * @param table table name
   * @param tableTypes target table types
   * @return table information
   */
  private[this] def _getTable(meta: DatabaseMetaData, schema: String, table: String, tableTypes: Array[String] = Array("TABLE", "VIEW")): Option[Table] = {
    val tableList = new RSTraversable(meta.getTables(null, schema, table, tableTypes)).map {
      rs => (rs.string("TABLE_SCHEM"), rs.string("TABLE_NAME"), rs.string("REMARKS"))
    }

    tableList.headOption.map {
      case (schema, table, remarks) =>
        val pkNames: Traversable[String] = new RSTraversable(meta.getPrimaryKeys(null, schema, table)).map(rs => rs.string("COLUMN_NAME"))

        Table(
          name = table,
          schema = schema,
          description = remarks,
          columns = new RSTraversable(meta.getColumns(null, schema, table, "%")).map { rs =>
            Column(
              name = try rs.string("COLUMN_NAME") catch { case e: ResultSetExtractorException => null },
              typeCode = try rs.int("DATA_TYPE") catch { case e: ResultSetExtractorException => -1 },
              typeName = rs.string("TYPE_NAME"),
              size = try rs.int("COLUMN_SIZE") catch { case e: ResultSetExtractorException => -1 },
              isRequired = try {
                rs.string("IS_NULLABLE") != null && rs.string("IS_NULLABLE") == "NO"
              } catch { case e: ResultSetExtractorException => false },
              isPrimaryKey = try {
                pkNames.exists(_ == rs.string("COLUMN_NAME"))
              } catch { case e: ResultSetExtractorException => false },
              isAutoIncrement = try {
                // Oracle throws java.sql.SQLException: Invalid column name
                rs.string("IS_AUTOINCREMENT") != null && rs.string("IS_AUTOINCREMENT") == "YES"
              } catch { case e: ResultSetExtractorException => false },
              description = try rs.string("REMARKS") catch { case e: ResultSetExtractorException => null },
              defaultValue = try rs.string("COLUMN_DEF") catch { case e: ResultSetExtractorException => null }
            )
          }.toList.distinct,
          foreignKeys = {
            try {
              new RSTraversable(meta.getImportedKeys(null, schema, table)).map { rs =>
                ForeignKey(
                  name = rs.string("FKCOLUMN_NAME"),
                  foreignColumnName = rs.string("PKCOLUMN_NAME"),
                  foreignTableName = rs.string("PKTABLE_NAME")
                )
              }.toList.distinct
            } catch { case e: ResultSetExtractorException => Nil }
          },
          indices = {
            try {
              new RSTraversable(meta.getIndexInfo(null, schema, table, false, true))
                .foldLeft(Map[String, Index]()) {
                  case (map, rs) =>
                    val indexName: String = rs.string("INDEX_NAME")
                    val index: Index = map.get(indexName) match {
                      case Some(idx) =>
                        rs.stringOpt("COLUMN_NAME") match {
                          case Some(columnName) => idx.copy(columnNames = idx.columnNames :+ columnName)
                          case _ => idx
                        }
                      case _ =>
                        Index(
                          name = indexName,
                          columnNames = rs.stringOpt("COLUMN_NAME").toList,
                          isUnique = !rs.boolean("NON_UNIQUE"),
                          qualifier = rs.stringOpt("INDEX_QUALIFIER"),
                          indexType = {
                            rs.shortOpt("TYPE") match {
                              case Some(t) => IndexType.from(t)
                              case _ => IndexType.tableIndexOther
                            }
                          },
                          ordinalPosition = rs.shortOpt("ORDINAL_POSITION"),
                          ascOrDesc = rs.stringOpt("ASC_OR_DESC"),
                          cardinality = rs.longOpt("CARDINALITY"),
                          pages = rs.longOpt("PAGES"),
                          filterCondition = rs.stringOpt("FILTER_CONDITION")
                        )
                    }
                    map.updated(indexName, index)
                }.map { case (k, v) => v }.toList.distinct
            } catch {
              case e: ResultSetExtractorException =>
                log.error("Failed to fetch index information", e)
                Nil
            }
          }
        )
    }
  }

  /**
   * Returns table name list
   *
   * @param tableNamePattern table name pattern
   * @param tableTypes table types
   * @return table name list
   */
  def showTables(tableNamePattern: String = "%", tableTypes: Array[String] = Array("TABLE", "VIEW")): String = {
    getTableNames(tableNamePattern, tableTypes).mkString("\n")
  }

  /**
   * Returns describe style string value for the table
   *
   * @param table table name (with schema optionally)
   * @return described information
   */
  def describe(table: String): String = {
    getTable(table).map(t => t.toDescribeStyleString).getOrElse("Not found.")
  }

  /**
   * Returns schema name and table name
   *
   * @param meta database meta data
   * @param tablePattern table name (with schema optionally)
   * @param tableTypes target table types
   * @return schema name and table name
   */
  private[this] def getSchemaAndTableName(meta: DatabaseMetaData, tablePattern: String, tableTypes: Array[String]): Option[(String, String)] = {
    def _getSchemaAndTableName(meta: DatabaseMetaData, tablePattern: String, tableTypes: Array[String]): Option[(String, String)] = {
      val (_schema, table) = toSchemaAndTable(tablePattern)
      val schema = if (meta.getURL.startsWith("jdbc:h2")) {
        // H2 Database 1.4 cannot accept null for metadata retrieving columns
        // in tables that name is same as information schema (e.g.) rules
        Option(_schema).getOrElse("")
      } else {
        _schema
      }

      if (new RSTraversable(meta.getTables(null, schema, table, tableTypes)).isEmpty) {
        None
      } else {
        Some((schema, table))
      }
    }

    _getSchemaAndTableName(meta, tablePattern, tableTypes)
      .orElse(_getSchemaAndTableName(meta, tablePattern.toUpperCase(en), tableTypes))
      .orElse(_getSchemaAndTableName(meta, tablePattern.toLowerCase(en), tableTypes))
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy