All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mybatisintercept.util.ASTDruidUtil Maven / Gradle / Ivy

package com.github.mybatisintercept.util;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlShowStatement;
import com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter;

import java.util.List;
import java.util.function.BiPredicate;
import java.util.function.Predicate;

public class ASTDruidUtil {

    public static String addAndCondition(String sql, String injectCondition, ExistInjectConditionStrategyEnum existInjectConditionStrategyEnum, String dbType, BiPredicate skip, Predicate isJoinUniqueKey, List excludeInjectCondition) {
        return ASTDruidConditionUtil.addCondition(sql, injectCondition, SQLBinaryOperator.BooleanAnd, false, existInjectConditionStrategyEnum, dbType, skip, isJoinUniqueKey, excludeInjectCondition);
    }

    static SQLExpr toValueExpr(Object value) {
        if (value == null) {
            return new SQLNullExpr();
        } else if (value instanceof String) {
            return new SQLCharExpr((String) value);
        } else {
            return SQLUtils.toSQLExpr(String.valueOf(value));
        }
    }

    private static boolean addSelectItem(SQLSelect query, SQLExpr valueExpr) {
        if (query == null) {
            return true;
        }
        SQLSelectQuery select = query.getQuery();
        if (select instanceof SQLSelectQueryBlock) {
            SQLSelectQueryBlock queryBlock = ((SQLSelectQueryBlock) select);
            queryBlock.addSelectItem(valueExpr);
            return true;
        } else {
            return false;
        }
    }

    public static String addColumnValues(String rawSql, String columnName, Object value, String dbType) {
        List sqlStatements = SQLUtils.parseStatements(rawSql, dbType);
        if (sqlStatements.size() != 1) {
            throw new IllegalStateException("addColumnValues sqlStatements.size() != 1. sql = " + rawSql);
        }
        SQLStatement sqlStatement = sqlStatements.get(0);
        SQLIdentifierExpr columnExpr = new SQLIdentifierExpr(columnName);

        SQLExpr valueExpr = toValueExpr(value);
        if (sqlStatement instanceof SQLInsertStatement) {
            SQLInsertStatement statement = ((SQLInsertStatement) sqlStatement);
            int columnIndex = columnIndex(statement.getColumns(), columnName);
            SQLSelect query = statement.getQuery();
            // 用户未填写字段
            if (columnIndex == -1) {
                statement.getColumns().add(columnExpr);
                if (query != null) {
                    // insert into `base_area` (`id`, `name`) select id,name from copy
                    if (!addSelectItem(query, valueExpr)) {
                        throw new IllegalStateException("not support addColumnValues. sql = " + rawSql + ", columnName = " + columnName);
                    }
                } else {
                    // insert into `base_area` (`id`, `name`) values (1, '2')
                    for (SQLInsertStatement.ValuesClause valuesClause : statement.getValuesList()) {
                        valuesClause.addValue(valueExpr);
                    }
                }
            } else {
                // 用户填写了字段
                if (query == null) {
                    for (SQLInsertStatement.ValuesClause valuesClause : statement.getValuesList()) {
                        valuesClause.getValues().set(columnIndex, valueExpr);
                    }
                } else {
                    // 这里表示用户如写明了字段,交给用户处理
                }
            }
            return SQLUtils.toSQLString(sqlStatement, dbType);
        } else if (sqlStatement instanceof SQLReplaceStatement) {
            SQLReplaceStatement statement = ((SQLReplaceStatement) sqlStatement);
            SQLQueryExpr query = statement.getQuery();

            int columnIndex = columnIndex(statement.getColumns(), columnName);
            // 用户未填写字段
            if (columnIndex == -1) {
                statement.getColumns().add(columnExpr);
                if (query != null) {
                    // replace into `base_area` (`id`, `name`) select id,name from copy
                    if (!addSelectItem(query.getSubQuery(), valueExpr)) {
                        throw new IllegalStateException("not support addColumnValues. sql = " + rawSql + ", columnName = " + columnName);
                    }
                } else {
                    // replace into `base_area` (`id`, `name`) values (1, '2')
                    for (SQLInsertStatement.ValuesClause valuesClause : statement.getValuesList()) {
                        valuesClause.addValue(valueExpr);
                    }
                }
            } else {
                // 用户填写了字段
                if (query == null) {
                    for (SQLInsertStatement.ValuesClause valuesClause : statement.getValuesList()) {
                        valuesClause.getValues().set(columnIndex, valueExpr);
                    }
                } else {
                    // 这里表示用户如写明了字段,交给用户处理
                }
            }
            return SQLUtils.toSQLString(sqlStatement, dbType);
        } else {
            throw new IllegalStateException("addColumnValues sqlStatements.no support. sql = " + rawSql);
        }
    }

    private static boolean isSupportWhere(SQLStatement statement) {
        if (statement instanceof SQLInsertStatement || statement instanceof SQLReplaceStatement) {
            // INSERT INTO user_copy (id, name) SELECT id, name FROM user
            boolean[] existSelect = new boolean[1];
            statement.accept(new SQLASTVisitorAdapter() {
                @Override
                public boolean visit(SQLSelectQueryBlock x) {
                    existSelect[0] = true;
                    return false;
                }
            });
            return existSelect[0];
        } else if (statement instanceof MySqlShowStatement || statement instanceof SQLSetStatement) {
            return false;
        } else {
            return true;
        }
    }

    public static boolean isNoSkipUpdate(String rawSql, String dbType, BiPredicate skip) {
        List statementList;
        try {
            statementList = SQLUtils.parseStatements(rawSql, dbType);
        } catch (Exception e) {
            return false;
        }
        // SingleStatement
        if (statementList.size() != 1) {
            return false;
        }
        SQLStatement sqlStatement = statementList.get(0);
        if (sqlStatement instanceof SQLUpdateStatement) {
            SQLTableSource tableSource = ((SQLUpdateStatement) sqlStatement).getTableSource();
            if (tableSource instanceof SQLExprTableSource) {
                // 单表
                SQLExprTableSource table = (SQLExprTableSource) tableSource;
                return !skip.test(SQLUtils.normalize(table.getSchema(), null), SQLUtils.normalize(table.getName().getSimpleName(), null));
            } else if (tableSource instanceof SQLJoinTableSource) {
                // todo 多表
                return true;
            } else {
                // 其他未知语法
                return false;
            }
        } else {
            // 非update语句
            return false;
        }
    }

    public static boolean isNoSkipInsertOrReplace(String rawSql, String dbType, BiPredicate skip) {
        List statementList;
        try {
            statementList = SQLUtils.parseStatements(rawSql, dbType);
        } catch (Exception e) {
            return false;
        }
        // SingleStatement
        if (statementList.size() != 1) {
            return false;
        }
        SQLStatement sqlStatement = statementList.get(0);
        if (sqlStatement instanceof SQLInsertStatement) {
            SQLExprTableSource table = ((SQLInsertStatement) sqlStatement).getTableSource();
            return !skip.test(SQLUtils.normalize(table.getSchema(), null), SQLUtils.normalize(table.getName().getSimpleName(), null));
        } else if (sqlStatement instanceof SQLReplaceStatement) {
            SQLExprTableSource table = ((SQLReplaceStatement) sqlStatement).getTableSource();
            return !skip.test(SQLUtils.normalize(table.getSchema(), null), SQLUtils.normalize(table.getName().getSimpleName(), null));
        } else {
            return false;
        }
    }

    public static boolean isSingleStatementAndSupportWhere(String rawSql, String dbType) {
        List sqlStatements;
        try {
            sqlStatements = SQLUtils.parseStatements(rawSql, dbType);
        } catch (Exception e) {
            return false;
        }
        // SingleStatement
        if (sqlStatements.size() != 1) {
            return false;
        }
        return isSupportWhere(sqlStatements.get(0));
    }

    public static int getColumnParameterizedIndex(String rawSql, String columnName, String dbType) {
        List sqlStatements = SQLUtils.parseStatements(rawSql, dbType);
        // SingleStatement
        if (sqlStatements.size() != 1) {
            throw new IllegalStateException("getColumns sqlStatements.size() != 1. sql = " + rawSql);
        }
        SQLStatement sqlStatement = sqlStatements.get(0);
        if (sqlStatement instanceof SQLInsertStatement) {
            SQLInsertStatement statement = ((SQLInsertStatement) sqlStatement);
            return columnParameterizedIndex(statement.getColumns(), statement.getValuesList(), columnName);
        } else if (sqlStatement instanceof SQLReplaceStatement) {
            SQLReplaceStatement statement = ((SQLReplaceStatement) sqlStatement);
            return columnParameterizedIndex(statement.getColumns(), statement.getValuesList(), columnName);
        } else {
            throw new IllegalStateException("getColumns not support. sql = " + rawSql);
        }
    }

    private static SQLExpr valueAt(List values, int index) {
        if (values == null || values.isEmpty()) {
            return null;
        } else {
            return values.get(index);
        }
    }

    private static int columnParameterizedIndex(List columns, List valuesList, String columnName) {
        int i = 0;
        List values = valuesList == null || valuesList.isEmpty() ? null : valuesList.get(0).getValues();
        for (SQLExpr column : columns) {
            String name = SQLUtils.normalize(column.toString(), null);
            SQLExpr value = valueAt(values, i);
            if (columnName.equalsIgnoreCase(name)) {
                if (value instanceof SQLVariantRefExpr) {
                    return ((SQLVariantRefExpr) value).getIndex();
                } else {
                    return -2;
                }
            }
            i++;
        }
        return -1;
    }

    private static int columnIndex(List columns, String columnName) {
        int i = 0;
        for (SQLExpr column : columns) {
            String name = SQLUtils.normalize(column.toString(), null);
            if (columnName.equalsIgnoreCase(name)) {
                return i;
            }
            i++;
        }
        return -1;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy