com.alibaba.druid.wall.spi.WallVisitorUtils Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of druid Show documentation
Show all versions of druid Show documentation
An JDBC datasource implementation.
The newest version!
/*
* Copyright 1999-2018 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.druid.wall.spi;
import static com.alibaba.druid.sql.visitor.SQLEvalVisitor.EVAL_VALUE;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import com.alibaba.druid.sql.ast.SQLCommentHint;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLLimit;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLOrderBy;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr;
import com.alibaba.druid.sql.ast.expr.SQLAllColumnExpr;
import com.alibaba.druid.sql.ast.expr.SQLBetweenExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExprGroup;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLBooleanExpr;
import com.alibaba.druid.sql.ast.expr.SQLCaseExpr;
import com.alibaba.druid.sql.ast.expr.SQLCaseExpr.Item;
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLExistsExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr;
import com.alibaba.druid.sql.ast.expr.SQLIntegerExpr;
import com.alibaba.druid.sql.ast.expr.SQLLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLNCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLNotExpr;
import com.alibaba.druid.sql.ast.expr.SQLNumberExpr;
import com.alibaba.druid.sql.ast.expr.SQLNumericLiteralExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.expr.SQLUnaryExpr;
import com.alibaba.druid.sql.ast.expr.SQLValuableExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLAlterStatement;
import com.alibaba.druid.sql.ast.statement.SQLAlterTableStatement;
import com.alibaba.druid.sql.ast.statement.SQLBlockStatement;
import com.alibaba.druid.sql.ast.statement.SQLCallStatement;
import com.alibaba.druid.sql.ast.statement.SQLCommentStatement;
import com.alibaba.druid.sql.ast.statement.SQLCommitStatement;
import com.alibaba.druid.sql.ast.statement.SQLCreateStatement;
import com.alibaba.druid.sql.ast.statement.SQLCreateTableStatement;
import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement;
import com.alibaba.druid.sql.ast.statement.SQLDescribeStatement;
import com.alibaba.druid.sql.ast.statement.SQLDropStatement;
import com.alibaba.druid.sql.ast.statement.SQLDropTableStatement;
import com.alibaba.druid.sql.ast.statement.SQLExplainStatement;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLInsertInto;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement.ValuesClause;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLMergeStatement;
import com.alibaba.druid.sql.ast.statement.SQLReplaceStatement;
import com.alibaba.druid.sql.ast.statement.SQLRollbackStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectGroupByClause;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLSetStatement;
import com.alibaba.druid.sql.ast.statement.SQLShowTablesStatement;
import com.alibaba.druid.sql.ast.statement.SQLStartTransactionStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.ast.statement.SQLTruncateStatement;
import com.alibaba.druid.sql.ast.statement.SQLUnionOperator;
import com.alibaba.druid.sql.ast.statement.SQLUnionQuery;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.ast.statement.SQLUseStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlOrderingExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlOutFileExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlExplainStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlHintStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlLockTableStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlOptimizeStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlRenameTableStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlShowGrantsStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlShowStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleExecuteImmediateStatement;
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleMultiInsertStatement;
import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGShowStatement;
import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerExecStatement;
import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerInsertStatement;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.visitor.ExportParameterVisitor;
import com.alibaba.druid.sql.visitor.SQLEvalVisitor;
import com.alibaba.druid.sql.visitor.SQLEvalVisitorUtils;
import com.alibaba.druid.sql.visitor.functions.Nil;
import com.alibaba.druid.support.logging.Log;
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.FnvHash;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.util.ServletPathMatcher;
import com.alibaba.druid.util.StringUtils;
import com.alibaba.druid.wall.WallConfig;
import com.alibaba.druid.wall.WallConfig.TenantCallBack;
import com.alibaba.druid.wall.WallConfig.TenantCallBack.StatementType;
import com.alibaba.druid.wall.WallContext;
import com.alibaba.druid.wall.WallProvider;
import com.alibaba.druid.wall.WallSqlTableStat;
import com.alibaba.druid.wall.WallUpdateCheckHandler;
import com.alibaba.druid.wall.WallUpdateCheckItem;
import com.alibaba.druid.wall.WallVisitor;
import com.alibaba.druid.wall.violation.ErrorCode;
import com.alibaba.druid.wall.violation.IllegalSQLObjectViolation;
public class WallVisitorUtils {
private final static Log LOG = LogFactory.getLog(WallVisitorUtils.class);
public final static String HAS_TRUE_LIKE = "hasTrueLike";
public final static String[] whiteHints = { "LOCAL", "TEMPORARY", "SQL_NO_CACHE", "SQL_CACHE", "HIGH_PRIORITY",
"LOW_PRIORITY", "STRAIGHT_JOIN", "SQL_BUFFER_RESULT", "SQL_BIG_RESULT", "SQL_SMALL_RESULT", "DELAYED" };
public static void check(WallVisitor visitor, SQLInListExpr x) {
}
public static boolean check(WallVisitor visitor, SQLBinaryOpExpr x) {
if (x.getOperator() == SQLBinaryOperator.BooleanOr
|| x.getOperator() == SQLBinaryOperator.BooleanAnd) {
List groupList = SQLBinaryOpExpr.split(x);
for (SQLExpr item : groupList) {
item.accept(visitor);
}
return false;
}
if (x.getOperator() == SQLBinaryOperator.Add
|| x.getOperator() == SQLBinaryOperator.Concat) {
List groupList = SQLBinaryOpExpr.split(x);
if (groupList.size() >= 4) {
int chrCount = 0;
for (int i = 0; i < groupList.size(); ++i) {
SQLExpr item = groupList.get(i);
if (item instanceof SQLMethodInvokeExpr) {
SQLMethodInvokeExpr methodExpr = (SQLMethodInvokeExpr) item;
String methodName = methodExpr.getMethodName().toLowerCase();
if ("chr".equals(methodName) || "char".equals(methodName)) {
if (methodExpr.getParameters().get(0) instanceof SQLLiteralExpr) {
chrCount++;
}
}
} else if (item instanceof SQLCharExpr) {
if (((SQLCharExpr) item).getText().length() > 5) {
chrCount = 0;
continue;
}
}
if (chrCount >= 4) {
addViolation(visitor, ErrorCode.EVIL_CONCAT, "evil concat", x);
break;
}
}
}
}
return true;
}
public static boolean check(WallVisitor visitor, SQLBinaryOpExprGroup x) {
return true;
}
public static void check(WallVisitor visitor, SQLCreateTableStatement x) {
String tableName = ((SQLName) x.getName()).getSimpleName();
WallContext context = WallContext.current();
if (context != null) {
WallSqlTableStat tableStat = context.getTableStat(tableName);
if (tableStat != null) {
tableStat.incrementCreateCount();
}
}
}
public static void check(WallVisitor visitor, SQLAlterTableStatement x) {
String tableName = ((SQLName) x.getName()).getSimpleName();
WallContext context = WallContext.current();
if (context != null) {
WallSqlTableStat tableStat = context.getTableStat(tableName);
if (tableStat != null) {
tableStat.incrementAlterCount();
}
}
}
public static void check(WallVisitor visitor, SQLDropTableStatement x) {
for (SQLTableSource item : x.getTableSources()) {
if (item instanceof SQLExprTableSource) {
SQLExpr expr = ((SQLExprTableSource) item).getExpr();
String tableName = ((SQLName) expr).getSimpleName();
WallContext context = WallContext.current();
if (context != null) {
WallSqlTableStat tableStat = context.getTableStat(tableName);
if (tableStat != null) {
tableStat.incrementDropCount();
}
}
}
}
}
public static void check(WallVisitor visitor, SQLSelectItem x) {
SQLExpr expr = x.getExpr();
if (expr instanceof SQLVariantRefExpr) {
if (!isTopSelectItem(expr) && "@".equals(((SQLVariantRefExpr) expr).getName())) {
addViolation(visitor, ErrorCode.EVIL_NAME, "@ not allow", x);
}
}
if (visitor.getConfig().isSelectAllColumnAllow()) {
return;
}
if (expr instanceof SQLAllColumnExpr //
&& x.getParent() instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) x.getParent();
SQLTableSource from = queryBlock.getFrom();
if (from instanceof SQLExprTableSource) {
addViolation(visitor, ErrorCode.SELECT_NOT_ALLOW, "'SELECT *' not allow", x);
}
}
}
public static void check(WallVisitor visitor, SQLPropertyExpr x) {
checkSchema(visitor, x.getOwner());
}
public static void checkInsert(WallVisitor visitor, SQLInsertInto x) {
checkReadOnly(visitor, x.getTableSource());
if (!visitor.getConfig().isInsertAllow()) {
addViolation(visitor, ErrorCode.INSERT_NOT_ALLOW, "insert not allow", x);
}
checkInsertForMultiTenant(visitor, x);
}
public static void checkSelelct(WallVisitor visitor, SQLSelectQueryBlock x) {
if (x.getInto() != null) {
checkReadOnly(visitor, x.getInto());
}
if (!visitor.getConfig().isSelectIntoAllow() && x.getInto() != null) {
addViolation(visitor, ErrorCode.SELECT_INTO_NOT_ALLOW, "select into not allow", x);
return;
}
List hints = x.getHintsDirect();
if (hints != null
&& x.getParent() instanceof SQLUnionQuery
&& x == ((SQLUnionQuery) x.getParent()).getRight()
) {
for (SQLCommentHint hint : hints) {
String text = hint.getText();
if (text.startsWith("!")) {
addViolation(visitor, ErrorCode.UNION, "union select hint not allow", x);
return;
}
}
}
SQLExpr where = x.getWhere();
if (where != null) {
checkCondition(visitor, x.getWhere());
Object whereValue = getConditionValue(visitor, where, visitor.getConfig().isSelectWhereAlwayTrueCheck());
if (Boolean.TRUE == whereValue) {
if (visitor.getConfig().isSelectWhereAlwayTrueCheck()
&& visitor.isSqlEndOfComment()
&& !isSimpleConstExpr(where)) {// 简单表达式
addViolation(visitor, ErrorCode.ALWAYS_TRUE, "select alway true condition not allow", x);
}
}
}
checkSelectForMultiTenant(visitor, x);
// checkConditionForMultiTenant(visitor, x.getWhere(), x);
}
public static void checkHaving(WallVisitor visitor, SQLExpr x) {
if (x == null) {
return;
}
if (Boolean.TRUE == getConditionValue(visitor, x, visitor.getConfig().isSelectHavingAlwayTrueCheck())) {
if (visitor.getConfig().isSelectHavingAlwayTrueCheck()
&& visitor.isSqlEndOfComment()
&& !isSimpleConstExpr(x)) {
addViolation(visitor, ErrorCode.ALWAYS_TRUE, "having alway true condition not allow", x);
}
}
}
public static void checkDelete(WallVisitor visitor, SQLDeleteStatement x) {
checkReadOnly(visitor, x.getTableSource());
WallConfig config = visitor.getConfig();
if (!config.isDeleteAllow()) {
addViolation(visitor, ErrorCode.INSERT_NOT_ALLOW, "delete not allow", x);
return;
}
boolean hasUsing = false;
if (x instanceof MySqlDeleteStatement) {
hasUsing = ((MySqlDeleteStatement) x).getUsing() != null;
}
boolean isJoinTableSource = x.getTableSource() instanceof SQLJoinTableSource;
if (x.getWhere() == null && (!hasUsing) && !isJoinTableSource) {
WallContext context = WallContext.current();
if (context != null) {
context.incrementDeleteNoneConditionWarnings();
}
if (config.isDeleteWhereNoneCheck()) {
addViolation(visitor, ErrorCode.NONE_CONDITION, "delete none condition not allow", x);
return;
}
}
SQLExpr where = x.getWhere();
if (where != null) {
checkCondition(visitor, where);
if (Boolean.TRUE == getConditionValue(visitor, where, config.isDeleteWhereAlwayTrueCheck())) {
if (config.isDeleteWhereAlwayTrueCheck() && visitor.isSqlEndOfComment() && !isSimpleConstExpr(where)) {
addViolation(visitor, ErrorCode.ALWAYS_TRUE, "delete alway true condition not allow", x);
}
}
}
// checkConditionForMultiTenant(visitor, x.getWhere(), x);
}
private static boolean isSimpleConstExpr(SQLExpr sqlExpr) {
List parts = getParts(sqlExpr);
if (parts.isEmpty()) {
return false;
}
for (SQLExpr part : parts) {
if(isFirst(part)) {
Object evalValue = part.getAttribute(EVAL_VALUE);
if (evalValue == null) {
if (part instanceof SQLBooleanExpr) {
evalValue = ((SQLBooleanExpr) part).getBooleanValue();
} else if (part instanceof SQLNumericLiteralExpr) {
evalValue = ((SQLNumericLiteralExpr) part).getNumber();
} else if (part instanceof SQLCharExpr) {
evalValue = ((SQLCharExpr) part).getText();
} else if (part instanceof SQLNCharExpr) {
evalValue = ((SQLNCharExpr) part).getText();
}
}
Boolean result = SQLEvalVisitorUtils.castToBoolean(evalValue);
if (result != null && result) {
return true;
}
}
boolean isSimpleConstExpr = false;
if (part == sqlExpr || part instanceof SQLLiteralExpr) {
isSimpleConstExpr = true;
} else if (part instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) part;
if (binaryOpExpr.getOperator() == SQLBinaryOperator.Equality
|| binaryOpExpr.getOperator() == SQLBinaryOperator.NotEqual
|| binaryOpExpr.getOperator() == SQLBinaryOperator.GreaterThan) {
if (binaryOpExpr.getLeft() instanceof SQLIntegerExpr
&& binaryOpExpr.getRight() instanceof SQLIntegerExpr) {
isSimpleConstExpr = true;
}
}
}
if (!isSimpleConstExpr) {
return false;
}
}
return true;
}
private static void checkCondition(WallVisitor visitor, SQLExpr x) {
if (x == null) {
return;
}
if (visitor.getConfig().isMustParameterized()) {
ExportParameterVisitor exportParameterVisitor = visitor.getProvider().createExportParameterVisitor();
x.accept(exportParameterVisitor);
if (exportParameterVisitor.getParameters().size() > 0) {
addViolation(visitor, ErrorCode.NOT_PARAMETERIZED, "sql must parameterized", x);
}
}
}
private static void checkJoinSelectForMultiTenant(WallVisitor visitor, SQLJoinTableSource join,
SQLSelectQueryBlock x) {
TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
return;
}
SQLTableSource right = join.getRight();
if (right instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) right).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
String alias = null;
String tenantColumn = null;
if (tenantCallBack != null) {
tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName);
}
if (StringUtils.isEmpty(tenantColumn)
&& ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
tenantColumn = visitor.getConfig().getTenantColumn();
}
if (!StringUtils.isEmpty(tenantColumn)) {
alias = right.getAlias();
if (alias == null) {
alias = tableName;
}
SQLExpr item = null;
if (alias != null) {
item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
} else {
item = new SQLIdentifierExpr(tenantColumn);
}
SQLSelectItem selectItem = new SQLSelectItem(item);
x.getSelectList().add(selectItem);
visitor.setSqlModified(true);
}
}
}
}
private static boolean isSelectStatmentForMultiTenant(SQLSelectQueryBlock queryBlock) {
SQLObject parent = queryBlock.getParent();
while (parent != null) {
if (parent instanceof SQLUnionQuery) {
SQLObject x = parent;
parent = x.getParent();
} else {
break;
}
}
if (!(parent instanceof SQLSelect)) {
return false;
}
parent = ((SQLSelect) parent).getParent();
if (parent instanceof SQLSelectStatement) {
return true;
}
return false;
}
private static void checkSelectForMultiTenant(WallVisitor visitor, SQLSelectQueryBlock x) {
TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
return;
}
if (x == null) {
throw new IllegalStateException("x is null");
}
if (!isSelectStatmentForMultiTenant(x)) {
return;
}
SQLTableSource tableSource = x.getFrom();
String alias = null;
String matchTableName = null;
String tenantColumn = null;
if (tableSource instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (tenantCallBack != null) {
tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName);
}
if (StringUtils.isEmpty(tenantColumn)
&& ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
tenantColumn = visitor.getConfig().getTenantColumn();
}
if (!StringUtils.isEmpty(tenantColumn)) {
matchTableName = tableName;
alias = tableSource.getAlias();
}
}
} else if (tableSource instanceof SQLJoinTableSource) {
SQLJoinTableSource join = (SQLJoinTableSource) tableSource;
if (join.getLeft() instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) join.getLeft()).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (tenantCallBack != null) {
tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName);
}
if (StringUtils.isEmpty(tenantColumn)
&& ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
tenantColumn = visitor.getConfig().getTenantColumn();
}
if (!StringUtils.isEmpty(tenantColumn)) {
matchTableName = tableName;
alias = join.getLeft().getAlias();
if (alias == null) {
alias = tableName;
}
}
}
checkJoinSelectForMultiTenant(visitor, join, x);
} else {
checkJoinSelectForMultiTenant(visitor, join, x);
}
}
if (matchTableName == null) {
return;
}
SQLExpr item = null;
if (alias != null) {
item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
} else {
item = new SQLIdentifierExpr(tenantColumn);
}
SQLSelectItem selectItem = new SQLSelectItem(item);
x.getSelectList().add(selectItem);
visitor.setSqlModified(true);
}
private static void checkUpdateForMultiTenant(WallVisitor visitor, SQLUpdateStatement x) {
TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
return;
}
if (x == null) {
throw new IllegalStateException("x is null");
}
SQLTableSource tableSource = x.getTableSource();
String alias = null;
String matchTableName = null;
String tenantColumn = null;
if (tableSource instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (tenantCallBack != null) {
tenantColumn = tenantCallBack.getTenantColumn(StatementType.UPDATE, tableName);
}
if (StringUtils.isEmpty(tenantColumn)
&& ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
tenantColumn = visitor.getConfig().getTenantColumn();
}
if (!StringUtils.isEmpty(tenantColumn)) {
matchTableName = tableName;
alias = tableSource.getAlias();
}
}
}
if (matchTableName == null) {
return;
}
SQLExpr item = null;
if (alias != null) {
item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
} else {
item = new SQLIdentifierExpr(tenantColumn);
}
SQLExpr value = generateTenantValue(visitor, alias, StatementType.UPDATE, matchTableName);
SQLUpdateSetItem updateSetItem = new SQLUpdateSetItem();
updateSetItem.setColumn(item);
updateSetItem.setValue(value);
x.addItem(updateSetItem);
visitor.setSqlModified(true);
}
private static void checkInsertForMultiTenant(WallVisitor visitor, SQLInsertInto x) {
TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
return;
}
if (x == null) {
throw new IllegalStateException("x is null");
}
SQLExprTableSource tableSource = x.getTableSource();
String alias = null;
String matchTableName = null;
String tenantColumn = null;
SQLExpr tableExpr = tableSource.getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (tenantCallBack != null) {
tenantColumn = tenantCallBack.getTenantColumn(StatementType.INSERT, tableName);
}
if (StringUtils.isEmpty(tenantColumn)
&& ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
tenantColumn = visitor.getConfig().getTenantColumn();
}
if (!StringUtils.isEmpty(tenantColumn)) {
matchTableName = tableName;
alias = tableSource.getAlias();
}
}
if (matchTableName == null) {
return;
}
SQLExpr item = null;
if (alias != null) {
item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
} else {
item = new SQLIdentifierExpr(tenantColumn);
}
SQLExpr value = generateTenantValue(visitor, alias, StatementType.INSERT, matchTableName);
// add insert item and value
x.getColumns().add(item);
List valuesClauses = null;
ValuesClause valuesClause = null;
if (x instanceof MySqlInsertStatement) {
valuesClauses = ((MySqlInsertStatement) x).getValuesList();
} else if (x instanceof SQLServerInsertStatement) {
valuesClauses = ((MySqlInsertStatement) x).getValuesList();
} else {
valuesClause = x.getValues();
}
if (valuesClauses != null && valuesClauses.size() > 0) {
for (ValuesClause clause : valuesClauses) {
clause.addValue(value);
}
}
if (valuesClause != null) {
valuesClause.addValue(value);
}
// insert .. select
SQLSelect select = x.getQuery();
if (select != null) {
List queryBlocks = splitSQLSelectQuery(select.getQuery());
for (SQLSelectQueryBlock queryBlock : queryBlocks) {
queryBlock.getSelectList().add(new SQLSelectItem(value));
}
}
visitor.setSqlModified(true);
}
private static List splitSQLSelectQuery(SQLSelectQuery x) {
List groupList = new ArrayList();
Stack stack = new Stack();
stack.push(x);
do {
SQLSelectQuery query = stack.pop();
if (query instanceof SQLSelectQueryBlock) {
groupList.add((SQLSelectQueryBlock) query);
} else if (query instanceof SQLUnionQuery) {
SQLUnionQuery unionQuery = (SQLUnionQuery) query;
stack.push(unionQuery.getLeft());
stack.push(unionQuery.getRight());
}
} while (!stack.empty());
return groupList;
}
@Deprecated
public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, SQLObject parent) {
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantTablePattern == null || tenantTablePattern.length() == 0) {
return;
}
if (parent == null) {
throw new IllegalStateException("parent is null");
}
String alias = null;
SQLTableSource tableSource;
StatementType statementType = null;
if (parent instanceof SQLDeleteStatement) {
tableSource = ((SQLDeleteStatement) parent).getTableSource();
statementType = StatementType.DELETE;
} else if (parent instanceof SQLUpdateStatement) {
tableSource = ((SQLUpdateStatement) parent).getTableSource();
statementType = StatementType.UPDATE;
} else if (parent instanceof SQLSelectQueryBlock) {
tableSource = ((SQLSelectQueryBlock) parent).getFrom();
statementType = StatementType.SELECT;
} else {
throw new IllegalStateException("not support parent : " + parent.getClass());
}
String matchTableName = null;
if (tableSource instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
matchTableName = tableName;
alias = tableSource.getAlias();
}
}
} else if (tableSource instanceof SQLJoinTableSource) {
SQLJoinTableSource join = (SQLJoinTableSource) tableSource;
if (join.getLeft() instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) join.getLeft()).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
matchTableName = tableName;
alias = join.getLeft().getAlias();
}
}
checkJoinConditionForMultiTenant(visitor, join, false, statementType);
} else {
checkJoinConditionForMultiTenant(visitor, join, true, statementType);
}
}
if (matchTableName == null) {
return;
}
SQLBinaryOpExpr tenantCondition = createTenantCondition(visitor, alias, statementType, matchTableName);
SQLExpr condition;
if (x == null) {
condition = tenantCondition;
} else {
condition = new SQLBinaryOpExpr(tenantCondition, SQLBinaryOperator.BooleanAnd, x);
}
if (parent instanceof SQLDeleteStatement) {
SQLDeleteStatement deleteStmt = (SQLDeleteStatement) parent;
deleteStmt.setWhere(condition);
visitor.setSqlModified(true);
} else if (parent instanceof SQLUpdateStatement) {
SQLUpdateStatement updateStmt = (SQLUpdateStatement) parent;
updateStmt.setWhere(condition);
visitor.setSqlModified(true);
} else if (parent instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) parent;
queryBlock.setWhere(condition);
visitor.setSqlModified(true);
}
}
@Deprecated
public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoinTableSource join,
boolean checkLeft, StatementType statementType) {
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantTablePattern == null || tenantTablePattern.length() == 0) {
return;
}
SQLExpr condition = join.getCondition();
SQLTableSource right = join.getRight();
if (right instanceof SQLExprTableSource) {
SQLExpr tableExpr = ((SQLExprTableSource) right).getExpr();
if (tableExpr instanceof SQLIdentifierExpr) {
String tableName = ((SQLIdentifierExpr) tableExpr).getName();
if (ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
String alias = right.getAlias();
if (alias == null) {
alias = tableName;
}
SQLBinaryOpExpr tenantCondition = createTenantCondition(visitor, alias, statementType, tableName);
if (condition == null) {
condition = tenantCondition;
} else {
condition = new SQLBinaryOpExpr(tenantCondition, SQLBinaryOperator.BooleanAnd, condition);
}
}
}
}
if (condition != join.getCondition()) {
join.setCondition(condition);
visitor.setSqlModified(true);
}
}
@Deprecated
private static SQLBinaryOpExpr createTenantCondition(WallVisitor visitor, String alias,
StatementType statementType, String tableName) {
SQLExpr left, right;
if (alias != null) {
left = new SQLPropertyExpr(new SQLIdentifierExpr(alias), visitor.getConfig().getTenantColumn());
} else {
left = new SQLIdentifierExpr(visitor.getConfig().getTenantColumn());
}
right = generateTenantValue(visitor, alias, statementType, tableName);
SQLBinaryOpExpr tenantCondition = new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right);
return tenantCondition;
}
private static SQLExpr generateTenantValue(WallVisitor visitor, String alias, StatementType statementType,
String tableName) {
SQLExpr value;
TenantCallBack callBack = visitor.getConfig().getTenantCallBack();
if (callBack != null) {
WallProvider.setTenantValue(callBack.getTenantValue(statementType, tableName));
}
Object tenantValue = WallProvider.getTenantValue();
if (tenantValue instanceof Number) {
value = new SQLNumberExpr((Number) tenantValue);
} else if (tenantValue instanceof String) {
value = new SQLCharExpr((String) tenantValue);
} else {
throw new IllegalStateException("tenant value not support type " + tenantValue);
}
return value;
}
public static void checkReadOnly(WallVisitor visitor, SQLTableSource tableSource) {
if (tableSource instanceof SQLExprTableSource) {
String tableName = null;
SQLExpr tableNameExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableNameExpr instanceof SQLName) {
tableName = ((SQLName) tableNameExpr).getSimpleName();
}
boolean readOnlyValid = visitor.getProvider().checkReadOnlyTable(tableName);
if (!readOnlyValid) {
addViolation(visitor, ErrorCode.READ_ONLY, "table readonly : " + tableName, tableSource);
}
} else if (tableSource instanceof SQLJoinTableSource) {
SQLJoinTableSource join = (SQLJoinTableSource) tableSource;
checkReadOnly(visitor, join.getLeft());
checkReadOnly(visitor, join.getRight());
}
}
public static void checkUpdate(WallVisitor visitor, SQLUpdateStatement x) {
checkReadOnly(visitor, x.getTableSource());
WallConfig config = visitor.getConfig();
if (!config.isUpdateAllow()) {
addViolation(visitor, ErrorCode.UPDATE_NOT_ALLOW, "update not allow", x);
return;
}
SQLExpr where = x.getWhere();
if (where == null) {
WallContext context = WallContext.current();
if (context != null) {
context.incrementUpdateNoneConditionWarnings();
}
if (config.isUpdateWhereNoneCheck()) {
if (x instanceof MySqlUpdateStatement) {
MySqlUpdateStatement mysqlUpdate = (MySqlUpdateStatement) x;
if (mysqlUpdate.getLimit() == null) {
addViolation(visitor, ErrorCode.NONE_CONDITION, "update none condition not allow", x);
return;
}
} else {
addViolation(visitor, ErrorCode.NONE_CONDITION, "update none condition not allow", x);
return;
}
}
} else {
checkCondition(visitor, where);
if (Boolean.TRUE == getConditionValue(visitor, where, config.isUpdateWhereAlayTrueCheck())) {
if (config.isUpdateWhereAlayTrueCheck() && visitor.isSqlEndOfComment()&& !isSimpleConstExpr(where)) {
addViolation(visitor, ErrorCode.ALWAYS_TRUE, "update alway true condition not allow", x);
}
}
SQLName table = x.getTableName();
if (table == null) {
return;
}
String tableName = table.getSimpleName();
Set updateCheckColumns = config.getUpdateCheckTable(tableName);
boolean isUpdateCheckTable = updateCheckColumns != null && !updateCheckColumns.isEmpty();
WallUpdateCheckHandler updateCheckHandler = config.getUpdateCheckHandler();
if (isUpdateCheckTable && updateCheckHandler != null) {
String checkColumn = updateCheckColumns.iterator().next();
SQLExpr valueExpr = null;
for (SQLUpdateSetItem item : x.getItems()) {
if (item.columnMatch(checkColumn)) {
valueExpr = item.getValue();
break;
}
}
if (valueExpr != null) {
List conditions;
if (where instanceof SQLBinaryOpExpr) {
conditions = SQLBinaryOpExpr.split((SQLBinaryOpExpr) where, SQLBinaryOperator.BooleanAnd);
} else if (where instanceof SQLBinaryOpExprGroup) {
conditions = new ArrayList();
for (SQLExpr each : ((SQLBinaryOpExprGroup) where).getItems()) {
if (each instanceof SQLBinaryOpExpr) {
conditions.addAll(SQLBinaryOpExpr.split((SQLBinaryOpExpr) each, SQLBinaryOperator.BooleanAnd));
} else if (each instanceof SQLInListExpr) {
conditions.add(each);
}
}
} else {
conditions = new ArrayList();
conditions.add(where);
}
List filterValueExprList = new ArrayList();
for (SQLExpr condition : conditions) {
if (condition instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr binaryCondition = (SQLBinaryOpExpr) condition;
if (binaryCondition.getOperator() == SQLBinaryOperator.Equality
&& binaryCondition.conditionContainsColumn(checkColumn)) {
SQLExpr left = binaryCondition.getLeft();
SQLExpr right = binaryCondition.getRight();
if (left instanceof SQLValuableExpr || left instanceof SQLVariantRefExpr) {
filterValueExprList.add(left);
} else if (right instanceof SQLValuableExpr || right instanceof SQLVariantRefExpr) {
filterValueExprList.add(right);
}
}
} else if (condition instanceof SQLInListExpr) {
SQLInListExpr listExpr = (SQLInListExpr) condition;
if (listExpr.getExpr() instanceof SQLIdentifierExpr) {
SQLIdentifierExpr nameExpr = (SQLIdentifierExpr) listExpr.getExpr();
if (nameExpr.getName().equals(checkColumn)) {
filterValueExprList.addAll(((SQLInListExpr) condition).getTargetList());
}
}
}
}
boolean allValue = valueExpr instanceof SQLValuableExpr;
if (allValue) {
for (SQLExpr filterValue : filterValueExprList) {
if (!(filterValue instanceof SQLValuableExpr)) {
allValue = false;
break;
}
}
}
if (allValue) {
Object setValue = ((SQLValuableExpr) valueExpr).getValue();
List