All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ru.yandex.mysqlDiff.util.jdbc.scala Maven / Gradle / Ivy

package ru.yandex.mysqlDiff.util

import java.sql._
import javax.sql.DataSource
import scala.collection.mutable.ArrayBuffer

trait LiteDataSource {
    def openConnection(): Connection
    def closeConnection(c: Connection) = c.close()
    /** Close data source */
    def close() = ()
}

object LiteDataSource extends Logging {
    def apply(f: () => Connection): LiteDataSource = new LiteDataSource {
        override def openConnection() = f()
    }
    
    def apply(ds: DataSource): LiteDataSource = apply(() => ds.getConnection())
    
    def driverManager(url: String, user: String, password: String): LiteDataSource = {
        require(url != null && url.length > 0, "url must be not empty")
        apply(() => DriverManager.getConnection(url, user, password))
    }
    
    def driverManager(url: String): LiteDataSource = {
        require(url != null && url.length > 0, "url must be not empty")
        apply(() => DriverManager.getConnection(url))
    }
    
    def singleConnection(c: Connection) = new SingleConnectionLiteDataSource(c)
}

class SingleConnectionLiteDataSource(c: Connection) extends LiteDataSource {
    override def openConnection() = c
    override def closeConnection(c: Connection) = ()
}

class CacheConnectionLiteDataSource(ds: LiteDataSource) extends LiteDataSource {
    private var cached: Option[Connection] = None
    
    override def openConnection() = synchronized {
        cached match {
            case Some(c) => c
            case None =>
                val c = ds.openConnection()
                cached = Some(c)
                c
        }
    }
    
    override def closeConnection(c: Connection) = ()
    
    override def close() = synchronized {
        cached match {
            case Some(c) =>
                ds.closeConnection(c)
                cached = None
            case None =>
        }
    }
}

trait JdbcOperations extends Logging {
    val ds: LiteDataSource
    import ds._
    
    // copy from scalax.resource
    def foreach(f: Connection => Unit): Unit = acquireFor(f)
    def flatMap[B](f: Connection => B): B = acquireFor(f)
    def map[B](f: Connection => B): B = acquireFor(f)

    /** Acquires the resource for the duration of the supplied function. */
    private def acquireFor[B](f: Connection => B): B = {
        val c = openConnection()
        try {
            f(c)
        } finally {
            closeConnectionQuietly(c)
        }
    }
    
    trait Query {
        def prepareStatement(conn: Connection): PreparedStatement
        
        def execute[T](rse: ResultSet => T) =
            JdbcOperations.this.execute { conn =>
                val ps = prepareStatement(conn)
                try {
                    val rs = ps.executeQuery()
                    rse(rs)
                } finally {
                    close(ps)
                }
            }
        
        private def read[T](rs: ResultSet, rm: ResultSet => T): Seq[T] = {
            val r = new ArrayBuffer[T]
            while (rs.next()) {
                r += rm(rs)
            }
            r
        }
        
        def seq[T](rm: ResultSet => T): Seq[T] =
            execute { rs => read(rs, rm) }
        
        private def singleColumnSeq[T](rm: ResultSet => T): Seq[T] = execute { rs =>
            if (rs.getMetaData.getColumnCount != 1) throw new Exception("expecting single column") // XXX
            read(rs, rm)
        }
        
        def ints(): Seq[Int] = singleColumnSeq { rs => rs.getInt(1) }
        
        private class SqlResultSeq[T](s: Seq[T]) {
            def singleRow: T = {
                if (s.length != 1) throw new Exception("expecting 1 row")
                s.first
            }
        }
        
        implicit private def sqlResultSeq[T](s: Seq[T]) = new SqlResultSeq[T](s)
        
        def single[T](rm: ResultSet => T): T = seq(rm).singleRow
        
        def int() = ints().singleRow
    }
    
    /** Close quietly */
    private def close(o: { def close() }) {
        if (o != null)
            try {
                o.close()
            } catch {
                case e => logger.warn("failed to close something: " + e, e)
            }
    }
    
    def closeConnectionQuietly(c: Connection) =
        try {
            closeConnection(c)
        } catch {
            case e => logger.warn("failed to close connection: " + e, e)
        }
    
    def execute[T](ccb: Connection => T): T = {
        val conn = openConnection()
        try {
            ccb(conn)
        } finally {
            closeConnectionQuietly(conn)
        }
    }
    
    def execute[T](psc: Connection => PreparedStatement, f: PreparedStatement => T): T = {
        execute { conn =>
            val ps = psc(conn)
            try {
                f(ps)
            } finally {
                close(ps)
            }
        }
    }
    
    def execute(q: String) {
        execute { conn =>
            conn.createStatement().execute(q)
        }
    }
    
    class ParamsQuery(q: String, params: Any*) extends Query {
        override def prepareStatement(conn: Connection) = {
            val ps = conn.prepareStatement(q)
            for ((value, i) <- params.toList.zipWithIndex) {
                val c = i + 1
                value match {
                    case s: String => ps.setString(c, s)
                    // XXX: add more
                    case x => ps.setObject(c, x)
                }
            }
            ps
        }
    }
    
    def query(q: String, params: Any*) = new ParamsQuery(q, params: _*)
}

class JdbcTemplate(override val ds: LiteDataSource) extends JdbcOperations {
    
    def this(ds: () => Connection) = this(LiteDataSource(ds))
    def this(ds: DataSource) = this(LiteDataSource(ds))

}

// vim: set ts=4 sw=4 et:




© 2015 - 2024 Weber Informatics LLC | Privacy Policy