Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
studio.raptor.sqlparser.wall.WallProvider Maven / Gradle / Ivy
/*
* Copyright 1999-2017 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package studio.raptor.sqlparser.wall;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import studio.raptor.sqlparser.SQLUtils;
import studio.raptor.sqlparser.ast.SQLStatement;
import studio.raptor.sqlparser.dialect.mysql.ast.statement.MySqlHintStatement;
import studio.raptor.sqlparser.parser.Lexer;
import studio.raptor.sqlparser.parser.NotAllowCommentException;
import studio.raptor.sqlparser.parser.ParserException;
import studio.raptor.sqlparser.parser.SQLStatementParser;
import studio.raptor.sqlparser.parser.Token;
import studio.raptor.sqlparser.util.LRUCache;
import studio.raptor.sqlparser.visitor.ExportParameterVisitor;
import studio.raptor.sqlparser.visitor.ParameterizedOutputVisitorUtils;
import studio.raptor.sqlparser.wall.spi.WallVisitorUtils;
import studio.raptor.sqlparser.wall.violation.ErrorCode;
import studio.raptor.sqlparser.wall.violation.IllegalSQLObjectViolation;
import studio.raptor.sqlparser.wall.violation.SyntaxErrorViolation;
public abstract class WallProvider {
private static final ThreadLocal privileged = new ThreadLocal();
private static final ThreadLocal tenantValueLocal = new ThreadLocal();
public final WallDenyStat commentDeniedStat = new WallDenyStat();
protected final WallConfig config;
protected final AtomicLong checkCount = new AtomicLong();
protected final AtomicLong hardCheckCount = new AtomicLong();
protected final AtomicLong whiteListHitCount = new AtomicLong();
protected final AtomicLong blackListHitCount = new AtomicLong();
protected final AtomicLong syntaxErrorCount = new AtomicLong();
protected final AtomicLong violationCount = new AtomicLong();
protected final AtomicLong violationEffectRowCount = new AtomicLong();
private final Map attributes = new ConcurrentHashMap(
1,
0.75f,
1);
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private final ConcurrentMap functionStats = new ConcurrentHashMap(
16,
0.75f,
1);
private final ConcurrentMap tableStats = new ConcurrentHashMap(
16,
0.75f,
1);
protected String dbType = null;
private String name;
private boolean whiteListEnable = true;
private LRUCache whiteList;
private int MAX_SQL_LENGTH = 8192; // 8k
private int whiteSqlMaxSize = 1000;
private boolean blackListEnable = true;
private LRUCache blackList;
private LRUCache blackMergedList;
private int blackSqlMaxSize = 200;
public WallProvider(WallConfig config) {
this.config = config;
}
public WallProvider(WallConfig config, String dbType) {
this.config = config;
this.dbType = dbType;
}
public static boolean ispPrivileged() {
Boolean value = privileged.get();
if (value == null) {
return false;
}
return value;
}
public static T doPrivileged(PrivilegedAction action) {
final Boolean original = privileged.get();
privileged.set(Boolean.TRUE);
try {
return action.run();
} finally {
privileged.set(original);
}
}
public static Object getTenantValue() {
return tenantValueLocal.get();
}
public static void setTenantValue(Object value) {
tenantValueLocal.set(value);
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Map getAttributes() {
return attributes;
}
public void reset() {
this.checkCount.set(0);
this.hardCheckCount.set(0);
this.violationCount.set(0);
this.whiteListHitCount.set(0);
this.blackListHitCount.set(0);
this.clearWhiteList();
this.clearBlackList();
this.functionStats.clear();
this.tableStats.clear();
}
public ConcurrentMap getTableStats() {
return this.tableStats;
}
public ConcurrentMap getFunctionStats() {
return this.functionStats;
}
public WallSqlStat getSqlStat(String sql) {
WallSqlStat sqlStat = this.getWhiteSql(sql);
if (sqlStat == null) {
sqlStat = this.getBlackSql(sql);
}
return sqlStat;
}
public WallTableStat getTableStat(String tableName) {
String lowerCaseName = tableName.toLowerCase();
if (lowerCaseName.startsWith("`") && lowerCaseName.endsWith("`")) {
lowerCaseName = lowerCaseName.substring(1, lowerCaseName.length() - 1);
}
return getTableStatWithLowerName(lowerCaseName);
}
public void addUpdateCount(WallSqlStat sqlStat, long updateCount) {
sqlStat.addUpdateCount(updateCount);
Map sqlTableStats = sqlStat.getTableStats();
if (sqlTableStats == null) {
return;
}
for (Map.Entry entry : sqlTableStats.entrySet()) {
String tableName = entry.getKey();
WallTableStat tableStat = this.getTableStat(tableName);
if (tableStat == null) {
continue;
}
WallSqlTableStat sqlTableStat = entry.getValue();
if (sqlTableStat.getDeleteCount() > 0) {
tableStat.addDeleteDataCount(updateCount);
} else if (sqlTableStat.getUpdateCount() > 0) {
tableStat.addUpdateDataCount(updateCount);
} else if (sqlTableStat.getInsertCount() > 0) {
tableStat.addInsertDataCount(updateCount);
}
}
}
public void addFetchRowCount(WallSqlStat sqlStat, long fetchRowCount) {
sqlStat.addAndFetchRowCount(fetchRowCount);
Map sqlTableStats = sqlStat.getTableStats();
if (sqlTableStats == null) {
return;
}
for (Map.Entry entry : sqlTableStats.entrySet()) {
String tableName = entry.getKey();
WallTableStat tableStat = this.getTableStat(tableName);
if (tableStat == null) {
continue;
}
WallSqlTableStat sqlTableStat = entry.getValue();
if (sqlTableStat.getSelectCount() > 0) {
tableStat.addFetchRowCount(fetchRowCount);
}
}
}
public WallTableStat getTableStatWithLowerName(String lowerCaseName) {
WallTableStat stat = tableStats.get(lowerCaseName);
if (stat == null) {
if (tableStats.size() > 10000) {
return null;
}
tableStats.putIfAbsent(lowerCaseName, new WallTableStat());
stat = tableStats.get(lowerCaseName);
}
return stat;
}
public WallFunctionStat getFunctionStat(String functionName) {
String lowerCaseName = functionName.toLowerCase();
return getFunctionStatWithLowerName(lowerCaseName);
}
public WallFunctionStat getFunctionStatWithLowerName(String lowerCaseName) {
WallFunctionStat stat = functionStats.get(lowerCaseName);
if (stat == null) {
if (functionStats.size() > 10000) {
return null;
}
functionStats.putIfAbsent(lowerCaseName, new WallFunctionStat());
stat = functionStats.get(lowerCaseName);
}
return stat;
}
public WallConfig getConfig() {
return config;
}
public WallSqlStat addWhiteSql(String sql, Map tableStats,
Map functionStats, boolean syntaxError) {
if (!whiteListEnable) {
WallSqlStat stat = new WallSqlStat(tableStats, functionStats, syntaxError);
return stat;
}
String mergedSql;
try {
mergedSql = ParameterizedOutputVisitorUtils.parameterize(sql, dbType);
} catch (Exception ex) {
WallSqlStat stat = new WallSqlStat(tableStats, functionStats, syntaxError);
stat.incrementAndGetExecuteCount();
return stat;
}
if (mergedSql != sql) {
WallSqlStat mergedStat;
lock.readLock().lock();
try {
if (whiteList == null) {
whiteList = new LRUCache(whiteSqlMaxSize);
}
mergedStat = whiteList.get(mergedSql);
} finally {
lock.readLock().unlock();
}
if (mergedStat == null) {
WallSqlStat newStat = new WallSqlStat(tableStats, functionStats, syntaxError);
newStat.setSample(sql);
lock.writeLock().lock();
try {
mergedStat = whiteList.get(mergedSql);
if (mergedStat == null) {
whiteList.put(mergedSql, newStat);
mergedStat = newStat;
}
} finally {
lock.writeLock().unlock();
}
}
mergedStat.incrementAndGetExecuteCount();
return mergedStat;
}
lock.writeLock().lock();
try {
if (whiteList == null) {
whiteList = new LRUCache(whiteSqlMaxSize);
}
WallSqlStat wallStat = whiteList.get(sql);
if (wallStat == null) {
wallStat = new WallSqlStat(tableStats, functionStats, syntaxError);
whiteList.put(sql, wallStat);
wallStat.setSample(sql);
wallStat.incrementAndGetExecuteCount();
}
return wallStat;
} finally {
lock.writeLock().unlock();
}
}
public WallSqlStat addBlackSql(String sql, Map tableStats,
Map functionStats, List violations,
boolean syntaxError) {
if (!blackListEnable) {
return new WallSqlStat(tableStats, functionStats, violations, syntaxError);
}
String mergedSql;
try {
mergedSql = ParameterizedOutputVisitorUtils.parameterize(sql, dbType);
} catch (Exception ex) {
// skip
mergedSql = sql;
}
lock.writeLock().lock();
try {
if (blackList == null) {
blackList = new LRUCache(blackSqlMaxSize);
}
if (blackMergedList == null) {
blackMergedList = new LRUCache(blackSqlMaxSize);
}
WallSqlStat wallStat = blackList.get(sql);
if (wallStat == null) {
wallStat = blackMergedList.get(mergedSql);
if (wallStat == null) {
wallStat = new WallSqlStat(tableStats, functionStats, violations, syntaxError);
blackMergedList.put(mergedSql, wallStat);
wallStat.setSample(sql);
}
wallStat.incrementAndGetExecuteCount();
blackList.put(sql, wallStat);
}
return wallStat;
} finally {
lock.writeLock().unlock();
}
}
public Set getWhiteList() {
Set hashSet = new HashSet();
lock.readLock().lock();
try {
if (whiteList != null) {
hashSet.addAll(whiteList.keySet());
}
} finally {
lock.readLock().unlock();
}
return Collections.unmodifiableSet(hashSet);
}
public Set getSqlList() {
Set hashSet = new HashSet();
lock.readLock().lock();
try {
if (whiteList != null) {
hashSet.addAll(whiteList.keySet());
}
if (blackMergedList != null) {
hashSet.addAll(blackMergedList.keySet());
}
} finally {
lock.readLock().unlock();
}
return Collections.unmodifiableSet(hashSet);
}
public Set getBlackList() {
Set hashSet = new HashSet();
lock.readLock().lock();
try {
if (blackList != null) {
hashSet.addAll(blackList.keySet());
}
} finally {
lock.readLock().unlock();
}
return Collections.unmodifiableSet(hashSet);
}
public void clearCache() {
lock.writeLock().lock();
try {
if (whiteList != null) {
whiteList = null;
}
if (blackList != null) {
blackList = null;
}
if (blackMergedList != null) {
blackMergedList = null;
}
} finally {
lock.writeLock().unlock();
}
}
public void clearWhiteList() {
lock.writeLock().lock();
try {
if (whiteList != null) {
whiteList = null;
}
} finally {
lock.writeLock().unlock();
}
}
public void clearBlackList() {
lock.writeLock().lock();
try {
if (blackList != null) {
blackList = null;
}
} finally {
lock.writeLock().unlock();
}
}
public WallSqlStat getWhiteSql(String sql) {
WallSqlStat stat = null;
lock.readLock().lock();
try {
if (whiteList == null) {
return null;
}
stat = whiteList.get(sql);
} finally {
lock.readLock().unlock();
}
if (stat != null) {
return stat;
}
String mergedSql;
try {
mergedSql = ParameterizedOutputVisitorUtils.parameterize(sql, dbType);
} catch (Exception ex) {
// skip
return null;
}
lock.readLock().lock();
try {
stat = whiteList.get(mergedSql);
} finally {
lock.readLock().unlock();
}
return stat;
}
public WallSqlStat getBlackSql(String sql) {
lock.readLock().lock();
try {
if (blackList == null) {
return null;
}
return blackList.get(sql);
} finally {
lock.readLock().unlock();
}
}
public boolean whiteContains(String sql) {
return getWhiteSql(sql) != null;
}
public abstract SQLStatementParser createParser(String sql);
public abstract WallVisitor createWallVisitor();
public abstract ExportParameterVisitor createExportParameterVisitor();
public boolean checkValid(String sql) {
WallContext originalContext = WallContext.current();
try {
WallContext.create(dbType);
WallCheckResult result = checkInternal(sql);
return result.getViolations().isEmpty();
} finally {
if (originalContext == null) {
WallContext.clearContext();
}
}
}
public void incrementCommentDeniedCount() {
this.commentDeniedStat.incrementAndGetDenyCount();
}
public boolean checkDenyFunction(String functionName) {
if (functionName == null) {
return true;
}
functionName = functionName.toLowerCase();
return !getConfig().getDenyFunctions().contains(functionName);
}
public boolean checkDenySchema(String schemaName) {
if (schemaName == null) {
return true;
}
if (!this.config.isSchemaCheck()) {
return true;
}
schemaName = schemaName.toLowerCase();
return !getConfig().getDenySchemas().contains(schemaName);
}
public boolean checkDenyTable(String tableName) {
if (tableName == null) {
return true;
}
tableName = WallVisitorUtils.form(tableName);
return !getConfig().getDenyTables().contains(tableName);
}
public boolean checkReadOnlyTable(String tableName) {
if (tableName == null) {
return true;
}
tableName = WallVisitorUtils.form(tableName);
return !getConfig().isReadOnly(tableName);
}
public WallDenyStat getCommentDenyStat() {
return this.commentDeniedStat;
}
public WallCheckResult check(String sql) {
WallContext originalContext = WallContext.current();
try {
WallContext.createIfNotExists(dbType);
return checkInternal(sql);
} finally {
if (originalContext == null) {
WallContext.clearContext();
}
}
}
private WallCheckResult checkInternal(String sql) {
checkCount.incrementAndGet();
WallContext context = WallContext.current();
if (config.isDoPrivilegedAllow() && ispPrivileged()) {
WallCheckResult checkResult = new WallCheckResult();
checkResult.setSql(sql);
return checkResult;
}
// first step, check whiteList
boolean mulltiTenant =
config.getTenantTablePattern() != null && config.getTenantTablePattern().length() > 0;
if (!mulltiTenant) {
WallCheckResult checkResult = checkWhiteAndBlackList(sql);
if (checkResult != null) {
checkResult.setSql(sql);
return checkResult;
}
}
hardCheckCount.incrementAndGet();
final List violations = new ArrayList();
List statementList = new ArrayList();
boolean syntaxError = false;
boolean endOfComment = false;
try {
SQLStatementParser parser = createParser(sql);
parser.getLexer().setCommentHandler(WallCommentHandler.instance);
if (!config.isCommentAllow()) {
parser.getLexer().setAllowComment(false); // deny comment
}
if (!config.isCompleteInsertValuesCheck()) {
parser.setParseCompleteValues(false);
parser.setParseValuesSize(config.getInsertValuesCheckSize());
}
parser.parseStatementList(statementList);
final Token lastToken = parser.getLexer().token();
if (lastToken != Token.EOF && config.isStrictSyntaxCheck()) {
violations
.add(new IllegalSQLObjectViolation(ErrorCode.SYNTAX_ERROR, "not terminal sql, token "
+ lastToken, sql));
}
endOfComment = parser.getLexer().isEndOfComment();
} catch (NotAllowCommentException e) {
violations.add(
new IllegalSQLObjectViolation(ErrorCode.COMMENT_STATEMENT_NOT_ALLOW, "comment not allow",
sql));
incrementCommentDeniedCount();
} catch (ParserException e) {
syntaxErrorCount.incrementAndGet();
syntaxError = true;
if (config.isStrictSyntaxCheck()) {
violations.add(new SyntaxErrorViolation(e, sql));
}
} catch (Exception e) {
if (config.isStrictSyntaxCheck()) {
violations.add(new SyntaxErrorViolation(e, sql));
}
}
if (statementList.size() > 1 && !config.isMultiStatementAllow()) {
violations.add(
new IllegalSQLObjectViolation(ErrorCode.MULTI_STATEMENT, "multi-statement not allow",
sql));
}
WallVisitor visitor = createWallVisitor();
visitor.setSqlEndOfComment(endOfComment);
if (statementList.size() > 0) {
boolean lastIsHint = false;
for (int i = 0; i < statementList.size(); i++) {
SQLStatement stmt = statementList.get(i);
if ((i == 0 || lastIsHint) && stmt instanceof MySqlHintStatement) {
lastIsHint = true;
continue;
}
try {
stmt.accept(visitor);
} catch (ParserException e) {
violations.add(new SyntaxErrorViolation(e, sql));
}
}
}
if (visitor.getViolations().size() > 0) {
violations.addAll(visitor.getViolations());
}
WallSqlStat sqlStat = null;
if (violations.size() > 0) {
violationCount.incrementAndGet();
if (sql.length() < MAX_SQL_LENGTH) {
sqlStat = addBlackSql(sql, context.getTableStats(), context.getFunctionStats(), violations,
syntaxError);
}
} else {
if (sql.length() < MAX_SQL_LENGTH) {
sqlStat = addWhiteSql(sql, context.getTableStats(), context.getFunctionStats(),
syntaxError);
}
}
Map tableStats = null;
Map functionStats = null;
if (context != null) {
tableStats = context.getTableStats();
functionStats = context.getFunctionStats();
recordStats(tableStats, functionStats);
}
WallCheckResult result;
if (sqlStat != null) {
context.setSqlStat(sqlStat);
result = new WallCheckResult(sqlStat, statementList);
} else {
result = new WallCheckResult(null, violations, tableStats, functionStats, statementList,
syntaxError);
}
String resultSql;
if (visitor.isSqlModified()) {
resultSql = SQLUtils.toSQLString(statementList, dbType);
} else {
resultSql = sql;
}
result.setSql(resultSql);
return result;
}
private WallCheckResult checkWhiteAndBlackList(String sql) {
// check black list
if (blackListEnable) {
WallSqlStat sqlStat = getBlackSql(sql);
if (sqlStat != null) {
blackListHitCount.incrementAndGet();
violationCount.incrementAndGet();
if (sqlStat.isSyntaxError()) {
syntaxErrorCount.incrementAndGet();
}
sqlStat.incrementAndGetExecuteCount();
recordStats(sqlStat.getTableStats(), sqlStat.getFunctionStats());
return new WallCheckResult(sqlStat);
}
}
if (whiteListEnable) {
WallSqlStat sqlStat = getWhiteSql(sql);
if (sqlStat != null) {
whiteListHitCount.incrementAndGet();
sqlStat.incrementAndGetExecuteCount();
if (sqlStat.isSyntaxError()) {
syntaxErrorCount.incrementAndGet();
}
recordStats(sqlStat.getTableStats(), sqlStat.getFunctionStats());
WallContext context = WallContext.current();
if (context != null) {
context.setSqlStat(sqlStat);
}
return new WallCheckResult(sqlStat);
}
}
return null;
}
void recordStats(Map tableStats,
Map functionStats) {
if (tableStats != null) {
for (Map.Entry entry : tableStats.entrySet()) {
String tableName = entry.getKey();
WallSqlTableStat sqlTableStat = entry.getValue();
WallTableStat tableStat = getTableStat(tableName);
if (tableStat != null) {
tableStat.addSqlTableStat(sqlTableStat);
}
}
}
if (functionStats != null) {
for (Map.Entry entry : functionStats.entrySet()) {
String tableName = entry.getKey();
WallSqlFunctionStat sqlTableStat = entry.getValue();
WallFunctionStat functionStat = getFunctionStatWithLowerName(tableName);
if (functionStat != null) {
functionStat.addSqlFunctionStat(sqlTableStat);
}
}
}
}
public long getWhiteListHitCount() {
return whiteListHitCount.get();
}
public long getBlackListHitCount() {
return blackListHitCount.get();
}
public long getSyntaxErrorCount() {
return syntaxErrorCount.get();
}
public long getCheckCount() {
return checkCount.get();
}
public long getViolationCount() {
return violationCount.get();
}
public long getHardCheckCount() {
return hardCheckCount.get();
}
public long getViolationEffectRowCount() {
return violationEffectRowCount.get();
}
public void addViolationEffectRowCount(long rowCount) {
violationEffectRowCount.addAndGet(rowCount);
}
public boolean isWhiteListEnable() {
return whiteListEnable;
}
public void setWhiteListEnable(boolean whiteListEnable) {
this.whiteListEnable = whiteListEnable;
}
public boolean isBlackListEnable() {
return blackListEnable;
}
public void setBlackListEnable(boolean blackListEnable) {
this.blackListEnable = blackListEnable;
}
public static class WallCommentHandler implements Lexer.CommentHandler {
public final static WallCommentHandler instance = new WallCommentHandler();
@Override
public boolean handle(Token lastToken, String comment) {
if (lastToken == null) {
return false;
}
switch (lastToken) {
case SELECT:
case INSERT:
case DELETE:
case UPDATE:
case TRUNCATE:
case SET:
case CREATE:
case ALTER:
case DROP:
case SHOW:
case REPLACE:
return true;
default:
break;
}
WallContext context = WallContext.current();
if (context != null) {
context.incrementCommentCount();
}
return false;
}
}
}