
com.github.sarxos.mse.SchemaEvolver Maven / Gradle / Ivy
package com.github.sarxos.mse;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FilenameFilter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import org.apache.commons.collections4.iterators.ReverseListIterator;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* MySQL schema upgrade tool.
*
* @author Bartosz Firyn (sarxos)
*/
public class SchemaEvolver {
/**
* Logger.
*/
private static final Logger LOG = LoggerFactory.getLogger(SchemaEvolver.class);
/**
* Used to compare two schema version strings.
*
* @author Bartosz Firyn (sarxos)
*/
private static final class VersionComparator implements Comparator {
@Override
public int compare(String a, String b) {
String[] ap = a.split("\\.");
String[] bp = b.split("\\.");
if (ap.length != bp.length || ap.length != 7) {
throw new IllegalStateException("Invalid schema number");
}
int ai = 0;
int bi = 0;
for (int i = 0; i < ap.length; i++) {
ai = Integer.parseInt(ap[i]);
bi = Integer.parseInt(bp[i]);
if (ai == bi) {
continue;
}
return ai - bi;
}
return 0;
}
}
/**
* Used to filter schema version directories.
*
* @author Bartosz Firyn (sarxos)
*/
private static final class EvolutionDirFilter implements FilenameFilter {
@Override
public boolean accept(File dir, String name) {
String[] parts = name.split("\\.");
if (parts.length != 7) {
return false;
}
for (String part : parts) {
for (int i = 0; i < part.length(); i++) {
if (!Character.isDigit(part.charAt(i))) {
return false;
}
}
}
if (!new File(dir, name).isDirectory()) {
return false;
}
return true;
}
}
/**
* File name filter used to filter schema directories.
*/
private static final FilenameFilter EVF = new EvolutionDirFilter();
/**
* Schema version comparator.
*/
private static final Comparator VC = new VersionComparator();
/**
* Used to distinguish upgrade.
*/
private static final String UPGRADE = "upgrade";
/**
* Used to distinguish downgrade.
*/
private static final String DOWNGRADE = "downgrade";
private static final String ROUTINE_FILE = "routines.sql";
/**
* MySQL database connection.
*/
private final Connection connection;
private final String dbname;
/**
* Create MySQL schema evolver.
*
* @param connection the database connection, must not be null
*/
public SchemaEvolver(Connection connection) {
if (connection == null) {
throw new IllegalArgumentException("Connection cannot be null");
}
this.connection = connection;
try {
this.dbname = connection.getCatalog();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
/**
* Read resource from classpath and return associated reader.
*
* @param resource the resource to read
* @return {@link Reader} associated with resource {@link InputStream}
*/
private Reader read(String resource) {
return new InputStreamReader(getClass().getClassLoader().getResourceAsStream(resource));
}
/**
* Evolve database schema.
*
* @param path the path to schema directories
* @throws IOException when files cannot be read or is not a directory
* @throws SQLException when something wrong happen in SQL
*/
public void evolve(String path) throws IOException, SQLException {
LOG.info("Starting schema evolution script");
// verify path points to a directory
LOG.info("Preloading default routines");
evaluate(read("routines/AddColumn.sql"));
evaluate(read("routines/DropColumn.sql"));
evaluate(read("routines/DropFK.sql"));
evaluate(read("routines/DropIndex.sql"));
evaluate(read("routines/ModColumn.sql"));
evaluate(read("routines/SetCharacterSet.sql"));
evaluate(read("routines/SetCollate.sql"));
LOG.debug("Checking for routines file");
File dir = new File(path);
if (!dir.isDirectory()) {
throw new FileNotFoundException(path);
}
File routines = new File(dir, ROUTINE_FILE);
if (routines.canRead()) {
LOG.info("Routines file found, evaluating");
evaluate(routines);
LOG.info("Routines file evaluated");
}
ArrayList versions = new ArrayList<>();
for (File file : dir.listFiles(EVF)) {
versions.add(file.getName());
}
if (versions.isEmpty()) {
LOG.info("No upgrades to be executed");
return;
}
Collections.sort(versions, VC);
String current = getCurrentVersion();
String newest = versions.get(versions.size() - 1);
String direction = null;
LOG.info("Current {} schema version is {} and the newest one is {}", dbname, current, newest);
int result = VC.compare(current, newest);
Iterator vi = null;
// return if schema is the newest one
if (result == 0) {
LOG.info("The {} schema version is already the newest one", dbname);
return;
} else if (result < 0) {
LOG.info("The {} schema version should be upgraded", dbname);
direction = UPGRADE;
vi = versions.listIterator();
} else {
LOG.info("The {} schema version should be downgraded (not implemented yet)", dbname);
direction = DOWNGRADE;
vi = new ReverseListIterator<>(versions);
}
while (vi.hasNext()) {
String version = vi.next();
switch (direction) {
case UPGRADE:
if (VC.compare(version, current) <= 0) {
LOG.info("The {} schema skipping {} vs current {}", dbname, version, current);
continue;
} else {
LOG.info("Processing {} schema version {} {}", dbname, version, direction);
}
break;
case DOWNGRADE:
if (VC.compare(current, version) >= 0) {
LOG.info("The {} schema skipping {} vs current {}", dbname, version, current);
continue;
} else {
LOG.info("Processing {} schema version {} {}", dbname, version, direction);
}
break;
}
File verdir = new File(dir, version);
File sqlfile = new File(verdir, direction + ".sql");
if (sqlfile.exists()) {
evaluate(sqlfile);
} else {
LOG.warn("No {} has been found for schema {}", sqlfile, dbname);
}
updateVersion(version);
}
}
/**
* Evaluate SQL file.
*
* @param file the file object
* @throws IOException when file cannot be read
* @throws SQLException when there is an SQL syntax in given file
*/
private final void evaluate(File file) throws IOException, SQLException {
if (file == null) {
return;
}
try (Reader reader = new FileReader(file)) {
evaluate(reader);
}
}
/**
* Evaluate SQL.
*
* @param reader the SQL instructions reader
* @throws IOException when reader cannot read instructions
* @throws SQLException when there is an SQL syntax in given file
*/
private final void evaluate(Reader reader) throws IOException, SQLException {
try (BufferedReader br = new BufferedReader(reader)) {
@SuppressWarnings("unchecked")
ImmutablePair[] params = new ImmutablePair[] {
getParam("UNIQUE_CHECKS", 0),
getParam("FOREIGN_KEY_CHECKS", 0),
getParam("SQL_MODE", "TRADITIONAL"),
};
// disable unique and foreign key check to speed up process
setParam("UNIQUE_CHECKS", 0);
setParam("FOREIGN_KEY_CHECKS", 0);
String s = null;
String delimiter = ";";
StringBuilder sb = new StringBuilder();
while ((s = br.readLine()) != null) {
s = s.trim();
if (s.startsWith("--") || s.isEmpty()) {
continue;
}
s = s.replaceAll("\t", " ");
s = s.replaceAll("\\s+", " ");
s = s.trim();
if (StringUtils.startsWithIgnoreCase(s, "DELIMITER")) {
delimiter = s.split(" ")[1].trim();
sb.delete(0, sb.length());
continue;
}
boolean semicolon = false;
if (s.endsWith(delimiter)) {
s = s.substring(0, s.length() - delimiter.length());
semicolon = true;
}
sb.append(s).append(' ');
if (semicolon) {
String sql = sb.toString();
LOG.info("{} mysql> {}", dbname, sql);
try (Statement stmt = connection.createStatement()) {
stmt.execute(sql);
} finally {
sb.delete(0, sb.length());
}
}
}
String sql = sb.toString().trim();
if (!sql.isEmpty()) {
throw new SQLException("Syntax error, delimiter is missing on: " + sql);
}
for (ImmutablePair param : params) {
setParam(param.getLeft(), param.getRight());
}
}
}
/**
* Set connection parameter value.
*
* @param name the parameter name
* @param value the new parameter value
* @throws SQLException when parameter name is invalid
*/
private final void setParam(String name, Object value) throws SQLException {
LOG.info("{} mysql> SET {} = {} ", dbname, name, value instanceof String ? "'" + value + "'" : value);
try (PreparedStatement stmt = connection.prepareStatement("SET " + name + " = ?")) {
stmt.setObject(1, value);
stmt.execute();
}
}
/**
* Get connection parameter
*
* @param param the parameter name
* @param defValue the parameter default value
* @return Parameter value of default one if parameter is not defined
* @throws SQLException when parameter name is invalid
*/
private final ImmutablePair getParam(String param, Object defValue) throws SQLException {
String name = "@@" + param;
String query = "SELECT " + name;
Object value = null;
LOG.info("{} mysql> {}", dbname, query);
try (Statement stmt = connection.createStatement()) {
try (ResultSet rs = stmt.executeQuery(query)) {
if (rs.next()) {
value = rs.getObject(name);
} else {
value = defValue;
}
}
}
return new ImmutablePair(param, value);
}
/**
* Get current schema version from database.
*
* @return Currently installed database schema version
* @throws SQLException
* @throws IOException
*/
private final String getCurrentVersion() throws SQLException, IOException {
try (Statement stmt = connection.createStatement()) {
try (ResultSet rs = stmt.executeQuery("SELECT v.version FROM version v WHERE v.id = 1")) {
if (rs.next()) {
return rs.getString(1);
}
} catch (SQLException e) {
LOG.trace(e.getMessage(), e);
LOG.info("No version table detected in {} schema", dbname);
return createInitialVersion();
}
}
throw new IllegalStateException("Unable to read version from database " + dbname);
}
/**
* Create initial schema.
*
* @return Always return 00.00.00.00.00.00.001
* @throws SQLException
* @throws IOException
*/
private String createInitialVersion() throws SQLException, IOException {
LOG.info("Creating initial schema in {}", dbname);
evaluate(read("sql/initial.sql"));
return "00.00.00.00.00.00.000";
}
/**
* Update version in database.
*
* @param version the new schema version to set
* @throws SQLException
*/
private final void updateVersion(String version) throws SQLException {
LOG.info("The {} schema update version to {}", dbname, version);
try (PreparedStatement stmt = connection.prepareStatement("UPDATE version v SET v.version = ? WHERE v.id = 1")) {
stmt.setString(1, version);
stmt.execute();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy