Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
ru.yandex.mysqlDiff.jdbc.JdbcModelExtractor.scala Maven / Gradle / Ivy
package ru.yandex.mysqlDiff.jdbc
import java.sql._
import model._
import scalax.control.ManagedResource
import scala.util.Sorting._
import scala.collection.mutable.ArrayBuffer
// XXX: drop it
import vendor.mysql._
import util._
class JdbcModelExtractorException(msg: String, cause: Throwable) extends Exception(msg, cause)
/*
* TBD:
* Extract table engine, default charset
* Extract keys
*/
object JdbcModelExtractor {
import JdbcUtils._
import vendor.mysql._
import MysqlContext._
// http://bugs.mysql.com/36699
private val PROPER_COLUMN_DEF_MIN_MYSQL_VERSION = MysqlServerVersion.parse("5.0.51")
class Lazy[T](create: => T) {
var value: Option[T] = None
def get = {
if (value.isEmpty) value = Some(create)
value.get
}
def isCreated = value.isDefined
}
private def groupBy[A, B](seq: Seq[A])(f: A => B): Seq[(B, Seq[A])] = {
def g(seq: Seq[(B, A)]): Seq[(B, Seq[A])] = seq match {
case Seq() => List()
case Seq((b, a)) => List((b, List(a)))
case Seq((b1, a1), rest @ _*) =>
g(rest) match {
case Seq((`b1`, l), rest @ _*) => (b1, a1 :: l.toList) :: rest.toList
case r => (b1, List(a1)) :: r.toList
}
}
g(seq.map(a => (f(a), a)))
}
object MetaDao {
/** INFORMATION_SCHEMA.COLUMNS */
case class MysqlColumnInfo(
tableCatalog: String, tableSchema: String, tableName: String,
columnName: String, ordinalPosition: Int, columnDefault: String,
isNullable: Boolean, dataType: String,
characterMaximumLength: Double, characterOctetLength: Double,
numericPrecision: Int, numericScale: Int,
characterSetName: String, collationName: String,
columnType: String, /* skipped some columns */ columnComment: String
)
private def mapColumnsRow(rs: ResultSet) = {
import rs._
MysqlColumnInfo(
getString("table_catalog"), getString("table_schema"), getString("table_name"),
getString("column_name"), getInt("ordinal_position"), getString("column_default"),
getBoolean("is_nullable"), getString("data_type"),
getDouble("character_maximum_length"), getDouble("character_octet_length"),
getInt("numeric_precision"), getInt("numeric_scale"),
getString("character_set_name"), getString("collation_name"),
getString("column_type"), getString("column_comment")
)
}
}
import MetaDao._
/**
* @deprecated
*/
class MetaDao(conn: Connection) {
// XXX: move outside: MySQL-specific
def mapTableOptions(rs: ResultSet) =
(rs.getString("TABLE_NAME"), List(TableOption("ENGINE", rs.getString("ENGINE"))))
def findTablesOptions(schema: String): Seq[(String, Seq[TableOption])] = {
val q = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ?"
val ps = conn.prepareStatement(q)
ps.setString(1, schema)
val rs = ps.executeQuery()
read(rs)(mapTableOptions _)
}
def findTableOptions(schema: String, tableName: String): Seq[TableOption] = {
val q = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?"
val ps = conn.prepareStatement(q)
ps.setString(1, schema)
ps.setString(2, tableName)
val rs = ps.executeQuery()
rs.next()
mapTableOptions(rs)._2
}
def findMysqlTablesColumns(schema: String): Seq[(String, Seq[MysqlColumnInfo])] = {
val q = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? ORDER BY table_name"
val ps = conn.prepareStatement(q)
ps.setString(1, schema)
val rs = ps.executeQuery()
val columns = read(rs)(mapColumnsRow _)
groupBy[MysqlColumnInfo, String](columns)(_.tableName)
}
def findMysqlColumns(schema: String, tableName: String) = {
val q = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ?"
val ps = conn.prepareStatement(q)
ps.setString(1, schema)
ps.setString(2, tableName)
val rs = ps.executeQuery()
read(rs)(mapColumnsRow _)
}
}
abstract class SchemaExtractor(conn: Connection) {
protected val dao = new MetaDao(conn)
protected val dao2 = new jdbc.MetaDao(LiteDataSource.singleConnection(conn))
lazy val currentDb = {
val db = conn.getMetaData.getURL.replaceFirst("\\?.*", "").replaceFirst(".*/", "")
require(db.length > 0)
db
}
def currentCatalog = currentDb
def currentSchema: String = null
def extractTable(tableName: String): TableModel = {
val data = conn.getMetaData
val columns = data.getColumns(currentCatalog, currentSchema, tableName, "%")
val mysqlColumns = getMysqlColumns(tableName)
val columnsList = read(columns) { columns =>
val colName = columns.getString("COLUMN_NAME")
val colType = columns.getString("TYPE_NAME")
def getIntOption(rs: ResultSet, columnName: String) = {
val v = rs.getInt(columnName)
if (rs.wasNull) None
else Some(v)
}
val colTypeSize = getIntOption(columns, "COLUMN_SIZE")
.filter(x => MysqlContext.dataTypes.make(colType, None, Nil).isLengthAllowed)
val nullable = columns.getString("IS_NULLABLE") match {
case "YES" => Some(Nullability(true))
case "NO" => Some(Nullability(false))
case "" => None
}
val autoIncrement = columns.getString("IS_AUTOINCREMENT") match {
case "YES" => Some(AutoIncrement(true))
case "NO" => Some(AutoIncrement(false))
case "" => None
}
val mysqlColumn = mysqlColumns.find(_.columnName == colName).get
val defaultValueFromDb =
// http://bugs.mysql.com/36699
if (true) mysqlColumn.columnDefault
else columns.getString("COLUMN_DEF")
val isUnsigned = false
val isZerofill = false
val characterSet = Some(mysqlColumn.characterSetName)
.filter(x => x != null && x != "")
.map(MysqlCharacterSet(_))
val collate = Some(mysqlColumn.collationName)
.filter(x => x != null && x != "")
.map(MysqlCollate(_))
val dataType = MysqlContext.dataTypes.make(colType, colTypeSize, Nil ++ characterSet ++ collate)
val defaultValue = parseDefaultValueFromDb(defaultValueFromDb, dataType).map(DefaultValue(_))
val props = new ColumnProperties(List[ColumnProperty]() ++ nullable ++ defaultValue ++ autoIncrement)
new ColumnModel(colName, dataType, props)
}
val pk = getPrimaryKey(tableName)
def columnExistsInPk(name: String) =
pk.exists(_.columns.exists(_ == name))
val fks = getFks(tableName)
// MySQL adds PK to indexes, so exclude
val indexes = getIndexes(tableName)
.filter(pk.isEmpty || _.columns.toList != pk.get.columns.toList)
.filter(i => !(fks.map(_.localColumns.toList) contains i.columns.toList))
new TableModel(tableName, columnsList.toList, pk, indexes ++ fks, getTableOptions(tableName))
}
def getPrimaryKey(tableName: String): Option[PrimaryKeyModel] =
dao2.findPrimaryKey(currentCatalog, currentSchema, tableName)
def getIndexes(tableName: String): Seq[IndexModel] =
dao2.findIndexes(currentCatalog, currentSchema, tableName)
def getFks(tableName: String): Seq[ForeignKeyModel] =
dao2.findImportedKeys(currentCatalog, currentSchema, tableName)
/** Not including PK */
def getKeys(tableName: String): Seq[KeyModel] =
getIndexes(tableName) ++ getFks(tableName)
def getTableOptions(tableName: String): Seq[TableOption]
def getMysqlColumns(tableName: String): Seq[MysqlColumnInfo]
}
class SingleTableSchemaExtractor(conn: Connection) extends SchemaExtractor(conn) {
override def getTableOptions(tableName: String) =
dao.findTableOptions(currentDb, tableName)
override def getMysqlColumns(tableName: String) =
dao.findMysqlColumns(currentDb, tableName)
}
class AllTablesSchemaExtractor(conn: Connection) extends SchemaExtractor(conn) {
def extract(): DatabaseModel =
new DatabaseModel(extractTables())
private val cachedTableNames = new Lazy(dao2.findTableNames(currentCatalog, currentSchema))
def tableNames = cachedTableNames.get
private val cachedTablesOptions = new Lazy(dao.findTablesOptions(currentDb))
def getTableOptions(tableName: String): Seq[TableOption] =
cachedTablesOptions.get.find(_._1 == tableName).get._2
val cachedMysqlColumnDefaultValues = new Lazy(dao.findMysqlTablesColumns(currentDb))
def getMysqlColumns(tableName: String) =
cachedMysqlColumnDefaultValues.get.find(_._1 == tableName).get._2
def extractTables(): Seq[TableModel] = {
tableNames.map(extractTable _)
}
}
protected def parseDefaultValueFromDb(s: String, dataType: DataType): Option[SqlValue] = {
if (s == null) Some(NullValue) // None
else if (dataType.isAnyChar) {
if (s matches "'.*'") Some(StringValue(s.replaceFirst("^'", "").replaceFirst("'$", "")))
else Some(StringValue(s))
}
else if (s == "NULL") None
else if (dataType.isAnyDateTime) {
if (s == "CURRENT_TIMESTAMP") Some(NowValue)
else Some(StringValue(s))
}
else if (dataType.isAnyNumber) {
Some(sqlParserCombinator.parseValue(s))
}
else Some(StringValue(s))
}
// XXX: move to jdbc.MetaDao
def read[T](rs: ResultSet)(f: ResultSet => T) = {
var r = List[T]()
while (rs.next()) {
r += f(rs)
}
r
}
def extractTables(ds: LiteDataSource): Seq[TableModel] =
for (c <- new JdbcTemplate(ds)) yield new AllTablesSchemaExtractor(c).extractTables()
def extractTable(tableName: String, ds: LiteDataSource): TableModel =
for (c <- new JdbcTemplate(ds)) yield new SingleTableSchemaExtractor(c).extractTable(tableName)
def extract(ds: LiteDataSource): DatabaseModel =
for (c <- new JdbcTemplate(ds)) yield new AllTablesSchemaExtractor(c).extract()
def search(url: String): Seq[TableModel] = {
extractTables(LiteDataSource.driverManager(url))
}
def parse(jdbcUrl: String): DatabaseModel = new DatabaseModel(search(jdbcUrl))
def parseTable(tableName: String, jdbcUrl: String) =
for (c <- new JdbcTemplate(LiteDataSource.driverManager(jdbcUrl)))
yield new SingleTableSchemaExtractor(c).extractTable(tableName)
def main(args: scala.Array[String]) {
def usage() {
Console.err.println("usage: JdbcModelExtractor jdbc-url [table-name]")
}
val model = args match {
case Seq(jdbcUrl) =>
parse(jdbcUrl)
case Seq(jdbcUrl, tableName) =>
new DatabaseModel(List(parseTable(tableName, jdbcUrl)))
case _ =>
usage(); exit(1)
}
print(ModelSerializer.serializeDatabaseToText(model))
}
}
object JdbcModelExtractorTests extends org.specs.Specification {
import vendor.mysql.MysqlTestDataSourceParameters._
private def execute(q: String) {
jdbcTemplate.execute(q)
}
private def dropTable(tableName: String) {
execute("DROP TABLE IF EXISTS " + tableName)
}
def extractTable(name: String) =
for (c <- jdbcTemplate) yield new JdbcModelExtractor.SingleTableSchemaExtractor(c).extractTable(name)
"Simple Table" in {
dropTable("bananas")
execute("CREATE TABLE bananas (id INT, color VARCHAR(100), PRIMARY KEY(id))")
val table = extractTable("bananas")
assert("bananas" == table.name)
assert("id" == table.columns(0).name)
assert("INT" == table.columns(0).dataType.name)
assert("color" == table.columns(1).name)
assert("VARCHAR" == table.columns(1).dataType.name)
assert(100 == table.columns(1).dataType.length.get)
assert(List("id") == table.primaryKey.get.columns.toList)
}
"Indexes" in {
dropTable("users")
execute("CREATE TABLE users (first_name VARCHAR(20), last_name VARCHAR(20), age INT, INDEX age_k(age), UNIQUE KEY(first_name, last_name), KEY(age, last_name))")
val table = extractTable("users")
val ageK = table.indexes.find(_.name.get == "age_k").get
List("age") must_== ageK.columns.toList
ageK.isUnique must_== false
val firstLastK = table.indexWithColumns("first_name", "last_name")
firstLastK.isUnique must_== true
val ageLastK = table.indexWithColumns("age", "last_name")
ageLastK.isUnique must_== false
}
"PK is not in indexes list" in {
dropTable("files")
execute("CREATE TABLE files (id INT, PRIMARY KEY(id))")
val table = extractTable("files")
table.indexes.length must_== 0
table.primaryKey.get.columns.toList must_== List("id")
}
"Foreign keys" in {
dropTable("citizen")
dropTable("city")
dropTable("person")
execute("CREATE TABLE city (id INT PRIMARY KEY, name VARCHAR(10)) ENGINE=InnoDB")
execute("CREATE TABLE person(id1 INT, id2 INT, PRIMARY KEY(id1, id2)) ENGINE=InnoDB")
// http://community.livejournal.com/levin_matveev/20802.html
execute("CREATE TABLE citizen (id INT PRIMARY KEY, city_id INT, pid1 INT, pid2 INT, " +
"FOREIGN KEY (city_id) REFERENCES city(id), " +
"FOREIGN KEY (pid1, pid2) REFERENCES person(id1, id2)" +
") ENGINE=InnoDB")
val citizen = extractTable("citizen")
val city = extractTable("city")
val person = extractTable("person")
citizen.fks must haveSize(2)
val fkc = citizen.fks.find(_.localColumns.toList == List("city_id")).get
fkc.localColumns must beLike { case Seq("city_id") => true }
fkc.externalColumns must beLike { case Seq("id") => true }
fkc.externalTableName must_== "city"
val fkp = citizen.fks.find(_.localColumns.toList == List("pid1", "pid2")).get
fkp.localColumns must beLike { case Seq("pid1", "pid2") => true }
fkp.externalColumns must beLike { case Seq("id1", "id2") => true }
fkp.externalTableName must_== "person"
// no sure
//citizen.indexes must haveSize(0)
city.fks must haveSize(0)
person.fks must haveSize(0)
}
"table options" in {
dropTable("dogs")
execute("CREATE TABLE dogs (id INT) ENGINE=InnoDB")
val table = extractTable("dogs")
table.options must contain(TableOption("ENGINE", "InnoDB"))
}
"DEFAULT NOW()" in {
dropTable("cars")
execute("CREATE TABLE cars (id INT, created TIMESTAMP DEFAULT NOW())")
val table = extractTable("cars")
val created = table.column("created")
created.defaultValue must_== Some(NowValue)
}
"MySQL string DEFAULT values" in {
dropTable("jets")
execute("CREATE TABLE jets (a VARCHAR(2), b VARCHAR(2) DEFAULT '', c VARCHAR(2) DEFAULT 'x', " +
"d VARCHAR(2) NOT NULL, e VARCHAR(2) NOT NULL DEFAULT '', f VARCHAR(2) NOT NULL DEFAULT 'y')")
val table = extractTable("jets")
//table.column("a").defaultValue must_== None
table.column("b").defaultValue must_== Some(StringValue(""))
table.column("c").defaultValue must_== Some(StringValue("x"))
//table.column("d").defaultValue must_== None
table.column("e").defaultValue must_== Some(StringValue(""))
table.column("f").defaultValue must_== Some(StringValue("y"))
}
"unspecified AUTO_INCREMENT" in {
dropTable("ships")
execute("CREATE TABLE ships (id INT NOT NULL, name VARCHAR(10), PRIMARY KEY(id))")
val t = extractTable("ships")
t.column("id").properties.autoIncrement must_== Some(false)
//t.column("name").properties.autoIncrement must_== None
}
"fetches CHARACTER SET and COLLATE" in {
dropTable("qwqw")
execute("CREATE TABLE qwqw (a VARCHAR(2), b VARCHAR(2) CHARACTER SET utf8 COLLATE utf8_bin)")
val table = extractTable("qwqw")
val a = table.column("a")
val b = table.column("b")
b.dataType.options must contain(MysqlCharacterSet("utf8"))
b.dataType.options must contain(MysqlCollate("utf8_bin"))
}
"DATETIME without length" in {
for (t <- List("DATETIME", "TEXT")) {
val table = t + "_without_length_test"
dropTable(table)
execute("CREATE TABLE " + table + "(a " + t + ")")
val tp = extractTable(table).column("a").dataType
(tp.name, tp.length) must_== (t, None)
}
}
}
// vim: set ts=4 sw=4 et: