All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hundsun.lightdb.unisql.proxy.Driver Maven / Gradle / Ivy

package com.hundsun.lightdb.unisql.proxy;

import com.hundsun.lightdb.unisql.model.UnisqlProperties;
import com.hundsun.lightdb.unisql.proxy.jdbc.DbType;
import com.hundsun.lightdb.unisql.proxy.jdbc.UnisqlConnection;
import com.hundsun.lightdb.unisql.utils.Utils;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.DriverPropertyInfo;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.Statement;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public class Driver implements java.sql.Driver {

    public static void main(String[] args) {
    }

    private static Logger LOG = LoggerFactory.getLogger(Driver.class);

    private static final Driver INSTANCE = new Driver();

    private static Boolean mysql_driver_version_6 = null;

    static {
        AccessController.doPrivileged(new PrivilegedAction() {

            @Override
            public Object run() {
                Driver.registerDriver(Driver.INSTANCE);
                return null;
            }
        });
    }

    public static boolean registerDriver(java.sql.Driver driver) {
        try {
            DriverManager.registerDriver(driver);
            return true;
        } catch (Exception var4) {
            if (LOG == null) {
                LOG = LoggerFactory.getLogger(Driver.class);
            }

            LOG.error("registerDriver error", var4);
            return false;
        }
    }

    private static final String acceptPrefix = "jdbc:unisql:";
    private static final String sourceDialect = "sourceDialect";
    private static final String targetDialect = "targetDialect";

    /** 兜底限制一下 */
    private static final int MAX_PRE_CHECKED_URLS = 2048;

    /** 已通过预检查 */
    private static final Map PRE_CHECKED = new ConcurrentHashMap<>();
    /** 已打印信息 */
    private static final Map INFO_PRINTED = new ConcurrentHashMap<>();

    /** 配置参数 */
    private final UnisqlProperties properties = UnisqlProperties.getInstance();

    @Override
    public Connection connect(String url, Properties info) throws SQLException {
        if (!this.acceptsURL(url)) {
            return null;
        } else {
            String replaceUrl = url.replace(acceptPrefix, "jdbc:");
            java.sql.Driver driver = getTargetDriver(replaceUrl);
            Properties prop = parseUrl(url);

            String sourceDialect = mustExists(prop, Driver.sourceDialect, "url param sourceDialect not specified");
            String targetDialect = mustExists(prop, Driver.targetDialect, "url param targetDialect not specified");

            DbType sourceDialect1 = DbType.of(sourceDialect);
            DbType targetDialect1 = DbType.of(targetDialect);
            Connection nativeConnection = driver.connect(replaceUrl, info);
            try {
                preCheck(url, nativeConnection, sourceDialect1, targetDialect1);
                printInformation(url, replaceUrl, nativeConnection, sourceDialect1, targetDialect1);
                return new UnisqlConnection(nativeConnection, sourceDialect1, targetDialect1);
            } catch (Exception e) {
                // 如果处理统一 sql 相关逻辑出现异常,将物理连接关掉,避免泄漏
                try {
                    nativeConnection.close();
                } catch (Exception e2) {
                    LOG.warn("An exception occurred when closing the original connection, " +
                            "only print the exception and still throw the original exception.", e);
                }
                throw e;
            }
        }
    }

    /**
     * 打印目标库相关信息
     *
     * @param url        连接 url
     * @param replaceUrl 原始 URL
     * @param con        实际连接
     * @param src        源 SQL 方言类型
     * @param dst        目标 SQL 方言类型
     */
    private void printInformation(String url, String replaceUrl, Connection con, DbType src, DbType dst) {
        if (INFO_PRINTED.containsKey(url)) {
            return;
        }
        // 准备打印目标库信息
        String databaseNameAndVersion = "unknown";
        try {
            if (replaceUrl.startsWith("jdbc:postgresql:")) {
                // pg 方式,读取 select version()
                try (final Statement stmt = con.createStatement();
                     final ResultSet rs = stmt.executeQuery("select version()")) {
                    if (rs.next()) {
                        databaseNameAndVersion = rs.getString(1);
                    }
                }
            } else if (replaceUrl.startsWith("jdbc:oceanbase:") || replaceUrl.startsWith("jdbc:oracle:")) {
                // ob 或 oracle 方式,读取 select BANNER from v$version
                try (final Statement stmt = con.createStatement();
                     final ResultSet rs = stmt.executeQuery("select BANNER from v$version")) {
                    if (rs.next()) {
                        databaseNameAndVersion = rs.getString(1);
                    }
                }
            } else {
                // 默认方式,包括 mysql ,走 jdbc
                final DatabaseMetaData metadata = con.getMetaData();
                databaseNameAndVersion = metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion();
            }
        } catch (Exception e) {
            // 无法确定目标库信息
            log.warn("Cannot detect target database type, real url: {}", replaceUrl);
        }

        log.info("Established unisql JDBC connection. database: {}, source dialect: {}, target dialect: {}",
                databaseNameAndVersion, src, dst);

        INFO_PRINTED.put(url, true);
    }

    /**
     * 预先检查统一 sql 功能是否可用
     *
     * @param url              连接 url
     * @param nativeConnection 实际连接
     * @param sourceDialect1   源 SQL 方言类型
     * @param targetDialect1   目标 SQL 方言类型
     */
    private void preCheck(String url, Connection nativeConnection, DbType sourceDialect1, DbType targetDialect1) {
        // 考虑到一个程序不可能连过多的 jdbc url ,就直接使用 url 做了存储
        if (PRE_CHECKED.containsKey(url)) {
            return;
        }

        boolean checked = false;

        // 检查 PostgreSQL 目标库
        if (targetDialect1 == DbType.POSTGRESQL && properties.isCheckPostgreSQL()) {
            // noinspection ConstantValue
            checked = checked || tryCheckPostgresSchema(nativeConnection);
        }

        // 检查 MySQL 目标库
        if (targetDialect1 == DbType.MYSQL) {
            // noinspection ConstantValue
            checked = checked || tryCheckMySQLSchema(nativeConnection);
        }

        // 做一下兜底限制,万一真有谁连了无数个 url 呢,限制一下不至于炸掉
        if (checked && PRE_CHECKED.size() < MAX_PRE_CHECKED_URLS) {
            PRE_CHECKED.put(url, Boolean.TRUE);
        }
    }

    /**
     * 检查 pg 库中有没有统一 sql 必须的 schema
     *
     * @param connection 物理连接
     * @return 是否执行了检查
     */
    private boolean tryCheckPostgresSchema(Connection connection) {
        if (!properties.isCheckPostgreSQLSchema()) {
            return false;
        }
        boolean exists;
        try {
            ResultSet resultSet = connection.createStatement().executeQuery("select exists(SELECT 1 FROM " +
                    "information_schema.schemata " +
                    "WHERE schema_name = 'unisql')");
            resultSet.next();
            exists = resultSet.getBoolean(1);
        } catch (Exception e) {
            throw new IllegalStateException("unisql schema check failed: " + e.getMessage(), e);
        }
        if (!exists) {
            LOG.error("unisql schema not created");
            throw new IllegalStateException("unisql schema not created.");
        }
        return true;
    }

    /**
     * 检查 MySQL 库中有没有统一 sql 必须的 database
     *
     * @param connection 物理连接
     * @return 是否执行了检查
     */
    private boolean tryCheckMySQLSchema(Connection connection) {
        if (!properties.isCheckMySQLSchema()) {
            return false;
        }
        boolean exists;
        try {
            ResultSet resultSet = connection.createStatement().executeQuery("show databases like 'unisql'");
            exists = resultSet.next();
        } catch (Exception e) {
            throw new IllegalStateException("unisql database check failed: " + e.getMessage(), e);
        }
        if (!exists) {
            LOG.error("unisql database not created or permission issue. " +
                    "you can try to execute 'call unisql.GrantUnisqlPermissions()' as root then try again");
            throw new IllegalStateException("unisql database not created or permission issue.");
        }
        return true;
    }

    private String mustExists(Properties prop, String key, String msg) {

        String property = prop.getProperty(key);
        if (property == null || property.isEmpty()) {
            throw new IllegalArgumentException(msg);
        }

        return property;
    }

    private Properties parseUrl(String url) {
        int qPos = url.indexOf('?');
        String urlArgs;
        if (qPos != -1) {
            urlArgs = url.substring(qPos + 1);
        } else {
            throw new IllegalArgumentException("url param sourceDialect not specified");
        }

        Properties urlProps = new Properties();

        String[] args = urlArgs.split("&");
        for (String token : args) {
            if (token.isEmpty()) {
                continue;
            }
            int pos = token.indexOf('=');
            if (pos == -1) {
                urlProps.setProperty(token, "");
            } else {
                urlProps.setProperty(token.substring(0, pos), decode(token, pos));
            }
        }

        return urlProps;
    }

    private String decode(String token, int pos) {
        try {
            return URLDecoder.decode(token.substring(pos + 1), "UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw new IllegalStateException(
                    "Unable to decode URL entry via UTF-8. This should not happen", e);
        }
    }

    protected java.sql.Driver getTargetDriver(String rawUrl) throws SQLException {
        if (rawUrl == null) {
            return null;
        } else if (rawUrl.startsWith("jdbc:opengauss:")) {
            return createDriver("org.opengauss.Driver");
        } else if (rawUrl.startsWith("jdbc:postgresql:")) {
            return createDriver("org.postgresql.Driver");
        } else if (rawUrl.startsWith("jdbc:dm:")) {
            return createDriver("dm.jdbc.driver.DmDriver");
        } else if (rawUrl.startsWith("jdbc:oceanbase:")) {
            // https://www.oceanbase.com/docs/common-oceanbase-database-cn-10000000001698776
            return createDriver("com.alipay.oceanbase.jdbc.Driver");
        } else if (rawUrl.startsWith("jdbc:mariadb:")) {
            return createDriver("org.mariadb.jdbc.Driver");
        } else if (rawUrl.startsWith("jdbc:oracle:") || rawUrl.startsWith("JDBC:oracle:")) {
            return createDriver("oracle.jdbc.OracleDriver");
        } else if (rawUrl.startsWith("jdbc:mysql:")) {
            if (mysql_driver_version_6 == null) {
                mysql_driver_version_6 = Utils.loadClass("com.mysql.cj.jdbc.Driver") != null;
            }

            if (mysql_driver_version_6) {
                return createDriver("com.mysql.cj.jdbc.Driver");
            } else {
                return createDriver("com.mysql.jdbc.Driver");
            }
        } else {
            return DriverManager.getDriver(rawUrl);
        }
    }

    private java.sql.Driver createDriver(String className) throws SQLException {
        Class rawDriverClass = Utils.loadClass(className);
        if (rawDriverClass == null) {
            throw new SQLException("jdbc-driver's class not found. '" + className + "'");
        } else {
            try {
                return (java.sql.Driver) rawDriverClass.newInstance();
            } catch (InstantiationException | IllegalAccessException var5) {
                throw new SQLException("create driver instance error, driver className '" + className + "'", var5);
            }
        }
    }

    @Override
    public boolean acceptsURL(String url) throws SQLException {
        if (url == null) {
            return false;
        } else {
            return url.startsWith(acceptPrefix);
        }
    }

    @Override
    public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) throws SQLException {
        return new DriverPropertyInfo[0];
    }

    @Override
    public int getMajorVersion() {
        return 0;
    }

    @Override
    public int getMinorVersion() {
        return 0;
    }

    @Override
    public boolean jdbcCompliant() {
        return false;
    }

    @Override
    public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException {
        return null;
    }
}