All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mybatisintercept.util.ASTDruidConditionUtil Maven / Gradle / Ivy

package com.github.mybatisintercept.util;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlShowStatement;
import com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter;
import com.alibaba.druid.sql.visitor.SQLEvalVisitor;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.function.BiPredicate;
import java.util.function.Predicate;

public class ASTDruidConditionUtil {
    private static final Method DB_TYPE_METHOD;

    static {
        Method dbTypeMethod;
        try {
            Class clazz = Class.forName("com.alibaba.druid.DbType");
            dbTypeMethod = clazz.getDeclaredMethod("of", String.class);
        } catch (Exception e) {
            dbTypeMethod = null;
        }
        DB_TYPE_METHOD = dbTypeMethod;
    }

    public static List getColumnList(String injectCondition) {
        SQLExpr injectConditionExpr = SQLUtils.toSQLExpr(injectCondition, getDbType(null));
        List list = new ArrayList<>();
        injectConditionExpr.accept(new SQLASTVisitorAdapter() {

            @Override
            public boolean visit(SQLInSubQueryExpr statement) {
                SQLExpr expr = statement.getExpr();
                String name;
                if (expr instanceof SQLPropertyExpr) {
                    name = ((SQLPropertyExpr) expr).getName();
                } else if (expr instanceof SQLIdentifierExpr) {
                    name = ((SQLIdentifierExpr) expr).getName();
                } else {
                    return false;
                }
                String col = normalize(name);
                list.add(col);
                return false;
            }

            @Override
            public boolean visit(SQLSelectQueryBlock statement) {
                return false;
            }

            @Override
            public boolean visit(SQLPropertyExpr x) {
                String col = normalize(x.getName());
                list.add(col);
                return true;
            }

            @Override
            public boolean visit(SQLIdentifierExpr x) {
                String col = normalize(x.getName());
                list.add(col);
                return true;
            }
        });
        return list;
    }

    public static String addCondition(String sql, String injectCondition, SQLBinaryOperator op,
                                      boolean appendConditionToLeft, ExistInjectConditionStrategyEnum existInjectConditionStrategyEnum,
                                      String dbType, BiPredicate skip, Predicate isJoinUniqueKey) {
        return addCondition(sql, injectCondition, op, appendConditionToLeft, existInjectConditionStrategyEnum, dbType, skip, isJoinUniqueKey, null);
    }

    public static String addCondition(String sql, String injectCondition, SQLBinaryOperator op,
                                      boolean appendConditionToLeft, ExistInjectConditionStrategyEnum existInjectConditionStrategyEnum,
                                      String dbType, BiPredicate skip, Predicate isJoinUniqueKey, List excludeInjectCondition) {
        if (injectCondition == null || injectCondition.isEmpty()) {
            return sql;
        }
        List stmtList = SQLUtils.parseStatements(sql, dbType);
        if (stmtList.size() != 1) {
            throw new IllegalArgumentException("not support statement :" + sql);
        }
        SQLStatement ast = stmtList.get(0);

        SQLExpr injectConditionExpr = SQLUtils.toSQLExpr(injectCondition, getDbType(dbType));
        List excludeInjectConditionExprList = getExcludeInjectConditionExprList(excludeInjectCondition, dbType);
        boolean change = addCondition(sql, ast, op, injectConditionExpr, appendConditionToLeft, existInjectConditionStrategyEnum, wrapDialectSkip(dbType, skip), isJoinUniqueKey, excludeInjectConditionExprList);
        if (change) {
            return SQLUtils.toSQLString(ast, dbType);
        } else {
            return sql;
        }
    }

    private static List getExcludeInjectConditionExprList(List excludeInjectCondition, String dbType) {
        if (excludeInjectCondition == null) {
            return null;
        }
        List list = new ArrayList<>(excludeInjectCondition.size());
        for (String s : excludeInjectCondition) {
            list.add(SQLUtils.toSQLExpr(s, getDbType(dbType)));
        }
        return list;
    }

    private static BiPredicate wrapDialectSkip(String dbType, BiPredicate skip) {
        switch (dbType) {
            case "MARIADB":
            case "Mariadb":
            case "mariadb":
            case "MYSQL":
            case "Mysql":
            case "mysql": {
                return (schema, tableName) -> {
                    if ("dual".equalsIgnoreCase(tableName)) {
                        return true;
                    } else {
                        return skip != null && skip.test(schema, tableName);
                    }
                };
            }
            default: {
                return skip == null ? (schema, tableName) -> false : skip;
            }
        }
    }

    private static String getAlias(SQLPropertyExpr expr) {
        if (expr == null) {
            return null;
        } else {
            return normalize(expr.getOwnernName());
        }
    }

    private static String getAlias(SQLTableSource tableSource) {
        if (tableSource == null) {
            // 这种sql => SELECT @rownum := 0, @rowtotal := NULL
            return null;
        } else if (tableSource instanceof SQLJoinTableSource) {
            // join
            return getAlias(((SQLJoinTableSource) tableSource).getLeft());
        } else {
            String alias = tableSource.getAlias();
            if (alias != null) {
                return alias;
            } else {
                return getTableName(tableSource);
            }
        }
    }

    private static boolean isSubqueryOrUnion(SQLTableSource from) {
        if (from instanceof SQLJoinTableSource) {
            return isSubqueryOrUnion(((SQLJoinTableSource) from).getLeft());
        } else if (from instanceof SQLSubqueryTableSource) {
            // 子查询
            return true;
        } else if (from instanceof SQLUnionQueryTableSource) {
            // 联合查询
            return true;
        } else if (from instanceof SQLWithSubqueryClause) {
            // mysql 8 With
            return true;
        } else {
            return false;
        }
    }

