All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.huaweicloud.dws.client.util.JdbcUtil Maven / Gradle / Ivy

package com.huaweicloud.dws.client.util;

import static java.lang.String.format;

import com.huawei.gauss200.jdbc.copy.CopyManager;
import com.huawei.gauss200.jdbc.core.BaseConnection;
import com.huawei.gauss200.jdbc.jdbc.PgConnection;
import com.huawei.gauss200.jdbc.util.PGobject;
import com.huawei.shade.com.alibaba.fastjson.JSONObject;
import com.huawei.shade.com.alibaba.fastjson.serializer.SerializerFeature;

import com.huaweicloud.dws.client.DwsConfig;
import com.huaweicloud.dws.client.TableConfig;
import com.huaweicloud.dws.client.exception.DwsClientException;
import com.huaweicloud.dws.client.exception.ExceptionCode;
import com.huaweicloud.dws.client.model.Column;
import com.huaweicloud.dws.client.model.ConflictStrategy;
import com.huaweicloud.dws.client.model.CopyMode;
import com.huaweicloud.dws.client.model.Record;
import com.huaweicloud.dws.client.model.TableName;
import com.huaweicloud.dws.client.model.TableSchema;
import com.huaweicloud.dws.client.model.TypeColumn;
import com.huaweicloud.dws.client.model.TypeDefinition;
import com.huaweicloud.dws.client.op.Get;
import com.huaweicloud.dws.client.op.Scan;
import com.huaweicloud.dws.client.types.MySqlTypeUtils;

import lombok.extern.slf4j.Slf4j;

import java.io.StringReader;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PrimitiveIterator;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;

/**
 * @ProjectName: dws-connector
 * @ClassName: JdbcUtil
 * @Description: jdbc操作工具类
 * @Date: 2022/12/22 14:24
 * @Version: 1.0
 */
@Slf4j
public class JdbcUtil {
    public static final char EOL_CHAR = 0x1F;

    public static final String EOL = String.valueOf(EOL_CHAR);

    public static final char DELIMITER_CHAR = 0x1E;

    public static final String DELIMITER = String.valueOf(DELIMITER_CHAR);

    public static TableSchema getTableSchema(Connection conn, TableName tableName) throws DwsClientException {
        return getTableSchema(conn, tableName, null, false);
    }

    /**
     * 根据表名获取表结构定义
     */
    public static TableSchema getTableSchema(Connection conn, TableName tableName, List uniqueKeys,
        boolean caseSensitive) throws DwsClientException {
        // 如果用户制定了唯一索引那么使用唯一索引替代主键,否则获取表中主键
        Set primaryKeys = new HashSet<>();
        if (uniqueKeys != null) {
            uniqueKeys.forEach(s -> primaryKeys.add(IdentifierUtil.toLowerIfNoSensitive(s, caseSensitive)));
        } else {
            try (ResultSet rs =
                conn.getMetaData().getPrimaryKeys(null, tableName.getSchemaName(), tableName.getTableName())) {
                while (rs.next()) {
                    primaryKeys.add(IdentifierUtil.toLowerIfNoSensitive(rs.getString(4), caseSensitive));
                }
            } catch (Exception e) {
                throw DwsClientException.fromException(e);
            }
        }
        // 获取所有列
        List columnList = new ArrayList<>();
        try (ResultSet rs =
            conn.getMetaData().getColumns(null, tableName.getSchemaName(), tableName.getTableName(), "%")) {
            while (rs.next()) {
                Column column = new Column();
                column.setName(rs.getString(4));
                column.setType(rs.getInt(5));
                column.setTypeName(rs.getString(6));
                column.setPrecision(rs.getInt(7));
                column.setScale(rs.getInt(9));
                column.setAllowNull(rs.getInt(11) == 1);
                column.setComment(rs.getString(12));
                column.setDefaultValue(rs.getObject(13));
                column.setPrimaryKey(
                    primaryKeys.contains(IdentifierUtil.toLowerIfNoSensitive(column.getName(), caseSensitive)));
                columnList.add(column);
            }
        } catch (Exception e) {
            throw DwsClientException.fromException(e);
        }
        if (columnList.isEmpty()) {
            throw new DwsClientException(ExceptionCode.TABLE_NOT_FOUND,
                format("table name %s does not exist", tableName.getFullName()));
        }
        return new TableSchema(tableName, columnList, caseSensitive);
    }

    /**
     * 获取upsert sql模板
     */
    public static String getUpsertStatement(String tableName, List insertFieldNames,
        List updateFieldNames, List uniqueKeyFields, ConflictStrategy strategy) {
        // 没有主键时获取insert sql
        if (uniqueKeyFields.isEmpty()) {
            return getInsertIntoStatement(tableName, insertFieldNames);
        }
        String uniqueColumns =
            uniqueKeyFields.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        String conflictSql = " DO NOTHING";
        if (strategy != ConflictStrategy.INSERT_OR_IGNORE) {
            String updateClause = updateFieldNames.stream()
                .filter(c -> LocalUtil.getTableConfig().isUpdateAll() || !uniqueKeyFields.contains(c))
                .map(f -> IdentifierUtil.quoteIdentifier(f) + "=EXCLUDED." + IdentifierUtil.quoteIdentifier(f))
                .collect(Collectors.joining(", "));
            conflictSql = format(" DO UPDATE SET %s ", updateClause);
        }
        TableConfig tableConfig = LocalUtil.getTableConfig();
        Set compareField = tableConfig.getCompareField();
        String whereSql = "";
        if (compareField != null) {
            whereSql = " WHERE " + compareField.stream()
                .map(f -> IdentifierUtil.quoteIdentifier(f) + "< EXCLUDED." + IdentifierUtil.quoteIdentifier(f))
                .collect(Collectors.joining(" AND "));
        }
        return getInsertIntoStatement(tableName, insertFieldNames) + " ON CONFLICT (" + uniqueColumns + ")"
            + conflictSql + whereSql;
    }

    /**
     * 获取insert into sql
     */
    public static String getInsertIntoStatement(String tableName, List fieldNames) {
        String columns = fieldNames.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        TableSchema tableSchema = LocalUtil.getContext().getTableSchema();
        String placeholders = fieldNames.stream().map(s -> {
            Column column = tableSchema.getColumn(s);
            if (Types.BIT == column.getType() && "bit".equals(column.getTypeName())) {
                return "?::bit(" + column.getPrecision() + ")";
            } else if (Types.OTHER == column.getType() && "varbit".equals(column.getTypeName())) {
                return "?::bit varying(" + column.getPrecision() + ")";
            }
            return "?";
        }).collect(Collectors.joining(", "));
        return "INSERT INTO " + tableName + "(" + columns + ")" + " VALUES (" + placeholders + ")";
    }

    /**
     * 时间中 00的异常数据
     */
    public static final String DATE_0000 = "0000-00-00 00:00:00";

    /**
     * Statement 设置参数
     */
    public static void fillPreparedStatement(PreparedStatement ps, int index, Object obj, Column column)
        throws SQLException {
        if (column == null || obj == null) {
            ps.setObject(index, obj);
            return;
        }
        String value = String.valueOf(obj);
        switch (column.getType()) {
            case Types.OTHER:
                if ("varbit".equals(column.getTypeName())) {
                    ps.setString(index, value);
                } else {
                    ps.setObject(index, obj, column.getType());
                }
                break;
            case Types.LONGNVARCHAR:
            case Types.VARCHAR:
            case Types.CHAR:
                ps.setObject(index, value, column.getType());
                break;
            case Types.BIT:
                setBit(ps, index, obj, column, value);
                break;
            case Types.TIMESTAMP_WITH_TIMEZONE:
            case Types.TIMESTAMP:
                setTimeStamp(ps, index, obj, column);
                break;
            case Types.DATE:
                setDate(ps, index, obj, column);
                break;
            case Types.TIME_WITH_TIMEZONE:
            case Types.TIME:
                setTime(ps, index, obj, column);
                break;
            case Types.BIGINT:
            case Types.TINYINT:
            case Types.SMALLINT:
            case Types.INTEGER:
            case Types.FLOAT:
            case Types.REAL:
            case Types.DOUBLE:
            case Types.NUMERIC:
            case Types.DECIMAL:
                ps.setObject(index, obj, column.getType());
                break;
            default:
                ps.setObject(index, obj);
        }
    }

    public static void setBit(PreparedStatement ps, int index, Object obj, Column column, String value)
        throws SQLException {
        if ("bit".equals(column.getTypeName())) {
            if (obj instanceof Boolean) {
                ps.setString(index, (Boolean) obj ? "1" : "0");
            } else {
                ps.setString(index, value);
            }
        } else {
            ps.setObject(index, obj, column.getType());
        }
    }

    /**
     * 时间类型参数设置
     */
    public static void setTime(PreparedStatement ps, int index, Object obj, Column column) throws SQLException {
        TableConfig config = LocalUtil.getTableConfig();
        if (obj instanceof Number && config.isNumberAsEpochMsForDatetime()) {
            ps.setObject(index, new Time(((Number) obj).longValue()), column.getType());
        } else if (obj instanceof String && config.getStringToDatetimeFormat() != null) {
            try {
                ps.setObject(index, new Time(getTime(obj.toString(), config.getStringToDatetimeFormat())),
                    column.getType());
            } catch (ParseException e) {
                if (DATE_0000.equals(obj)) {
                    ps.setObject(index, new Time(0), column.getType());
                } else {
                    ps.setObject(index, obj);
                }
            }
        } else {
            ps.setObject(index, obj);
        }
    }

    /**
     * 日期类型参数设置
     */
    public static void setDate(PreparedStatement ps, int index, Object obj, Column column) throws SQLException {
        TableConfig config = LocalUtil.getTableConfig();
        if (obj instanceof Number && config.isNumberAsEpochMsForDatetime()) {
            ps.setObject(index, new java.sql.Date(((Number) obj).longValue()), column.getType());
        } else if (obj instanceof String && config.getStringToDatetimeFormat() != null) {
            try {
                ps.setObject(index, new java.sql.Date(getTime(obj.toString(), config.getStringToDatetimeFormat())),
                    column.getType());
            } catch (ParseException e) {
                if (DATE_0000.equals(obj)) {
                    ps.setObject(index, new java.sql.Date(0), column.getType());
                } else {
                    ps.setObject(index, obj);
                }
            }
        } else {
            ps.setObject(index, obj);
        }
    }

    private static long getTime(String src, String format) throws ParseException {
        SimpleDateFormat dateFormat = new SimpleDateFormat(format);
        return dateFormat.parse(src).getTime();
    }

    /**
     * 处理时间戳字段
     */
    public static void setTimeStamp(PreparedStatement ps, int index, Object obj, Column column) throws SQLException {
        TableConfig config = LocalUtil.getTableConfig();
        if (obj instanceof Number && config.isNumberAsEpochMsForDatetime()) {
            ps.setObject(index, new Timestamp(((Number) obj).longValue()), column.getType());
        } else if (obj instanceof String && config.getStringToDatetimeFormat() != null) {

            try {
                ps.setObject(index, new Timestamp(getTime(obj.toString(), config.getStringToDatetimeFormat())),
                    column.getType());
            } catch (ParseException e) {
                if (DATE_0000.equals(obj)) {
                    ps.setObject(index, new Timestamp(0), column.getType());
                } else {
                    ps.setObject(index, obj);
                }
            }
        } else {
            ps.setObject(index, obj);
        }
    }

    /**
     * 获取删除SQL
     */
    public static String getDeleteStatement(String tableName, List conditionFields) {
        String conditionClause = conditionFields.stream()
            .map(f -> format("%s = ?", IdentifierUtil.quoteIdentifier(f)))
            .collect(Collectors.joining(" AND "));
        return "DELETE FROM " + tableName + (conditionFields.isEmpty() ? "" : " WHERE " + conditionClause);
    }

    public static String getUpdateStatement(String tableName, List updateFieldNames,
        List uniqueKeyFields) {
        String conditionClause = uniqueKeyFields.stream()
            .map(f -> format("%s = ?", IdentifierUtil.quoteIdentifier(f)))
            .collect(Collectors.joining(" AND "));
        String setSql = updateFieldNames.stream()
            .filter(c -> !uniqueKeyFields.contains(c))
            .map(f -> format("%s = ?", IdentifierUtil.quoteIdentifier(f)))
            .collect(Collectors.joining(", "));
        return "UPDATE " + tableName + " SET " + setSql + " WHERE " + conditionClause;
    }

    public static String getUpdateFromStatement(String tableName, String tempTableName, List updateFieldNames,
        List uniqueKeyFields) {
        String conditionClause =
            uniqueKeyFields.stream().map(f -> format("o.%s = t.%s", f, f)).collect(Collectors.joining(" AND "));
        String setSql = updateFieldNames.stream()
            .filter(c -> !uniqueKeyFields.contains(c))
            .map(f -> format("o.%s = t.%s", f, f))
            .collect(Collectors.joining(", "));
        return "UPDATE " + tableName + " o set " + setSql + " FROM " + tempTableName + " t WHERE " + conditionClause;
    }

    /**
     * 获取创建临时表SQL
     *
     * @param origin 原表
     * @param tempTable 临时表名称
     * @return
     */
    public static String getCreateTempTableSql(String origin, String tempTable, String type) {

        return "create " + type + " temp table " + tempTable + " like " + origin
            + " including all excluding partition EXCLUDING INDEXES EXCLUDING RELOPTIONS ";
    }

    public static String getCreateTempTableSql(String origin, String tempTable) {

        return getCreateTempTableSql(origin, tempTable, "");
    }

    public static String getCreateTempTableSql1(String origin, String tempTable, String type, String hash) {

        return "create " + type + " temporary table IF NOT EXISTS  " + tempTable + " (like " + origin
            + " including defaults) WITH (orientation=row, compression=no) DISTRIBUTE BY HASH(" + hash + ");";
    }

    public static String getCreateTempTableAsSql(String origin, String tempTable, String type) {
        return "create " + type + " temp table " + tempTable + " as select * from " + origin + " where 1=2 ";
    }

    public static String getCreateTempTableAsSql(String origin, String tempTable) {
        return getCreateTempTableAsSql(origin, tempTable, "");
    }

    /**
     * 构建成copy的buffer
     */
    public static String buildCopyBuffer(List records, DwsConfig config, TableConfig tableConfig) {
        if (records == null || records.isEmpty()) {
            return null;
        }
        if (tableConfig.getCopyMode() == CopyMode.DELIMITER) {
            return buildStdinCopyBuffer(records, tableConfig.getDelimiter(), tableConfig.getEof());
        }
        return buildCsvCopyBuffer(records, config, tableConfig.getDelimiter());
    }

    /**
     * 构建成copy的buffer
     */
    public static void buildCopyBuffer(Record record, TableConfig tableConfig, StringBuilder temp) {
        if (tableConfig.getCopyMode() == CopyMode.DELIMITER) {
            buildStdinCopyBuffer(record, tableConfig.getDelimiter(), tableConfig.getEof(), temp);
            return;
        }
        buildCsvCopyBuffer(record, tableConfig.getDelimiter(), temp);
    }

    public static String buildStdinCopyBuffer(List records, String delimiter, String eof) {
        Record first = records.get(0);
        BitSet columnBit = first.getColumnBit();
        StringBuilder copyBuffer = new StringBuilder();
        for (Record record : records) {
            columnBit.stream().forEach(idx -> {
                String value;
                if (record.getDataConvert() != null) {
                    value = convertData(record, idx);
                } else {
                    value = IdentifierUtil.replaceValue(String.valueOf(record.getValue(idx)));
                }
                copyBuffer.append(value);
                if (!value.isEmpty() && value.charAt(value.length() - 1) == '\\') {
                    copyBuffer.append(" ");
                }
                copyBuffer.append(delimiter);
            });
            copyBuffer.setLength(copyBuffer.length() - 1);
            copyBuffer.append(eof);
        }
        return copyBuffer.toString();
    }

    private static String convertData(Record record, int idx) {
        if (record.getDataConvert() == null) {
            return String.valueOf(record.getValue(idx));
        }
        return String.valueOf(
            record.getDataConvert().convert(record.getValue(idx), record.getTableSchema().getColumns().get(idx), idx));
    }

    public static void buildStdinCopyBuffer(Record record, String delimiter, String eof, StringBuilder copyBuffer) {
        BitSet columnBit = record.getColumnBit();
        columnBit.stream().forEach(idx -> {
            String value = convertData(record, idx);
            copyBuffer.append(value);
            copyBuffer.append(delimiter);
        });
        copyBuffer.setLength(copyBuffer.length() - 1);
        copyBuffer.append(eof);
    }

    public static String buildCsvCopyBuffer(List records, DwsConfig config, String delimiter) {
        Record first = records.get(0);
        BitSet columnBit = first.getColumnBit();
        StringBuilder copyBuffer = new StringBuilder();
        for (Record record : records) {
            if (config.getLogDataTables() != null
                && config.getLogDataTables().contains(record.getTableSchema().getTableName())) {
                log.info("add data to copy buffer, data = {}",
                    JSONObject.toJSONString(RecordUtil.toMap(record), SerializerFeature.WriteMapNullValue));
            }
            copyBuffer.append(handleCopyBuffer(delimiter, columnBit, record));
            copyBuffer.append("\n");
        }
        return copyBuffer.toString();
    }

    public static void buildCsvCopyBuffer(Record record, String delimiter, StringBuilder copyBuffer) {
        copyBuffer.append(handleCopyBuffer(delimiter, record.getColumnBit(), record));
        copyBuffer.append("\n");
    }

    private static StringJoiner handleCopyBuffer(String delimiter, BitSet columnBit, Record record) {
        StringJoiner sj = new StringJoiner(delimiter);
        columnBit.stream().forEach(idx -> {
            Object objectValue = record.getValue(idx);
            String value = convertData(record, idx);
            if (objectValue instanceof String) {
                value = format("\"%s\"", IdentifierUtil.replaceValueForCsv(value));
            }
            if (objectValue instanceof PGobject) {
                // 兼容json\jsonb格式
                value = format("\"%s\"", IdentifierUtil.replaceValueForCsv(((PGobject) objectValue).getValue()));
            }
            sj.add(value);
        });
        return sj;
    }

    /**
     * 获取copy sql
     *
     * @return
     */
    public static String getCopyFromStdinStatement(String tableName, List fieldNames, String delimiter,
        String eof) {
        String fieldNamesString =
            fieldNames.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        return "COPY " + tableName + "(" + fieldNamesString + ")" + " FROM STDIN DELIMITER " + "'" + delimiter + "'"
            + " ENCODING" + " 'UTF8' NULL 'null'" + " eol" + " '" + eof + "'";
    }

    public static String getCopyFromStdinStatement(String tableName, List fieldNames, TableConfig tableConfig) {
        CopyMode copyMode = tableConfig.getCopyMode();
        if (copyMode == CopyMode.DELIMITER) {
            return getCopyFromStdinStatement(tableName, fieldNames, tableConfig.getDelimiter(), tableConfig.getEof());
        }
        return getCopyFromStdinAsCsvStatement(tableName, fieldNames, tableConfig.getDelimiter());
    }

    public static String getCopyFromStdinAsCsvStatement(String tableName, List fieldNames, String delimiter) {
        String fieldNamesString =
            fieldNames.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        return "COPY " + tableName + "(" + fieldNamesString + ")" + " FROM STDIN with(format 'csv', delimiter '"
            + delimiter + "', null 'null') ";
    }

    public static void executeCopy(BaseConnection connection, String sql, String copyBuffer) throws Exception {
        CopyManager perform = new CopyManager(connection);
        try (StringReader reader = new StringReader(copyBuffer)) {
            perform.copyIn(sql, reader);
        }
    }

    /**
     * 获取mergeSQL
     */
    public static String getMergeIntoSql(String table, String fromTable, List pkNames,
        List updateColumns, List insertColumns, ConflictStrategy strategy) {
        String uniqueColumns =
            pkNames.stream().map(item -> format("o.%s = t.%s ", item, item)).collect(Collectors.joining(" AND "));
        String updateSets =
            updateColumns.stream().map(item -> format("o.%s = t.%s ", item, item)).collect(Collectors.joining(", "));
        String insertFields =
            insertColumns.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        String insertValues = insertColumns.stream().map(key -> format("t.%s", key)).collect(Collectors.joining(", "));
        StringBuilder mergeSql = new StringBuilder();
        mergeSql.append("MERGE INTO ").append(table).append(" o \n");
        mergeSql.append("USING ").append(fromTable).append(" t ");
        mergeSql.append("ON (").append(uniqueColumns).append(") \n");
        mergeSql.append("WHEN MATCHED THEN \n");
        mergeSql.append("UPDATE SET ").append(updateSets);
        TableConfig config = LocalUtil.getTableConfig();
        Set compareField = config.getCompareField();
        if (strategy == ConflictStrategy.INSERT_OR_IGNORE) {
            mergeSql.append(" WHERE 1 > 1 ");
        } else if (compareField != null) {
            mergeSql.append(" WHERE ");
            String whereSql = compareField.stream()
                .map(item -> format("o.%s < t.%s ", item, item))
                .collect(Collectors.joining(" AND "));
            mergeSql.append(whereSql);
        }
        mergeSql.append("WHEN NOT MATCHED THEN \n");
        mergeSql.append("INSERT (").append(insertFields).append(") VALUES (");
        mergeSql.append(insertValues).append(");");
        return mergeSql.toString();
    }

    /**
     * 获取从源表upsert sql
     */
    public static String getUpsertFromSql(String table, String fromTable, List pkNames,
        List updateColumns, List insertColumns, ConflictStrategy strategy) {
        String uniqueColumns = pkNames.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        String insertFields =
            insertColumns.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        StringBuilder upsertFromSql = new StringBuilder();
        upsertFromSql.append("INSERT INTO ")
            .append(table)
            .append("(")
            .append(insertFields)
            .append(")")
            .append(" SELECT ")
            .append(insertFields)
            .append(" FROM ")
            .append(fromTable);
        String conflictSql = " DO NOTHING";
        TableConfig config = LocalUtil.getTableConfig();
        Set compareField = config.getCompareField();
        String whereSql = "";
        if (strategy != ConflictStrategy.INSERT_OR_IGNORE) {
            String updateClause = updateColumns.stream()
                .map(f -> IdentifierUtil.quoteIdentifier(f) + "=EXCLUDED." + IdentifierUtil.quoteIdentifier(f))
                .collect(Collectors.joining(", "));
            if (!updateClause.isEmpty()) {
                conflictSql = format(" DO UPDATE SET %s ", updateClause);
            }
            if (compareField != null) {
                whereSql = " WHERE " + compareField.stream()
                    .map(f -> IdentifierUtil.quoteIdentifier(f) + "< EXCLUDED." + IdentifierUtil.quoteIdentifier(f))
                    .collect(Collectors.joining(" AND "));
            }
        }
        upsertFromSql.append(" ON CONFLICT ( ").append(uniqueColumns).append(")").append(conflictSql).append(whereSql);
        return upsertFromSql.toString();
    }

    public static String getSelectFromStatement(String tableName, List selectFields,
        List conditionFields) {
        String fieldExpressions = conditionFields.stream()
            .map(f -> format("%s = ?", IdentifierUtil.quoteIdentifier(f)))
            .collect(Collectors.joining(" AND "));
        return getBaseSelectFromStatement(tableName, selectFields)
            + (conditionFields.size() > 0 ? " WHERE " + fieldExpressions : "");
    }

    public static String getSelectFromStatementWithFilterCondition(String tableName, List selectFields,
        List filterConditions) {
        if (filterConditions.size() > 0) {
            String joinedConditions =
                filterConditions.stream().map(pred -> format("%s", pred)).collect(Collectors.joining(" AND "));
            return getBaseSelectFromStatement(tableName, selectFields) + " WHERE " + joinedConditions;
        }
        return getBaseSelectFromStatement(tableName, selectFields);
    }

    public static String getBaseSelectFromStatement(String tableName, List selectFields) {
        String selectExpressions =
            selectFields.stream().map(IdentifierUtil::quoteIdentifier).collect(Collectors.joining(", "));
        return "SELECT " + selectExpressions + " FROM " + tableName;
    }

    public static String getBatchSelectFromStatement(TableSchema schema, TableName tableName, List recordList) {
        BitSet columnMask = new BitSet(schema.getColumns().size());
        for (Get get : recordList) {
            columnMask.or(get.getRecord().getColumnBit());
        }
        boolean first = true;
        StringBuilder sb = new StringBuilder();
        sb.append("select ");
        for (PrimitiveIterator.OfInt it = columnMask.stream().iterator(); it.hasNext();) {
            if (!first) {
                sb.append(",");
            }
            first = false;
            sb.append(IdentifierUtil.quoteIdentifier(schema.getColumn(it.next()).getName()));
        }
        sb.append(" from ").append(tableName.getFullName()).append(" where ");
        for (int i = 0; i < recordList.size(); ++i) {
            if (i > 0) {
                sb.append(" or ");
            }
            first = true;

            sb.append("( ");
            for (Column key : schema.getPrimaryKeys()) {
                if (!first) {
                    sb.append(" and ");
                }
                first = false;
                sb.append(IdentifierUtil.quoteIdentifier(key.getName())).append("=?");
            }
            sb.append(" ) ");
        }
        return sb.toString();
    }

    public static String getScanFromStatement(TableSchema schema, TableName tableName, Scan scan) {
        BitSet columnMask = new BitSet(schema.getColumns().size());
        Record record = scan.getRecord();
        if (scan.getSelectedColumns() != null) {
            columnMask.or(scan.getSelectedColumns());
        } else {
            columnMask.set(0, schema.getColumns().size());
        }
        boolean isFirst = true;
        StringBuilder sb = new StringBuilder();
        sb.append("select ");
        for (PrimitiveIterator.OfInt it = columnMask.stream().iterator(); it.hasNext();) {
            if (!isFirst) {
                sb.append(",");
            }
            isFirst = false;
            sb.append(IdentifierUtil.quoteIdentifier(schema.getColumn(it.next()).getName()));
        }
        sb.append(" from ").append(tableName.getFullName());

        // 填充查询条件
        isFirst = true;
        for (PrimitiveIterator.OfInt it = record.getColumnBit().stream().iterator(); it.hasNext();) {
            if (isFirst) {
                sb.append(" where ");
                isFirst = false;
            } else {
                sb.append(" and ");
            }
            sb.append(IdentifierUtil.quoteIdentifier(schema.getColumn(it.next()).getName())).append("=?");
        }
        return sb.toString();
    }

    public static void executeBatchRecordSql(List records, Connection connection, TableSchema schema,
        List keys, String sql, DwsConfig config) throws SQLException {
        if (records == null || records.isEmpty()) {
            return;
        }
        long starAddBatch = System.currentTimeMillis();
        List columnIndexList = new ArrayList<>(keys.size());
        for (String key : keys) {
            columnIndexList.add(new ColumnIndex(schema.getColumn(key), schema.getColumnIndex(key)));
        }
        try (PreparedStatement statement = connection.prepareStatement(sql)) {
            for (Record record : records) {
                for (int i = 0; i < columnIndexList.size(); i++) {
                    ColumnIndex columnIndex = columnIndexList.get(i);
                    Object value = record.getValue(columnIndex.getIndex());
                    if (record.getDataConvert() != null) {
                        value = record.getDataConvert().convert(value, columnIndex.getColumn(), columnIndex.getIndex());
                    }
                    JdbcUtil.fillPreparedStatement(statement, i + 1, value, columnIndex.getColumn());
                }
                statement.addBatch();
            }
            log.info("add batch time = {}", System.currentTimeMillis() - starAddBatch);
            long starExecute = System.currentTimeMillis();
            statement.executeBatch();
            log.info("execute batch time = {}", System.currentTimeMillis() - starExecute);
            LogUtil.withLogSwitch(config,
                () -> log.info("execute batch = {}", System.currentTimeMillis() - starExecute));
        }
    }

    public static void executeSql(Connection connection, String sql) throws SQLException {
        try (Statement statement = connection.createStatement()) {
            statement.execute(sql);
        }
    }

    public static String getAddColumn(Map schema, String column, String table) {
        // 执行加字段SQL
        StringBuilder sql = new StringBuilder(format("ALTER TABLE IF EXISTS %s ", table));
        StringJoiner columnsJoin = new StringJoiner(",");
        columnsJoin.add(format("ADD COLUMN \"%s\" %s", column, MySqlTypeUtils.toDataType(schema.get(column)).asSQL()));
        sql.append(columnsJoin);
        log.info("add column sql: {}", sql);
        return sql.toString();

    }

    public static String getDbUrl(String baseUrl, String db) {
        return baseUrl.endsWith("/") ? baseUrl + db : baseUrl + "/" + db;
    }

    public static TypeDefinition getTypeDefinition(Connection connection, TableName typeName)
        throws DwsClientException {
        TypeDefinition def = new TypeDefinition(typeName);
        PgConnection pgConn = (PgConnection) connection;
        try (ResultSet rs = queryTypeDefinition(pgConn, typeName)) {
            while (rs.next()) {
                TypeColumn column = new TypeColumn();
                column.setAttrName(rs.getString("attname"));
                column.setTypeName(rs.getString("typname"));
                int oid = rs.getInt("oid");
                String typType = rs.getString("typtype");
                column.setOid(oid);
                column.setTypType(typType);
                int sqlType;
                if ("c".equals(typType)) {
                    sqlType = 2002;
                } else if ("d".equals(typType)) {
                    sqlType = 2001;
                } else {
                    sqlType = pgConn.getTypeInfo().getSQLType(oid);
                }
                column.setType(sqlType);
                def.addTypeColumn(column);
            }
            return def;
        } catch (Exception exception) {
            throw DwsClientException.fromException(exception);
        }
    }

    private static ResultSet queryTypeDefinition(Connection connection, TableName typeName) throws DwsClientException {
        String sql = "SELECT a.attname, tp.typname, tp.oid, tp.typtype FROM pg_catalog.pg_type t "
            + "  LEFT OUTER JOIN pg_catalog.pg_namespace n on n.oid = t.typnamespace "
            + "  LEFT OUTER JOIN pg_catalog.pg_attribute a ON a.attrelid = t.typrelid "
            + "  LEFT OUTER JOIN pg_catalog. pg_type tp ON tp.oid = a.atttypid "
            + "  where t.typname = ? and n.nspname = ?  ORDER BY a.attnum ASC";
        try (PreparedStatement statement = connection.prepareStatement(sql)) {
            statement.setString(1, typeName.getTableName());
            statement.setString(2, typeName.getSchemaName());
            return statement.executeQuery();
        } catch (Exception exception) {
            throw DwsClientException.fromException(exception);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy