gu.sql2java.parser.StatementCache Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of sql2java-manager Show documentation
Show all versions of sql2java-manager Show documentation
sql2java manager class package for accessing database
package gu.sql2java.parser;
import static gu.sql2java.SimpleLog.log;
import static com.google.common.base.Preconditions.checkNotNull;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Maps;
import gu.sql2java.SqlFormatter;
import gu.sql2java.exception.RuntimeDaoException;
import gu.sql2java.parser.ParserSupport.SqlParserInfo;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.parser.CCJSqlParserDefaultVisitor;
import net.sf.jsqlparser.parser.CCJSqlParserVisitor;
import net.sf.jsqlparser.parser.SimpleNode;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.util.TablesNamesFinder;
/**
* 基于{@link LoadingCache}实现SQL语句解析格式化缓存,提高SQL语句分析的效率
*
* @author guyadong
*
*/
public class StatementCache {
/**
* 基于抽象语法树(AST)遍历所有语法节点的接口实例
*/
private final CCJSqlParserVisitor visitor;
private SqlSyntaxNormalizer sqlSyntaxNormalizer;
/**
* 保存成功解析的SQL语句对应的解析对象
*/
private final LoadingCache statementCache = CacheBuilder.newBuilder()
/** 设置数据缓存有效期 */
.expireAfterAccess(1, TimeUnit.HOURS)
.build(new CacheLoader() {
@Override
public SqlParserInfo load(String key) throws Exception {
SqlParserInfo sqlParserInfo = ParserSupport.parse0(key, visitor, sqlSyntaxNormalizer);
return sqlParserInfo;
}
});
/** 已经检出的危险SQL脚本集合 */
private final ConcurrentMap dangrousSqls = Maps.newConcurrentMap();
/** 解析失败的SQL脚本集合 */
private final ConcurrentMap invalidSqls = Maps.newConcurrentMap();
private final SqlInjectionAnalyzer injectAnalyzer;
public StatementCache() {
this((CCJSqlParserDefaultVisitor) null, null);
}
public StatementCache(CCJSqlParserDefaultVisitor vistor, SqlSyntaxNormalizer sqlSyntaxNormalizer) {
this.visitor = vistor;
this.injectAnalyzer = new SqlInjectionAnalyzer();
this.sqlSyntaxNormalizer = sqlSyntaxNormalizer;
}
public StatementCache(CCJSqlParserVisitor vistor,SqlSyntaxNormalizer sqlSyntaxNormalizer) {
this(new AstNodeVisitor(vistor), sqlSyntaxNormalizer);
}
public StatementCache(SqlFormatter sqlFormatter, SqlSyntaxNormalizer sqlSyntaxNormalizer) {
this(new AstNodeVisitor(sqlFormatter), sqlSyntaxNormalizer);
}
public StatementCache injectCheckEnable(boolean enable){
injectAnalyzer.injectCheckEnable(enable);
return this;
}
/**
* 解析SQL语句,解析成功返回保存解析数据的{@link SqlParserInfo}对象,
* 否则将解析异常{@link net.sf.jsqlparser.JSQLParserException}封装到{@link RuntimeDaoException}抛出
*
* @param sql
* @param injectAnalyze 为{@code true}执行注入攻击分析
*/
public SqlParserInfo parse(String sql, boolean injectAnalyze) {
if(null != sql){
RuntimeDaoException rde;
if(null != (rde = dangrousSqls.get(sql))){
throw rde;
}
if(null != (rde = invalidSqls.get(sql))){
throw rde;
}
}
try {
SqlParserInfo sqlParserInfo = statementCache.get(sql);
return injectAnalyze ? injectAnalyzer.injectAnalyse(sqlParserInfo) : sqlParserInfo;
} catch (ExecutionException e) {
RuntimeDaoException rde = new RuntimeDaoException(e.getCause());
invalidSqls.put(sql,rde);
throw rde;
}catch (InjectionAttackException e) {
RuntimeDaoException rde = new RuntimeDaoException(e);
dangrousSqls.put(sql,rde);
throw rde;
}
}
/**
* 解析SQL语句,解析成功则返回由{@link #visitor}归一化处理后的SQL语句,
* 否则将解析异常{@link net.sf.jsqlparser.JSQLParserException}封装到{@link RuntimeDaoException}抛出
*
* @param sql
* @param injectAnalyze 为{@code true}执行注入攻击分析
*/
public String normalize(String sql, boolean injectAnalyze) {
return parse(sql, injectAnalyze).nativeSql;
}
/**
* 调用{@link Connection#prepareStatement(String, int, int)}创建预编译SQL语句{@link PreparedStatement}对象,
* 调用前执行{@link #normalize(String, boolean)}方法对SQL语句进行归一化处理和安全检查
* @param c SQL connection
* @param sql sql statement
* @param injectAnalyze run injection attack analysis if true
* @param debug output SQL statement to console if true
* @param logPrefix prefix string for debug information
* @param resultSetType see also {@link Connection#prepareStatement(String, int, int)}
* @param resultSetConcurrency see also {@link Connection#prepareStatement(String, int, int)}
* @throws SQLException
* @see Connection#prepareStatement(String, int, int)
*/
public PreparedStatement prepareStatement(Connection c, String sql, boolean injectAnalyze, boolean debug,
String logPrefix,
int resultSetType, int resultSetConcurrency) throws SQLException {
sql = normalize(checkNotNull(sql, "sql is null"), injectAnalyze);
if (debug) {
log(logPrefix + " : " + sql);
}
return checkNotNull(c, "connection is null").prepareStatement(sql, resultSetType, resultSetConcurrency);
}
/**
* 调用{@link Connection#prepareStatement(String, int)}创建预编译SQL语句{@link PreparedStatement}对象,
* 调用前执行{@link #normalize(String, boolean)}方法对SQL语句进行归一化处理和安全检查
* @param c SQL connection
* @param sql sql statement
* @param injectAnalyze run injection attack analysis if true
* @param debug output SQL statement to console if true
* @param logPrefix prefix string for debug information
* @param autoGeneratedKeys see also {@link Connection#prepareStatement(String, int)}
* @throws SQLException
* @see Connection#prepareStatement(String, int)
*/
public PreparedStatement prepareStatement(Connection c, String sql, boolean injectAnalyze, boolean debug,
String logPrefix, int autoGeneratedKeys) throws SQLException {
sql = normalize(checkNotNull(sql, "sql is null"), injectAnalyze);
if (debug) {
log("{} : {}", logPrefix, sql);
}
return checkNotNull(c, "connection is null").prepareStatement(sql, autoGeneratedKeys);
}
/**
* 调用{@link Connection#prepareStatement(String)}创建预编译SQL语句{@link PreparedStatement}对象,
* 调用前执行{@link #normalize(String, boolean)}方法对SQL语句进行归一化处理和安全检查
* @param c SQL connection
* @param sql sql statement
* @param injectAnalyze run injection attack analysis if true
* @param debug output SQL statement to console if true
* @param logPrefix prefix string for debug information
* @throws SQLException
* @see Connection#prepareStatement(String)
*/
public PreparedStatement prepareStatement(Connection c, String sql, boolean injectAnalyze, boolean debug, String logPrefix)
throws SQLException {
sql = normalize(checkNotNull(sql, "sql is null"), injectAnalyze);
if (debug) {
log(logPrefix + " : " + sql);
}
return checkNotNull(c, "connection is null").prepareStatement(sql);
}
/**
* 基于抽象语法树(AST)遍历节点的{@link CCJSqlParserVisitor}封装
* @author guyadong
*
*/
private static class AstNodeVisitor extends CCJSqlParserDefaultVisitor {
private final CCJSqlParserVisitor visitor;
AstNodeVisitor(CCJSqlParserVisitor visitor) {
this.visitor = visitor;
}
AstNodeVisitor(NodeVisitor finder) {
this(null == finder ? null : (node, data) -> {
Object value = node.jjtGetValue();
if(value instanceof Column ){
finder.visit((Column) value);
}else if (value instanceof Table) {
finder.visit((Table) value);
}else if (value instanceof SelectExpressionItem) {
finder.visit((SelectExpressionItem) value);
}else if (value instanceof FromItem) {
finder.visit((FromItem)value);
}
return data;
});
}
AstNodeVisitor(SqlFormatter sqlFormatter) {
this(null == sqlFormatter ? null : new NodeVisitor(sqlFormatter));
}
@Override
public Object visit(SimpleNode node, Object data) {
if (null != visitor) {
visitor.visit(node, data);
}
return super.visit(node, data);
}
}
/**
* 基于{@link TablesNamesFinder}对象遍历所有对象的封装
* @author guyadong
*
*/
private static class NodeVisitor extends TablesNamesFinder{
private final SqlFormatter sqlFormatter;
NodeVisitor(SqlFormatter sqlFormatter) {
this.sqlFormatter = sqlFormatter;
init(true);
}
private void visit(Alias alias) {
if(null != sqlFormatter){
if(null != alias){
alias.setName(sqlFormatter.alias(alias.getName()));
}
}
}
void visit(FromItem fromItem) {
if(null != sqlFormatter){
if(null != fromItem){
visit(fromItem.getAlias());
}
}
}
@Override
public void visit(Column column) {
/** 名为true,false(不区分大小写)的column视为布尔值,不做处理 */
if(null != sqlFormatter && !ParserSupport.isBoolean(column)){
column.setColumnName(sqlFormatter.columname(column.getColumnName()));
}
super.visit(column);
}
@Override
public void visit(SelectExpressionItem item) {
if(null != sqlFormatter){
visit(item.getAlias());
}
super.visit(item);
}
@Override
public void visit(Table table) {
if(null != sqlFormatter){
table.setName(sqlFormatter.columname(table.getName()));
visit(table.getAlias());
}
super.visit(table);
}
}
}