    private static String getTableName(SQLTableSource tableSource) {
        if (tableSource == null) {
            // 这种sql => SELECT @rownum := 0, @rowtotal := NULL
            return null;
        } else if (tableSource instanceof SQLJoinTableSource) {
            // join
            return getTableName(((SQLJoinTableSource) tableSource).getLeft());
        } else if (tableSource instanceof SQLExprTableSource) {
            SQLName name = ((SQLExprTableSource) tableSource).getName();
            return name != null ? normalize(name.getSimpleName()) : null;
        } else {
            return null;
        }
    }

    private static String getTableSchema(SQLTableSource tableSource) {
        if (tableSource == null) {
            // 这种sql => SELECT @rownum := 0, @rowtotal := NULL
            return null;
        } else if (tableSource instanceof SQLJoinTableSource) {
            // join
            return getTableSchema(((SQLJoinTableSource) tableSource).getLeft());
        } else if (tableSource instanceof SQLExprTableSource) {
            return normalize(((SQLExprTableSource) tableSource).getSchema());
        } else {
            return null;
        }
    }

    private static String normalize(String name) {
        return SQLUtils.normalize(name, null);
    }

    private static SQLExpr getCondition(SQLTableSource tableSource) {
        if (tableSource instanceof SQLJoinTableSource) {
            SQLJoinTableSource join = ((SQLJoinTableSource) tableSource);
            SQLTableSource left = join.getLeft();
            if (left instanceof SQLJoinTableSource) {
                return getCondition(left);
            } else {
                return join.getCondition();
            }
        } else {
            return null;
        }
    }

    private static SQLJoinTableSource.JoinType getJoinType(SQLTableSource tableSource) {
        if (tableSource instanceof SQLJoinTableSource) {
            SQLJoinTableSource join = ((SQLJoinTableSource) tableSource);
            SQLTableSource left = join.getLeft();
            if (left instanceof SQLJoinTableSource) {
                return getJoinType(left);
            } else {
                return join.getJoinType();
            }
        } else {
            return null;
        }
    }

    private static boolean existAlias(String alias, SQLExpr condition) {
        if (condition instanceof SQLBinaryOpExpr) {
            LinkedList binaryOpExprLinkedList = new LinkedList<>();
            binaryOpExprLinkedList.add((SQLBinaryOpExpr) condition);
            while (!binaryOpExprLinkedList.isEmpty()) {
                SQLBinaryOpExpr binaryOpExpr = binaryOpExprLinkedList.removeFirst();

                SQLExpr left1 = binaryOpExpr.getLeft();
                SQLExpr right1 = binaryOpExpr.getRight();
                if (left1 instanceof SQLBinaryOpExpr) {
                    binaryOpExprLinkedList.add((SQLBinaryOpExpr) left1);
                } else if (left1 instanceof SQLPropertyExpr) {
                    String alias1 = getAlias((SQLPropertyExpr) left1);
                    if (alias.equalsIgnoreCase(alias1)) {
                        return true;
                    }
                }
                if (right1 instanceof SQLBinaryOpExpr) {
                    binaryOpExprLinkedList.add((SQLBinaryOpExpr) right1);
                } else if (left1 instanceof SQLPropertyExpr) {
                    String alias1 = getAlias((SQLPropertyExpr) left1);
                    if (alias.equalsIgnoreCase(alias1)) {
                        return true;
                    }
                }
            }
        }
        return false;
    }

    private static boolean existInjectConditionAndPreparedCollect(SQLStatement readonlyAst,
                                                                  List injectConditionColumnList,
                                                                  BiPredicate skip,
                                                                  Map tableAliasMap) {
        boolean[] exist = new boolean[1];
        readonlyAst.accept(new SQLASTVisitorAdapter() {
            private boolean select;
            private boolean update;
            private boolean delete;

            @Override
            public void endVisit(SQLExprTableSource tableSource) {
                String alias = getAlias(tableSource);
                if (alias != null) {
                    tableAliasMap.put(alias, tableSource);
                }
            }

            @Override
            public boolean visit(SQLSelectQueryBlock statement) {
                select = true;
                SQLTableSource from = statement.getFrom();
                if (from == null || isSubqueryOrUnion(from) || skip.test(getTableSchema(from), getTableName(from))) {
                    return true;
                }
                String alias = getAlias(from);
                if (existInjectCondition(injectConditionColumnList, alias, statement.getWhere())) {
                    exist[0] = true;
                    return false;
                }
                return true;
            }

            @Override
            public void endVisit(SQLSelectQueryBlock x) {
                select = false;
            }

            @Override
            public boolean visit(SQLJoinTableSource statement) {
                if (!select) {
                    return true;
                }
                SQLTableSource from = statement.getRight();
                if (from == null || isSubqueryOrUnion(from) || skip.test(getTableSchema(from), getTableName(from))) {
                    return true;
                }
                String alias = getAlias(from);
                switch (statement.getJoinType()) {
                    case COMMA: {
                        SQLObject parent = statement.getParent();
                        if (parent instanceof SQLSelectQueryBlock && existInjectCondition(injectConditionColumnList, alias, ((SQLSelectQueryBlock) parent).getWhere())) {
                            exist[0] = true;
                            return false;
                        }
                        return true;
                    }
                    default: {
                        if (existInjectCondition(injectConditionColumnList, alias, statement.getCondition())) {
                            exist[0] = true;
                            return false;
                        }
                        return true;
                    }
                }
            }

            @Override
            public boolean visit(SQLDeleteStatement statement) {
                delete = true;

                LinkedList temp = new LinkedList<>();
                temp.add(statement.getTableSource());
                while (!temp.isEmpty()) {
                    SQLTableSource tableSource = temp.removeFirst();
                    if (tableSource == null) {
                        continue;
                    }

                    if (tableSource instanceof SQLJoinTableSource) {
                        temp.add(((SQLJoinTableSource) tableSource).getLeft());
                        temp.add(((SQLJoinTableSource) tableSource).getRight());
                    } else {
                        if (isSubqueryOrUnion(tableSource) || skip.test(getTableSchema(tableSource), getTableName(tableSource))) {
                            continue;
                        }
                        String alias = getAlias(tableSource);
                        if (existInjectCondition(injectConditionColumnList, alias, statement.getWhere())) {
                            exist[0] = true;
                            return false;
                        }
                    }
                }
                return true;
            }

            @Override
            public void endVisit(SQLDeleteStatement x) {
                delete = false;
            }

            @Override
            public boolean visit(SQLUpdateStatement statement) {
                update = true;

                LinkedList temp = new LinkedList<>();
                temp.add(statement.getTableSource());
                while (!temp.isEmpty()) {
                    SQLTableSource tableSource = temp.removeFirst();
                    if (tableSource == null) {
                        continue;
                    }

                    if (tableSource instanceof SQLJoinTableSource) {
                        temp.add(((SQLJoinTableSource) tableSource).getLeft());
                        temp.add(((SQLJoinTableSource) tableSource).getRight());
                    } else {
                        String alias = getAlias(tableSource);
                        if (isSubqueryOrUnion(tableSource) || skip.test(getTableSchema(tableSource), getTableName(tableSource))) {
                            continue;
                        }
                        if (existInjectCondition(injectConditionColumnList, alias, statement.getWhere())) {
                            exist[0] = true;
                            return false;
                        }
                    }
                }
                return true;
            }

            @Override
            public void endVisit(SQLUpdateStatement x) {
                update = false;
            }
        });
        return exist[0];
    }

    private static void preparedCollect(SQLStatement ast, Map tableAliasMap) {
        ast.accept(new SQLASTVisitorAdapter() {
            @Override
            public void endVisit(SQLExprTableSource tableSource) {
                String alias = getAlias(tableSource);
                if (alias != null) {
                    tableAliasMap.put(alias, tableSource);
                }
            }
        });
    }

    private static Set getTableNameSet(SQLStatement ast) {
        Set tableNameSet = new HashSet<>(2);
        ast.accept(new SQLASTVisitorAdapter() {
            @Override
            public void endVisit(SQLExprTableSource tableSource) {
                tableNameSet.add(getTableName(tableSource));
            }
        });
        return tableNameSet;
    }

    private static void preparedCollectInject(SQLExpr injectCondition, Set cantSpecifyTargetTableNameSet) {
        injectCondition.accept(new InjectMarkSQLASTVisitor() {
            @Override
            public boolean visit(SQLExprTableSource tableSource) {
                cantSpecifyTargetTableNameSet.add(getTableName(tableSource));
                return true;
            }
        });
    }

    private static boolean addCondition(String sql, SQLStatement ast, SQLBinaryOperator op, SQLExpr injectCondition,
                                        boolean appendConditionToLeft, ExistInjectConditionStrategyEnum existInjectConditionStrategyEnum,
                                        BiPredicate skip, Predicate isJoinUniqueKey, Collection excludeInjectCondition) {
        if (ast instanceof MySqlShowStatement || ast instanceof SQLSetStatement) {
            return false;
        }

        boolean needPrepared = true;
        Set cantSpecifyTargetTableNameSet = new HashSet<>(1);
        Map tableAliasMap = new HashMap<>(3);
        List injectConditionColumnList;
        switch (existInjectConditionStrategyEnum) {
            case ANY_TABLE_MATCH_THEN_SKIP_SQL: {
                injectConditionColumnList = flatColumnList(injectCondition);
                if (existInjectConditionAndPreparedCollect(ast, injectConditionColumnList, (schema, tableName) -> false, tableAliasMap)) {
                    return false;
                }
                needPrepared = false;
                break;
            }
            case RULE_TABLE_MATCH_THEN_SKIP_SQL: {
                injectConditionColumnList = flatColumnList(injectCondition);
                if (existInjectConditionAndPreparedCollect(ast, injectConditionColumnList, skip, tableAliasMap)) {
                    return false;
                }
                needPrepared = false;
                break;
            }
            case RULE_TABLE_MATCH_THEN_SKIP_ITEM: {
                injectConditionColumnList = flatColumnList(injectCondition);
                break;
            }
            default:
            case ALWAYS_APPEND: {
                injectConditionColumnList = Collections.emptyList();
                break;
            }
        }

        // 减少1次遍历
        if (needPrepared) {
            preparedCollect(ast, tableAliasMap);
        }
        preparedCollectInject(injectCondition, cantSpecifyTargetTableNameSet);

        boolean[] change = new boolean[1];
        ast.accept(new SQLASTVisitorAdapter() {
            private final SQLCondition sqlJoin = new SQLCondition(sql);
            private final SQLColumn sqlColumn = new SQLColumn();
            private boolean select;
            private boolean update;
            private boolean delete;
            private boolean updateSetItem;
            private String updateTableName;

            private boolean isSelect() {
                if (select) {
                    return true;
                }
                return !update && !delete;
            }

            /**
             * 是否是里所说的update: 1093 - You can't specify target table 'xx' for update in FROM clause
             * @return true=是
             */
            private boolean isCantSpecifyTargetTableUpdate() {
                return update || delete;
            }

            @Override
            public boolean visit(SQLBinaryOpExpr expr) {
                // 1.排除的条件
                return !existExcludeInjectConditionList(excludeInjectCondition, expr);
            }

            @Override
            public boolean visit(SQLSelectQueryBlock statement) {
                select = true;
                return true;
            }

            private boolean addWhere(SQLSelectQueryBlock statement, String alias, String tableSchema, String tableName, SQLCondition.TypeEnum typeEnum, JoinTypeEnum joinTypeEnum, SQLExpr fromCondition) {
                SQLExpr where = statement.getWhere();
                if (isInjectCondition(where)) {
                    return false;
                }
                // 1.规则跳过拼条件
                if (existInjectConditionStrategyEnum == ExistInjectConditionStrategyEnum.RULE_TABLE_MATCH_THEN_SKIP_ITEM
                        && existInjectCondition(injectConditionColumnList, alias, where)) {
                    return false;
                }
                // 2.唯一键跳过拼条件
                if (isJoinUniqueKey != null) {
                    List joinUniqueKeyEqualityList = getJoinUniqueKeyEquality(where);
                    if (joinUniqueKeyEqualityList.isEmpty()) {
                        joinUniqueKeyEqualityList = getJoinUniqueKeyEqualityRetry(typeEnum, joinTypeEnum, fromCondition);
                    }
                    if (!joinUniqueKeyEqualityList.isEmpty()) {
                        sqlJoin.reset(typeEnum, joinTypeEnum, alias, tableSchema, tableName, joinUniqueKeyEqualityList);
                        if (sqlJoin.existParameterizedColumn() & isJoinUniqueKey.test(sqlJoin)) {
                            return false;
                        }
                    }
                }
                // 2.跳过 1093 - You can't specify target table 'xx' for update in FROM clause
                String cantSpecifyTargetTableName = updateSetItem ? updateTableName : tableName;
                if (isCantSpecifyTargetTableForUpdateInFromClause(cantSpecifyTargetTableName, cantSpecifyTargetTableNameSet, isCantSpecifyTargetTableUpdate())) {
                    return false;
                }
                // 3.拼条件
                statement.setWhere(mergeCondition(op, injectCondition, alias, appendConditionToLeft, where));
                return true;
            }

            private List getJoinUniqueKeyEqualityRetry(SQLCondition.TypeEnum typeEnum, JoinTypeEnum joinTypeEnum, SQLExpr fromCondition) {
                if (typeEnum == SQLCondition.TypeEnum.WHERE && joinTypeEnum == JoinTypeEnum.RIGHT_OUTER_JOIN) {
                    return getJoinUniqueKeyEquality(fromCondition);
                } else {
                    return Collections.emptyList();
                }
            }

            @Override
            public void endVisit(SQLSelectQueryBlock statement) {
                try {
                    if (isInjectCondition(statement)) {
                        return;
                    }
                    SQLTableSource from = statement.getFrom();
                    if (from == null || isSubqueryOrUnion(from)) {
                        return;
                    }
                    String tableSchema = getTableSchema(from);
                    String tableName = getTableName(from);
                    if (skip.test(tableSchema, tableName)) {
                        return;
                    }
                    String alias = getAlias(from);
                    JoinTypeEnum joinTypeEnum;
                    SQLExpr fromCondition;
                    if (from instanceof SQLJoinTableSource) {
                        joinTypeEnum = isLeftJoin(alias, (SQLJoinTableSource) from) ? JoinTypeEnum.LEFT_OUTER_JOIN : JoinTypeEnum.RIGHT_OUTER_JOIN;
                        fromCondition = getCondition(from);
                    } else {
                        joinTypeEnum = null;
                        fromCondition = null;
                    }
                    if (addWhere(statement, alias, tableSchema, tableName, SQLCondition.TypeEnum.WHERE, joinTypeEnum, fromCondition)) {
                        change[0] = true;
                    }
                } finally {
                    select = false;
                }
            }

            private boolean isLeftJoin(String selectAlias, SQLJoinTableSource from) {
                boolean anyRight = false;
                LinkedList temp = new LinkedList<>();
                temp.add(from);
                while (!temp.isEmpty()) {
                    SQLJoinTableSource join = temp.removeFirst();
                    SQLJoinTableSource.JoinType joinType = join.getJoinType();
                    SQLTableSource left = join.getLeft();
                    SQLTableSource right = join.getRight();

                    if (joinType == SQLJoinTableSource.JoinType.RIGHT_OUTER_JOIN) {
                        anyRight = true;
                    } else {
                        SQLExpr condition = join.getCondition();
                        if (condition == null && joinType == SQLJoinTableSource.JoinType.COMMA) {
                            SQLObject parent = join.getParent();
                            if (parent instanceof SQLSelectQueryBlock) {
                                condition = ((SQLSelectQueryBlock) parent).getWhere();
                            }
                        }
                        if (existAlias(selectAlias, condition)) {
                            return true;
                        }
                    }
                    if (left instanceof SQLJoinTableSource) {
                        temp.add((SQLJoinTableSource) left);
                    }
                    if (right instanceof SQLJoinTableSource) {
                        temp.add((SQLJoinTableSource) right);
                    }
                }
                return !anyRight;
            }

            @Override
            public void endVisit(SQLJoinTableSource statement) {
                if (!isSelect()) {
                    return;
                }
                SQLTableSource from = statement.getRight();
                if (from == null || isSubqueryOrUnion(from)) {
                    return;
                }
                String tableSchema = getTableSchema(from);
                String tableName = getTableName(from);
                if (skip.test(tableSchema, tableName)) {
                    return;
                }
                if (statement.getJoinType() == SQLJoinTableSource.JoinType.COMMA) {
                    // from table1,table2 where table1.id = table2.xx_id
                    SQLObject parent = statement.getParent();
                    if (parent instanceof SQLSelectQueryBlock && addWhere((SQLSelectQueryBlock) parent, getAlias(from), tableSchema, tableName, SQLCondition.TypeEnum.COMMA, JoinTypeEnum.COMMA, null)) {
                        change[0] = true;
                    }
                } else if (addJoin(statement, getAlias(from), tableSchema, tableName)) {
                    // on table1.id = table2.xx_id
                    change[0] = true;
                }
            }

            private boolean addJoin(SQLJoinTableSource join, String alias, String tableSchema, String tableName) {
                SQLExpr condition = join.getCondition();
                if (isInjectCondition(condition)) {
                    return false;
                }

                // 1.规则跳过拼条件
                if (existInjectConditionStrategyEnum == ExistInjectConditionStrategyEnum.RULE_TABLE_MATCH_THEN_SKIP_ITEM
                        && existInjectCondition(injectConditionColumnList, alias, condition)) {
                    return false;
                }
                // 2.唯一键跳过拼条件
                if (isJoinUniqueKey != null) {
                    List joinUniqueKeyEqualityList = getJoinUniqueKeyEquality(condition);
                    if (!joinUniqueKeyEqualityList.isEmpty()) {
                        sqlJoin.reset(SQLCondition.TypeEnum.JOIN, codeOfJoinTypeEnum(join.getJoinType()), alias, tableSchema, tableName, joinUniqueKeyEqualityList);
                        if (isJoinUniqueKey.test(sqlJoin)) {
                            return false;
                        }
                    }
                }
                // 2.跳过 1093 - You can't specify target table 'xx' for update in FROM clause
                String cantSpecifyTargetTableName = updateSetItem ? updateTableName : tableName;
                if (isCantSpecifyTargetTableForUpdateInFromClause(cantSpecifyTargetTableName, cantSpecifyTargetTableNameSet, isCantSpecifyTargetTableUpdate())) {
                    return false;
                }
                // 3.拼条件
                join.setCondition(mergeCondition(op, injectCondition, alias, appendConditionToLeft, condition));
                return true;
            }

            private JoinTypeEnum codeOfJoinTypeEnum(SQLJoinTableSource.JoinType joinType) {
                if (joinType == null) {
                    return null;
                }
                return JoinTypeEnum.codeOf(joinType.name());
            }

            private List getJoinUniqueKeyEquality(SQLExpr condition) {
                if (!(condition instanceof SQLBinaryOpExpr)) {
                    return Collections.emptyList();
                }
                SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) condition;
                SQLBinaryOperator operator = binaryOpExpr.getOperator();
                if (operator == SQLBinaryOperator.BooleanAnd) {
                    List leftColumn = getJoinUniqueKeyEquality(binaryOpExpr.getLeft());
                    List rightColumn = getJoinUniqueKeyEquality(binaryOpExpr.getRight());
                    List list = new ArrayList<>(leftColumn.size() + rightColumn.size());
                    list.addAll(leftColumn);
                    list.addAll(rightColumn);
                    return list;
                }
                if (operator != SQLBinaryOperator.Equality) {
                    return Collections.emptyList();
                }
                SQLExpr left = binaryOpExpr.getLeft();
                SQLExpr right = binaryOpExpr.getRight();

                if (left instanceof SQLPropertyExpr) {
                    SQLPropertyExpr itemColumn = (SQLPropertyExpr) left;
                    String columnOwnerName = normalize(itemColumn.getOwnernName());
                    String columnName = normalize(itemColumn.getName());
                    SQLExprTableSource tableSource = getTableSourceByAlias(columnOwnerName);
                    if (tableSource != null && columnName != null) {
                        String tableName = getTableName(tableSource);
                        String tableSchema = getTableSchema(tableSource);
                        sqlColumn.resetLeftColumn(null, columnOwnerName, tableSchema, tableName, columnName);
                    } else {
                        sqlColumn.resetLeftColumn(null, columnOwnerName, null, null, columnName);
                    }
                } else if (left instanceof SQLValuableExpr) {
                    sqlColumn.resetLeftColumn(value(((SQLValuableExpr) left).getValue()), null, null, null, null);
                    return Collections.singletonList(sqlColumn.clone());
                } else if (left instanceof SQLVariantRefExpr) {
                    sqlColumn.resetLeftColumn(value(left), null, null, null, null);
                    return Collections.singletonList(sqlColumn.clone());
                } else if (left instanceof SQLIdentifierExpr) {
                    SQLIdentifierExpr itemColumn = (SQLIdentifierExpr) left;
                    String columnName = normalize(itemColumn.getName());
                    SQLExprTableSource tableSource = getTableSourceByAlias(null);
                    if (tableSource != null && columnName != null) {
                        String tableName = getTableName(tableSource);
                        String tableSchema = getTableSchema(tableSource);
                        sqlColumn.resetLeftColumn(null, null, tableSchema, tableName, columnName);
                    } else {
                        sqlColumn.resetLeftColumn(null, null, null, null, columnName);
                    }
                } else {
                    return Collections.emptyList();
                }

                if (right instanceof SQLPropertyExpr) {
                    SQLPropertyExpr itemColumn = (SQLPropertyExpr) right;
                    String columnOwnerName = normalize(itemColumn.getOwnernName());
                    String columnName = normalize(itemColumn.getName());
                    SQLExprTableSource tableSource = tableAliasMap.get(columnOwnerName);
                    if (tableSource != null && columnName != null) {
                        String tableName = getTableName(tableSource);
                        String tableSchema = getTableSchema(tableSource);
                        sqlColumn.resetRightColumn(null, columnOwnerName, tableSchema, tableName, columnName);
                        return Collections.singletonList(sqlColumn.clone());
                    } else {
                        return Collections.emptyList();
                    }
                } else if (right instanceof SQLValuableExpr) {
                    sqlColumn.resetRightColumn(value(((SQLValuableExpr) right).getValue()), null, null, null, null);
                    return Collections.singletonList(sqlColumn.clone());
                } else if (right instanceof SQLVariantRefExpr) {
                    sqlColumn.resetRightColumn(value(right), null, null, null, null);
                    return Collections.singletonList(sqlColumn.clone());
                } else if (right instanceof SQLIdentifierExpr) {
                    SQLIdentifierExpr itemColumn = (SQLIdentifierExpr) right;
                    String columnName = normalize(itemColumn.getName());
                    SQLExprTableSource tableSource = getTableSourceByAlias(null);
                    if (tableSource != null && columnName != null) {
                        String tableName = getTableName(tableSource);
                        String tableSchema = getTableSchema(tableSource);
                        sqlColumn.resetLeftColumn(null, null, tableSchema, tableName, columnName);
                    } else {
                        sqlColumn.resetLeftColumn(null, null, null, null, columnName);
                    }
                    return Collections.singletonList(sqlColumn.clone());
                } else {
                    return Collections.emptyList();
                }
            }

            private Object value(Object value) {
                if (value == SQLEvalVisitor.EVAL_VALUE_NULL) {
                    return SQLColumn.NULL;
                } else if (value instanceof SQLVariantRefExpr) {
                    return SQLColumn.VAR_REF;
                } else {
                    return value;
                }
            }

            private SQLExprTableSource getTableSourceByAlias(String alias) {
                SQLExprTableSource tableSource = alias != null ? tableAliasMap.get(alias) : null;
                if (tableSource == null && tableAliasMap.size() == 1) {
                    tableSource = tableAliasMap.values().iterator().next();
                }
                return tableSource;
            }

            @Override
            public boolean visit(SQLDeleteStatement statement) {
                delete = true;
                return true;
            }

            @Override
            public void endVisit(SQLDeleteStatement statement) {
                try {
                    LinkedList temp = new LinkedList<>();
                    temp.add(statement.getTableSource());
                    while (!temp.isEmpty()) {
                        SQLTableSource tableSource = temp.removeFirst();
                        if (tableSource == null) {
                            continue;
                        }

                        if (tableSource instanceof SQLJoinTableSource) {
                            temp.add(((SQLJoinTableSource) tableSource).getLeft());
                            temp.add(((SQLJoinTableSource) tableSource).getRight());
                        } else {
                            SQLExpr where = statement.getWhere();
                            if (isInjectCondition(where)) {
                                continue;
                            }
                            if (isSubqueryOrUnion(tableSource)) {
                                continue;
                            }
                            String tableName = getTableName(tableSource);
                            if (skip.test(getTableSchema(tableSource), tableName)) {
                                continue;
                            }
                            String alias = getAlias(tableSource);
                            if (existInjectConditionStrategyEnum == ExistInjectConditionStrategyEnum.RULE_TABLE_MATCH_THEN_SKIP_ITEM
                                    && existInjectCondition(injectConditionColumnList, alias, where)) {
                                continue;
                            }
                            // 2.跳过 1093 - You can't specify target table 'xx' for update in FROM clause
                            if (isCantSpecifyTargetTableForUpdateInFromClause(tableName, cantSpecifyTargetTableNameSet, true)) {
                                continue;
                            }
                            statement.setWhere(mergeCondition(op, injectCondition, alias, appendConditionToLeft, where));
                            change[0] = true;
                        }
                    }
                } finally {
                    delete = false;
                }
            }

            @Override
            public boolean visit(SQLUpdateStatement statement) {
                update = true;
                return true;
            }

            @Override
            public boolean visit(SQLUpdateSetItem x) {
                this.updateTableName = getUpdateTableName(x.getColumn());
                this.updateSetItem = true;
                return true;
            }

            @Override
            public void endVisit(SQLUpdateSetItem x) {
                this.updateTableName = null;
                this.updateSetItem = false;
            }

            @Override
            public void endVisit(SQLUpdateStatement statement) {
                try {
                    LinkedList temp = new LinkedList<>();
                    temp.add(statement.getTableSource());
                    while (!temp.isEmpty()) {
                        SQLTableSource tableSource = temp.removeFirst();
                        if (tableSource == null) {
                            continue;
                        }

                        if (tableSource instanceof SQLJoinTableSource) {
                            temp.add(((SQLJoinTableSource) tableSource).getLeft());
                            temp.add(((SQLJoinTableSource) tableSource).getRight());
                        } else {
                            SQLExpr where = statement.getWhere();
                            if (isInjectCondition(where)) {
                                continue;
                            }
                            String alias = getAlias(tableSource);
                            if (isSubqueryOrUnion(tableSource)) {
                                continue;
                            }
                            String tableName = getTableName(tableSource);
                            if (skip.test(getTableSchema(tableSource), tableName)) {
                                continue;
                            }
                            if (existInjectConditionStrategyEnum == ExistInjectConditionStrategyEnum.RULE_TABLE_MATCH_THEN_SKIP_ITEM
                                    && existInjectCondition(injectConditionColumnList, alias, where)) {
                                continue;
                            }
                            // 2.跳过 1093 - You can't specify target table 'xx' for update in FROM clause
                            if (isCantSpecifyTargetTableForUpdateInFromClause(tableName, cantSpecifyTargetTableNameSet, true)) {
                                continue;
                            }
                            statement.setWhere(mergeCondition(op, injectCondition, alias, appendConditionToLeft, where));
                            change[0] = true;
                        }
                    }
                } finally {
                    update = false;
                }
            }

            private String getUpdateTableName(SQLExpr column) {
                String alias;
                if (column instanceof SQLPropertyExpr) {
                    alias = getAlias((SQLPropertyExpr) column);
                } else {
                    alias = null;
                }
                SQLExprTableSource tableSource = getTableSourceByAlias(alias);
                return getTableName(tableSource);
            }
        });
        return change[0];
    }

    /**
     * 是否存在这种情况:1093 - You can't specify target table 'xx' for update in FROM clause
     *
     * @param tableName                   当前from的表名
     * @param injectConditionTableNameSet 注入的嵌套表
     * @param update                      是否是update
     * @return 存在 1093 - You can't specify target table 'xx' for update in FROM clause
     */
    private static boolean isCantSpecifyTargetTableForUpdateInFromClause(String tableName, Set injectConditionTableNameSet, boolean update) {
        return update && injectConditionTableNameSet.contains(tableName);
    }

    private static List flatColumnList(SQLExpr injectCondition) {
        List list = new ArrayList<>(2);
        LinkedList temp = new LinkedList<>();
        temp.add(injectCondition);
        while (!temp.isEmpty()) {
            SQLObject injectConditionItem = temp.removeFirst();
            if (injectConditionItem instanceof SQLName) {
                list.add((SQLName) injectConditionItem);
            }
            if (injectConditionItem instanceof SQLExpr) {
                SQLExpr itemExpr = (SQLExpr) injectConditionItem;
                List next = itemExpr.getChildren();
                if (next != null && !next.isEmpty()) {
                    temp.addAll(next);
                }
            }
        }
        return list;
    }

    private static boolean existExcludeInjectConditionList(Collection excludeInjectList, SQLExpr where) {
        if (excludeInjectList != null) {
            for (SQLExpr exclude : excludeInjectList) {
                if (existExcludeInjectCondition(exclude, where)) {
                    return true;
                }
            }
        }
        return false;
    }

    private static boolean isEquals(SQLObject exclude, SQLObject where) {
        if (exclude == null && where == null) {
            return true;
        }
        if (exclude == null || where == null) {
            return false;
        }
        if (exclude.getClass() != where.getClass()) {
            return false;
        }
        if (exclude instanceof SQLPropertyExpr && where instanceof SQLPropertyExpr) {
            SQLPropertyExpr excludeExpr = (SQLPropertyExpr) exclude;
            SQLPropertyExpr whereExpr = (SQLPropertyExpr) where;
            return equalsIgnoreCase(normalize(excludeExpr.getName()), normalize(whereExpr.getName()));
        } else if (exclude instanceof SQLIdentifierExpr && where instanceof SQLIdentifierExpr) {
            SQLIdentifierExpr excludeExpr = (SQLIdentifierExpr) exclude;
            SQLIdentifierExpr whereExpr = (SQLIdentifierExpr) where;
            return equalsIgnoreCase(normalize(excludeExpr.getName()), normalize(whereExpr.getName()));
        } else if (exclude instanceof SQLBinaryOpExpr && where instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr excludeExpr = (SQLBinaryOpExpr) exclude;
            SQLBinaryOpExpr whereExpr = (SQLBinaryOpExpr) where;
            return excludeExpr.getOperator() == whereExpr.getOperator();
        } else {
            return true;
        }
    }

    public static boolean equalsIgnoreCase(String a, String b) {
        return (a == b) || (a != null && a.equalsIgnoreCase(b));
    }

    private static boolean existExcludeInjectCondition(SQLExpr exclude, SQLExpr where) {
        if (where == null) {
            return false;
        }
        if (!isEquals(exclude, where)) {
            return false;
        }

        LinkedList tempexclude = new LinkedList<>();
        LinkedList tempwhere = new LinkedList<>();
        tempexclude.add(exclude);
        tempwhere.add(where);
        while (true) {
            if (tempexclude.size() != tempwhere.size()) {
                return false;
            }
            if (tempexclude.isEmpty()) {
                return true;
            }
            SQLObject excludeExpr = tempexclude.removeFirst();
            SQLObject whereExpr = tempwhere.removeFirst();
            if (!isEquals(excludeExpr, whereExpr)) {
                return false;
            }
            if (excludeExpr instanceof SQLExpr) {
                SQLExpr itemExpr = (SQLExpr) excludeExpr;
                List next = itemExpr.getChildren();
                if (next != null && !next.isEmpty()) {
                    tempexclude.addAll(next);
                }
            }
            if (whereExpr instanceof SQLExpr) {
                SQLExpr itemExpr = (SQLExpr) whereExpr;
                List next = itemExpr.getChildren();
                if (next != null && !next.isEmpty()) {
                    tempwhere.addAll(next);
                }
            }
        }
    }

    private static boolean existInjectCondition(List injectConditionColumnList, String aliasOrTableName, SQLExpr where) {
        if (where == null) {
            return false;
        }
        int injectConditionColumnSize = injectConditionColumnList.size();
        BitSet exist = new BitSet(injectConditionColumnSize);
        LinkedList temp = new LinkedList<>();
        temp.add(where);
        while (!temp.isEmpty()) {
            SQLObject item = temp.removeFirst();
            if (item == null) {
                continue;
            }

            if (item instanceof SQLExpr) {
                SQLExpr itemExpr = (SQLExpr) item;
                String itemColumnName;
                String itemColumnOwnerName;
                if (item instanceof SQLPropertyExpr) {
                    SQLPropertyExpr itemColumn = (SQLPropertyExpr) item;
                    itemColumnOwnerName = itemColumn.getOwnernName();
                    itemColumnName = itemColumn.getName();
                } else if (item instanceof SQLIdentifierExpr) {
                    itemColumnName = ((SQLIdentifierExpr) item).getName();
                    itemColumnOwnerName = null;
                } else {
                    itemColumnName = null;
                    itemColumnOwnerName = null;
                }

                if (itemColumnOwnerName != null && !itemColumnOwnerName.equalsIgnoreCase(aliasOrTableName)) {
                    continue;
                }
                if (itemColumnName != null) {
                    String normalizeItemColumnName = normalize(itemColumnName);
                    int i = 0;
                    for (SQLName injectConditionColumn : injectConditionColumnList) {
                        if (exist.get(i)) {
                            continue;
                        }
                        if (injectConditionColumn.getSimpleName().equalsIgnoreCase(normalizeItemColumnName)) {
                            exist.set(i);
                        }
                        if (exist.length() == injectConditionColumnSize) {
                            return true;
                        }
                        i++;
                    }
                }

                List next = itemExpr.getChildren();
                if (next != null && !next.isEmpty()) {
                    temp.addAll(next);
                }
            }
        }
        return exist.length() == injectConditionColumnSize;
    }

    private static SQLExpr mergeCondition(SQLBinaryOperator op, SQLExpr injectCondition, String alias, boolean left, SQLExpr where) {
        SQLExpr result;
        if (injectCondition instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr binaryOpExpr = ((SQLBinaryOpExpr) injectCondition);
            SQLExpr injectConditionAlias = alias == null ?
                    injectCondition : mergeConditionIfExistAlias(binaryOpExpr.getLeft(), binaryOpExpr.getRight(), binaryOpExpr.getOperator(), alias, null);
            injectConditionAlias.accept(InjectMarkSQLASTVisitor.INSTANCE);
            if (where == null) {
                result = injectConditionAlias;
            } else {
                result = left ? new SQLBinaryOpExpr(injectConditionAlias, op, where) : new SQLBinaryOpExpr(where, op, injectConditionAlias);
            }
        } else {
            result = alias == null ?
                    injectCondition : left ? mergeConditionIfExistAlias(injectCondition, where, op, alias, true) : mergeConditionIfExistAlias(where, injectCondition, op, alias, false);
            injectCondition.accept(InjectMarkSQLASTVisitor.INSTANCE);
        }
        return result;
    }

    private static SQLExpr mergeConditionIfExistAlias(SQLExpr left, SQLExpr right, SQLBinaryOperator operator, String conditionAlias, Boolean aliasLeft) {
        SQLExpr newLeft;
        SQLExpr newRight;
        boolean leftAppend = aliasLeft == null || aliasLeft;
        boolean rightAppend = aliasLeft == null || !aliasLeft;

        if (left instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr expr = (SQLBinaryOpExpr) left;
            if (leftAppend) {
                newLeft = mergeConditionIfExistAlias(expr.getLeft(), expr.getRight(), expr.getOperator(), conditionAlias, null);
            } else {
                newLeft = left.clone();
            }
        } else if (left instanceof SQLName) {
            if (leftAppend) {
                newLeft = new SQLPropertyExpr(conditionAlias, ((SQLName) left).getSimpleName());
            } else {
                newLeft = left.clone();
            }
        } else if (left instanceof SQLInSubQueryExpr) {
            newLeft = left.clone();
            ((SQLInSubQueryExpr) newLeft).setExpr(mergeConditionIfExistAlias(((SQLInSubQueryExpr) left).getExpr(), null, operator, conditionAlias, leftAppend ? null : false));
        } else {
            newLeft = left == null ? null : left.clone();
        }

        if (right instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr expr = (SQLBinaryOpExpr) right;
            if (rightAppend) {
                newRight = mergeConditionIfExistAlias(expr.getLeft(), expr.getRight(), expr.getOperator(), conditionAlias, null);
            } else {
                newRight = right.clone();
            }
        } else if (right instanceof SQLIdentifierExpr) {
            if (rightAppend) {
                newRight = new SQLPropertyExpr(conditionAlias, ((SQLName) right).getSimpleName());
            } else {
                newRight = right.clone();
            }
        } else if (right instanceof SQLInSubQueryExpr) {
            newRight = right.clone();
            ((SQLInSubQueryExpr) newRight).setExpr(mergeConditionIfExistAlias(((SQLInSubQueryExpr) right).getExpr(), null, operator, conditionAlias, rightAppend ? null : false));
        } else {
            newRight = right == null ? null : right.clone();
        }
        SQLExpr binaryOpExpr;
        if (newLeft == null) {
            binaryOpExpr = newRight;
        } else if (newRight == null) {
            binaryOpExpr = newLeft;
        } else {
            binaryOpExpr = new SQLBinaryOpExpr(newLeft, operator, newRight);
        }
        return binaryOpExpr;
    }

    private static  T getDbType(String type) {
        if (DB_TYPE_METHOD != null) {
            try {
                return (T) DB_TYPE_METHOD.invoke(null, type);
            } catch (IllegalAccessException | InvocationTargetException e) {
            }
        }
        return (T) type;
    }

    private static final String INJECT_CONDITION_MARK_NAME = "inject";

    private static boolean isInjectCondition(SQLObject injectCondition) {
        return injectCondition != null && injectCondition.containsAttribute(INJECT_CONDITION_MARK_NAME);
    }

    static class InjectMarkSQLASTVisitor extends SQLASTVisitorAdapter {
        static final InjectMarkSQLASTVisitor INSTANCE = new InjectMarkSQLASTVisitor();

        @Override
        public boolean visit(SQLBinaryOpExpr x) {
            injectMark(x);
            return true;
        }

        @Override
        public boolean visit(SQLSelectQueryBlock x) {
            injectMark(x);
            return true;
        }

        private void injectMark(SQLObject sqlObject) {
            sqlObject.putAttribute(INJECT_CONDITION_MARK_NAME, Boolean.TRUE);
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy