com.kanlon.utils.SelfDruidSqlUtils Maven / Gradle / Ivy
Show all versions of elasticsearch-sql Show documentation
package com.kanlon.utils;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.PagerUtils;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLOrderBy;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static com.alibaba.druid.sql.SQLUtils.parseStatements;
import static com.alibaba.druid.sql.SQLUtils.toSQLExpr;
import static com.alibaba.druid.sql.SQLUtils.toSQLString;
/**
* 自定义的druid相关的解析的sql工具类,基于SQLUtils
*
* @author zhangcanlong
* @since 2022/08/10 21:25
**/
@Slf4j
public class SelfDruidSqlUtils {
/**
* 限制只能通过类名调用
*/
private SelfDruidSqlUtils() {}
/**
* 获取到表行数的 列名
**/
public static final String TABLE_ROW_COLUMN = "table_rows";
/**
* 计算获取某个sql 的总条数的sql,中的新列
*/
public static final String NUM_STR = "num";
/**
* global in 替换的列 字段
*/
public static final String GLOBAL_IN_REPLACE_COLUMN_STR = "_dashboard_global_in_column";
/**
* global not in 替换的列 字段
*/
public static final String GLOBAL_NOT_IN_REPLACE_COLUMN_STR = "_dashboard_global_not_in_column";
/**
* sql解析中包含剩余列的表名,如果再解析出的map中找不到字段名,才拼装成上这个表名,作为字段名
*/
public static final String SQL_PARSE_ALL_TABLE_NAME = "ds_AllTableColumns__tableName";
/**
* 双引号
*/
public static char DOUBLE_QUOTE = '"';
/**
* 空 值
*/
public static final String EMPTY = "";
/**
* 星号
*/
public static final String ASTERISK = "*";
/**
* 点
*/
public static final String DOT = ".";
/**
* 撇号
*/
public static final String BACKTICK = "`";
/**
* 在sql重新替换 会 包含 global in 特殊 列的sql,添加上 global in
*
* @param sql sql 要重新替换的sql(与SelfDruidSqlUtils#replaceGlobalInSql 对应)
* @return {@link String}
*/
public static String replaceGlobalInColumnSql(String sql) {
if (StringUtils.isEmpty(sql)) {
return sql;
}
sql = sql.replaceAll(GLOBAL_NOT_IN_REPLACE_COLUMN_STR, " global not ");
sql = sql.replaceAll(GLOBAL_IN_REPLACE_COLUMN_STR, " global ");
return sql;
}
/**
* 在sql取代 global in的sql
*
* @param sql sql 要替换的sql (与SelfDruidSqlUtils#replaceGlobalInColumnSql 对应)
* @return {@link String}
*/
public static String replaceGlobalInSql(String sql) {
if (StringUtils.isEmpty(sql)) {
return sql;
}
sql = sql.replaceAll("\\s+(global|GLOBAL)\\s{1,4}(not|NOT)\\s{1,4}(IN|in)", GLOBAL_NOT_IN_REPLACE_COLUMN_STR + " in");
sql = sql.replaceAll("\\s+(global|GLOBAL)\\s{1,4}(IN|in)", GLOBAL_IN_REPLACE_COLUMN_STR + " in");
return sql;
}
/**
* 根据表名获取表的的大概行数的sql
*
* @param tableName 表名
* @return 获取表行数的sql
**/
public static String getMysqlTableRowNumSql(String tableName) {
if (StringUtils.isBlank(tableName)) {
return EMPTY;
}
tableName = tableName.replace(BACKTICK, EMPTY).toLowerCase();
String dbName;
String actualTableName;
// 如果包含数据库
if (tableName.contains(DOT)) {
int dotIndex = tableName.indexOf(DOT);
dbName = tableName.substring(0, dotIndex);
actualTableName = tableName.substring(dotIndex + 1);
return "SELECT " + TABLE_ROW_COLUMN + " FROM information_schema.tables WHERE LOWER(TABLE_SCHEMA) = '" + dbName + "' AND LOWER(table_name)='" + actualTableName + "'";
}
return "SELECT " + TABLE_ROW_COLUMN + " FROM INFORMATION_SCHEMA.PARTITIONS WHERE LOWER(TABLE_NAME)='" + tableName + "'";
}
/**
* 得到实际的列 名称
*
* @param columnExpressMap 解析后的sql列对应关系
* @param column 列名
* @return 实际的列名称
*/
public static String getActualColumn(Map columnExpressMap, String column) {
// 如果包含了全部列的表,则拼装上表名,否则获取原始字段名
if (columnExpressMap.containsKey(SQL_PARSE_ALL_TABLE_NAME)) {
return columnExpressMap.getOrDefault(column, columnExpressMap.get(SQL_PARSE_ALL_TABLE_NAME) + DOT + column);
} else {
return columnExpressMap.getOrDefault(column, column);
}
}
/**
* 处理程序sql expr中的列名
*
* @param sqlExpr sql expr
* @param actualColumnMap 实际列映射
* @param dbType db型
*/
public static void handlerSqlExprColumnName(SQLExpr sqlExpr, Map actualColumnMap, DbType dbType) {
if (sqlExpr instanceof SQLInListExpr) {
SQLInListExpr sqlInListExpr = (SQLInListExpr) sqlExpr;
SQLExpr inSqlExpr1 = sqlInListExpr.getExpr();
if (inSqlExpr1 instanceof SQLIdentifierExpr) {
SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) inSqlExpr1;
String name = sqlIdentifierExpr.getName();
String actualName = SelfDruidSqlUtils.getActualColumn(actualColumnMap, name);
sqlInListExpr.setExpr(SQLUtils.toSQLExpr(actualName, dbType));
}
} else if (sqlExpr instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
SQLExpr leftSqlExpr = sqlBinaryOpExpr.getLeft();
SQLExpr rightSqlExpr = sqlBinaryOpExpr.getRight();
if (leftSqlExpr instanceof SQLIdentifierExpr) {
SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) leftSqlExpr;
String name = sqlIdentifierExpr.getName();
String actualName = SelfDruidSqlUtils.getActualColumn(actualColumnMap, name);
sqlBinaryOpExpr.setLeft(SQLUtils.toSQLExpr(actualName, dbType));
} else {
handlerSqlExprColumnName(leftSqlExpr, actualColumnMap, dbType);
}
handlerSqlExprColumnName(rightSqlExpr, actualColumnMap, dbType);
}
}
/**
* 解析sql获取sql中列的字段别名及其对应的表达式之前的关系,支持 单个select sql 和 union all sql
* (注意如果查询中包含多个全部列查询*,例如:t.*,t2.* 则有可能会出错的,会把未归属的列,归属到t2中)
*
* @param sql 要解析sql
* @param dbType db型
* @return 字段别名及其对应的表达式的map,如果包含别名才放入该集合
* @throws RuntimeException 运行时异常
*/
public static Map getColumnExpressMap(String sql, com.alibaba.druid.DbType dbType) throws RuntimeException {
Map columnExpressMap = new HashMap<>(16);
List sqlStatements;
try {
SQLSelectQueryBlock sqlSelectQueryBlock;
sqlStatements = parseStatements(sql, dbType);
SQLSelectQuery sqlSelectQuery = (((SQLSelectStatement) sqlStatements.get(sqlStatements.size() - 1)).getSelect()).getQuery();
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
} else if (sqlSelectQuery instanceof SQLUnionQuery) {
// union all
SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) sqlSelectQuery;
return getColumnExpressMap(SQLUtils.toSQLString(sqlUnionQuery.getLeft(), dbType), dbType);
} else {
throw new RuntimeException("无法解析sql!请更换sql或者联系系统管理员咨询支持的sql类型!");
}
List selectItems = sqlSelectQueryBlock.getSelectList();
// 查询全部列的查询项的个数,例如:t.*,则算一个
int allColumnSelectCnt = 0;
// 遍历条件项及获取表达式及其对应的字段别名
for (SQLSelectItem selectItem : selectItems) {
String columnAlias = selectItem.getAlias();
if (StringUtils.isNotEmpty(columnAlias)) {
columnAlias = columnAlias.replaceAll("['`]", EMPTY);
// 如果开头和结尾是 " 也去掉
if (Objects.equals(columnAlias.charAt(0), DOUBLE_QUOTE) && Objects.equals(columnAlias.charAt(columnAlias.length() - 1), DOUBLE_QUOTE)) {
columnAlias = columnAlias.replace(String.valueOf(DOUBLE_QUOTE), EMPTY);
}
columnExpressMap.put(columnAlias, SQLUtils.toSQLString(selectItem.getExpr()));
} else {
if (selectItem.getExpr() instanceof SQLPropertyExpr) {
SQLPropertyExpr sqlPropertyExpr = (SQLPropertyExpr) selectItem.getExpr();
if (Objects.equals(sqlPropertyExpr.getName(), ASTERISK)) {
allColumnSelectCnt++;
columnExpressMap.put(SQL_PARSE_ALL_TABLE_NAME, sqlPropertyExpr.getOwnerName());
} else {
columnExpressMap.put(((SQLPropertyExpr) selectItem.getExpr()).getName().replaceAll("['`]", EMPTY), SQLUtils.toSQLString(selectItem.getExpr()));
}
} else {
columnExpressMap.put(SQLUtils.toSQLString(selectItem.getExpr()), SQLUtils.toSQLString(selectItem.getExpr()));
}
}
}
// 如果查询全部列的查询项大于1,则去掉,查询全部列表的 key
if (allColumnSelectCnt > 1) {
columnExpressMap.remove(SQL_PARSE_ALL_TABLE_NAME);
}
} catch (Exception e) {
log.error("解析SQL错误!要解析的sql为【{}】", sql);
throw new RuntimeException("解析SQL错误!请确认SQL中字段关键字使用``和''合理括起来了!" + e.getMessage(), e);
}
return columnExpressMap;
}
/**
* 获取某个sql的中表名
*
* @param sql 要解析的sql的表名
* @param dbType 数据库类型
* @return 表名集合
**/
public static List getTableNamesBySql(String sql, DbType dbType) throws RuntimeException {
List tableNameList = new ArrayList<>(10);
if (StringUtils.isBlank(sql)) {
return tableNameList;
}
try {
// 将 global in 替换
sql = replaceGlobalInSql(sql);
List stmtList = SQLUtils.parseStatements(sql, dbType);
for (SQLStatement stmt : stmtList) {
SchemaStatVisitor schemaStatVisitor = new SchemaStatVisitor();
stmt.accept(schemaStatVisitor);
//获取表名称
Map nameTableStatMap = schemaStatVisitor.getTables();
for (TableStat.Name name : nameTableStatMap.keySet()) {
tableNameList.add(name.toString());
}
}
} catch (Exception e) {
log.error("解析SQL错误!要解析的sql为【{}】", sql);
throw new RuntimeException("解析SQL错误!" + e.getMessage(), e);
}
return tableNameList;
}
/**
* 替换sql中的查询项(仅支持,单个select的sql,如果为union all,则直接 在外面再嵌套一层 select *)
*
* @param sql 要替换查询项的sql
* @param express 替换后的查询表达式
* @param alias 替换后的查询别名(如果为null,则只设置的查询表达式)
* @param haveDistinct 表达式是否有distinct,如果包含distinct表达式,需要另外处理
* @param dbType 数据库类型
* @return 替换后sql查询项的sql
**/
public static String replaceSelectItem(String sql, String express, String alias, boolean haveDistinct, DbType dbType) throws RuntimeException {
SQLSelectQueryBlock sqlSelectQueryBlock;
SQLSelectStatement sqlSelectStatement;
SQLSelect sqlSelect;
SQLWithSubqueryClause withSubqueryClause;
try {
List sqlStatements = parseStatements(sql, dbType);
sqlSelectStatement = (SQLSelectStatement) sqlStatements.get(sqlStatements.size() - 1);
Map columnMap = getColumnExpressMap(sql, dbType);
sqlSelect = sqlSelectStatement.getSelect();
SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
withSubqueryClause = sqlSelect.getWithSubQuery();
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
} else if (sqlSelectQuery instanceof SQLUnionQuery) {
sql = "select * from (" + sql + ") temp_t";
List unionAllSqlStatements = parseStatements(sql, dbType);
SQLSelectQuery tempSqlSelectQuery = (((SQLSelectStatement) unionAllSqlStatements.get(unionAllSqlStatements.size() - 1)).getSelect()).getQuery();
sqlSelectQueryBlock = (SQLSelectQueryBlock) tempSqlSelectQuery;
} else {
throw new RuntimeException("无法解析sql!请更换sql或者联系系统管理员咨询支持的sql类型!");
}
SQLSelectGroupByClause selectGroupByClause = sqlSelectQueryBlock.getGroupBy();
if (selectGroupByClause != null) {
List groupBySqlExprList = selectGroupByClause.getItems();
for (int i = 0; i < groupBySqlExprList.size(); ++i) {
groupBySqlExprList.set(i, SQLUtils.toSQLExpr(SelfDruidSqlUtils.getActualColumn(columnMap, SQLUtils.toSQLString(groupBySqlExprList.get(i))), dbType));
}
}
String tempSql;
if (!StringUtils.isEmpty(alias)) {
tempSql = "select " + express + " as " + alias + " from temp1";
} else {
tempSql = "select " + express + " from temp1";
}
List tempSqlStatements = parseStatements(tempSql, dbType);
SQLSelectQuery tempSelectQuery = (((SQLSelectStatement) tempSqlStatements.get(tempSqlStatements.size() - 1)).getSelect()).getQuery();
sqlSelectQueryBlock.getSelectList().clear();
SQLSelectQueryBlock tempDistinctColumnQuery = (SQLSelectQueryBlock) tempSelectQuery;
for (SQLSelectItem selectItem : tempDistinctColumnQuery.getSelectList()) {
sqlSelectQueryBlock.getSelectList().add(selectItem);
}
if (haveDistinct) {
sqlSelectQueryBlock.setDistinct();
}
} catch (Exception e) {
log.error("替换查询列,解析SQL错误!要解析的sql为【{}】", sql);
throw new RuntimeException("解析SQL错误!" + e.getMessage(), e);
}
return (withSubqueryClause == null ? "" : SQLUtils.toSQLString(withSubqueryClause)) + "\n" + SQLUtils.toSQLString(sqlSelectQueryBlock, dbType);
}
/**
* 添加sql 的查询项
*
* @param sql sql
* @param selectItemSql 选择项sql
* @param dbType db型
* @return {@link String}
*/
public static String addSqlSelectItem(String sql, String selectItemSql, DbType dbType) throws RuntimeException {
String haveAddSqlSelectItemSql = sql;
if (StringUtils.isEmpty(sql) || StringUtils.isEmpty(selectItemSql)) {
return haveAddSqlSelectItemSql;
}
String[] needAddSelectItems = selectItemSql.split(" as|AS ");
String needAddSelectSql = needAddSelectItems[0];
String needAddSelectAlias = needAddSelectItems.length > 1 ? needAddSelectItems[1] : null;
String actualSelectColumn = StringUtils.isEmpty(needAddSelectAlias) ? needAddSelectSql : needAddSelectAlias;
Map columnMap;
try {
columnMap = getColumnExpressMap(sql, dbType);
} catch (RuntimeException e) {
throw new RuntimeException(e.getMessage(), e);
}
// 如果sql中已经存在了该列,则直接返回
if (columnMap.containsKey(actualSelectColumn)) {
return sql;
}
haveAddSqlSelectItemSql = SQLUtils.addSelectItem(sql, needAddSelectSql, needAddSelectAlias, dbType);
return haveAddSqlSelectItemSql;
}
/**
* 删除sql的某些查询项
*
* @param sql sql
* @param needRemoveSelectColumnSet 需要删除选择列集
* @param dbType 数据库类型
* @return 删除某些查询项返回的sql
*/
public static String removeSqlSelectItem(String sql, Set needRemoveSelectColumnSet, DbType dbType) throws RuntimeException {
if (CollectionUtils.isEmpty(needRemoveSelectColumnSet)) {
return sql;
}
SQLSelectQueryBlock sqlSelectQueryBlock;
SQLSelectStatement sqlSelectStatement;
SQLSelect sqlSelect;
SQLWithSubqueryClause withSubQueryClause;
List sqlStatements = parseStatements(sql, dbType);
sqlSelectStatement = (SQLSelectStatement) sqlStatements.get(sqlStatements.size() - 1);
sqlSelect = sqlSelectStatement.getSelect();
SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
withSubQueryClause = sqlSelect.getWithSubQuery();
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
} else {
throw new RuntimeException("不支持去除该SQL类型的查询项!该SQL类型为:" + sqlSelectQuery.getClass());
}
List selectItems = sqlSelectQueryBlock.getSelectList();
Iterator itemIterator = selectItems.iterator();
while (itemIterator.hasNext()) {
SQLSelectItem sqlSelectItem = itemIterator.next();
SQLExpr sqlExpr = sqlSelectItem.getExpr();
String alias = SelfDruidSqlUtils.getRealAlias(sqlSelectItem.getAlias());
String needFilterSelectValue = StringUtils.isBlank(alias) ? SQLUtils.toSQLString(sqlExpr) : alias;
if (needRemoveSelectColumnSet.contains(needFilterSelectValue)) {
itemIterator.remove();
}
}
return (withSubQueryClause == null ? "" : SQLUtils.toSQLString(withSubQueryClause)) + "\n" + SQLUtils.toSQLString(sqlSelectQuery, dbType);
}
/**
* 只保留 ,指定某个查询列的sql(慎用,会有很多坑,尽量通过过滤的列实现)
*
* @param sql sql
* @param needSelectColumn 需要选择列
* @param dbType db型
* @return {@link String}
* @throws RuntimeException 一般例外
*/
public static String onlySaveAppointSqlSelectItem(String sql, Set needSelectColumn, DbType dbType) throws RuntimeException {
SQLSelectQueryBlock sqlSelectQueryBlock;
SQLSelectStatement sqlSelectStatement;
SQLSelect sqlSelect;
SQLWithSubqueryClause withSubQueryClause;
List sqlStatements = parseStatements(sql, dbType);
sqlSelectStatement = (SQLSelectStatement) sqlStatements.get(sqlStatements.size() - 1);
sqlSelect = sqlSelectStatement.getSelect();
SQLSelectQuery sqlSelectQuery = sqlSelect.getQuery();
withSubQueryClause = sqlSelect.getWithSubQuery();
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
} else {
throw new RuntimeException("不支持只保留该SQL类型的查询项!该SQL类型为:" + sqlSelectQuery.getClass());
}
List selectItems = sqlSelectQueryBlock.getSelectList();
Iterator iterator = selectItems.iterator();
while (iterator.hasNext()) {
SQLSelectItem sqlSelectItem = iterator.next();
SQLExpr sqlExpr = sqlSelectItem.getExpr();
String alias = SelfDruidSqlUtils.getRealAlias(sqlSelectItem.getAlias());
String needFilterSelectValue = StringUtils.isBlank(alias) ? SQLUtils.toSQLString(sqlExpr) : alias;
// 如果不在指定的列中,则删除;
if (!needSelectColumn.contains(needFilterSelectValue)) {
iterator.remove();
} else {
// 一旦有了 查询了某列,则从指定集合中去掉,方便以免后面添加的时候,又添加一次
needSelectColumn.remove(needFilterSelectValue);
}
}
// 最后添加上所有 没有查询到的列名
Map columnExpressMap;
try {
columnExpressMap = SelfDruidSqlUtils.getColumnExpressMap(sql, dbType);
} catch (RuntimeException e) {
throw new RuntimeException(e.getMessage(), e);
}
needSelectColumn.removeAll(columnExpressMap.keySet());
for (String sqlSelectColumn : needSelectColumn) {
SQLSelectItem sqlSelectItem = new SQLSelectItem(SQLUtils.toSQLExpr(sqlSelectColumn, dbType));
selectItems.add(sqlSelectItem);
}
return (withSubQueryClause == null ? "" : SQLUtils.toSQLString(withSubQueryClause)) + "\n" + SQLUtils.toSQLString(sqlSelectQuery, dbType);
}
/**
* 是否是 union 的sql,一般是union 的sql, 在排序的时候,字段只能为查询出来的字段
*
* @param sqlContent sql内容
* @param dbType druid的数据库类型
* @return boolean
*/
public static boolean isUnionSql(String sqlContent, DbType dbType) {
List sqlStatements = parseStatements(sqlContent, dbType);
SQLSelectQuery sqlSelectQuery = (((SQLSelectStatement) sqlStatements.get(sqlStatements.size() - 1)).getSelect()).getQuery();
return sqlSelectQuery instanceof SQLUnionQuery;
}
/**
* 检查sql中是否包含group by 条件(仅支持,单个select的sql,不支持union all)
*
* @param sql 要检查的sql
* @param dbType 数据库类型
* @return true 包含,false 不包含
**/
public static boolean isContainGroupBy(String sql, DbType dbType) throws RuntimeException {
SQLSelectQueryBlock sqlSelectQueryBlock;
try {
List sqlStatements = parseStatements(sql, dbType);
SQLSelectQuery sqlSelectQuery = (((SQLSelectStatement) sqlStatements.get(sqlStatements.size() - 1)).getSelect()).getQuery();
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
} else if (sqlSelectQuery instanceof SQLUnionQuery) {
return false;
} else {
throw new RuntimeException("无法解析sql!请更换sql或者联系系统管理员咨询支持的sql类型!");
}
SQLSelectGroupByClause selectGroupByClause = sqlSelectQueryBlock.getGroupBy();
return selectGroupByClause != null;
} catch (Exception e) {
log.error("检查sql中是否包含group by 条件,解析SQL错误!要解析的sql为【{}】", sql);
throw new RuntimeException("解析SQL错误!" + e.getMessage(), e);
}
}
/**
* 动态增加SQL的 where 条件(sql 只支持单条的查询sql,支持union all的sql,如果是union all的sql,则每条sql都会添加条件,如果要全局的sql添加条件,在请直接使用使用:SQLSelectStatement.addWhere的方法)
*
* @param sql SQL 语句
* @param condition where 条件
* @param dbType DB类型
* @return {@link String} 增加条件后的sql
*/
public static String addWhereForAllSql(String sql, String condition, DbType dbType) {
TemplateBuilderSqlExpr templateBuilderSqlExpr = handleConditionAndSql(sql, condition, dbType);
if (templateBuilderSqlExpr == null) {
return sql;
}
StringBuilder builder = templateBuilderSqlExpr.getBuilder();
for (int i = 0; i < templateBuilderSqlExpr.getStatements().size(); ++i) {
addWhere(templateBuilderSqlExpr.getStatements().get(i), templateBuilderSqlExpr.getConditionExpr());
builder.append(toSQLString(templateBuilderSqlExpr.getStatements().get(i), dbType));
// 如果是多条sql,并且不是最后一条sql,则添加;
if (i != templateBuilderSqlExpr.getStatements().size() - 1 && templateBuilderSqlExpr.isMultiStatement()) {
builder.append(";");
}
}
return builder.toString();
}
/**
* 动态增加SQL的 where 条件(sql 只支持单条的查询sql,支持union all的sql,如果是union all的sql,则只给全局的sql添加条件
*
* @param sql SQL 语句
* @param condition where 条件
* @param dbType DB类型
* @return {@link String}
*/
public static String addHavingForGlobalSql(String sql, String condition, DbType dbType) {
TemplateBuilderSqlExpr templateBuilderSqlExpr = handleConditionAndSql(sql, condition, dbType);
if (templateBuilderSqlExpr == null) {
return sql;
}
List statements = templateBuilderSqlExpr.getStatements();
for (int i = 0; i < statements.size(); ++i) {
if (statements.get(i) instanceof SQLSelectStatement) {
SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) statements.get(i);
SQLSelectQuery query = sqlSelectStatement.getSelect().getQuery();
// 如果是union all 的sql,则添加两次条件,因为第一次为添加1=1,用select * 包围原来的union all
if (query instanceof SQLUnionQuery || query instanceof SQLSelectQueryBlock) {
addHavingForQuery(condition, dbType, query);
} else {
throw new IllegalArgumentException("要添加having条件的sql为非查询 union all 或普通select sql,不能添加!sql内容为:" + sql);
}
} else {
throw new IllegalArgumentException("要添加having条件的sql为非查询sql,不能添加!");
}
templateBuilderSqlExpr.getBuilder().append(toSQLString(statements.get(i), dbType));
// 如果是多条sql,并且不是最后一条sql,则添加;
if (i != statements.size() - 1 && templateBuilderSqlExpr.isMultiStatement()) {
templateBuilderSqlExpr.getBuilder().append(";");
}
}
return templateBuilderSqlExpr.getBuilder().toString();
}
/**
* 动态增加SQL的 where 条件(sql 只支持单条的查询sql,支持union all的sql,如果是union all的sql,则只给全局的sql添加条件
*
* @param sql SQL 语句
* @param condition where 条件
* @param dbType DB类型
* @return {@link String} 添加条件后的sql
*/
public static String addWhereForGlobalSql(String sql, String condition, DbType dbType) {
TemplateBuilderSqlExpr templateBuilderSqlExpr = handleConditionAndSql(sql, condition, dbType);
if (templateBuilderSqlExpr == null) {
return sql;
}
List statements = templateBuilderSqlExpr.getStatements();
for (int i = 0; i < statements.size(); ++i) {
if (statements.get(i) instanceof SQLSelectStatement) {
SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) statements.get(i);
SQLSelectQuery query = sqlSelectStatement.getSelect().getQuery();
// 如果是union all 的sql,则添加两次条件,因为第一次为添加1=1,用select * 包围原来的union all
if (query instanceof SQLUnionQuery) {
sqlSelectStatement.addWhere(toSQLExpr("1=1"));
sqlSelectStatement.addWhere(templateBuilderSqlExpr.getConditionExpr());
} else if (query instanceof SQLSelectQueryBlock) {
sqlSelectStatement.addWhere(templateBuilderSqlExpr.getConditionExpr());
} else {
throw new IllegalArgumentException("要添加条件的sql为非查询 union all 或普通select sql,不能添加!sql内容为:" + sql);
}
} else {
throw new IllegalArgumentException("要添加条件的sql为非查询sql,不能添加!");
}
templateBuilderSqlExpr.getBuilder().append(toSQLString(statements.get(i), dbType));
// 如果是多条sql,并且不是最后一条sql,则添加;
if (i != statements.size() - 1 && templateBuilderSqlExpr.isMultiStatement()) {
templateBuilderSqlExpr.getBuilder().append(";");
}
}
return templateBuilderSqlExpr.getBuilder().toString();
}
/**
* 增加where 条件(如果是union all的sql,则每条sql都会添加条件,如果要全局的sql添加条件,在请直接使用使用:SQLSelectStatement.addWhere的方法)
*
* @param stmt 查询的stmt
* @param whereCondition where 的条件
**/
public static void addWhere(SQLStatement stmt, SQLExpr whereCondition) {
if (stmt instanceof SQLSelectStatement) {
SQLSelectQuery query = ((SQLSelectStatement) stmt).getSelect().getQuery();
addWhereForSqlSelectQuery(query, whereCondition);
} else {
throw new IllegalArgumentException("add where not support " + stmt.getClass().getName());
}
}
/**
* 动态增加SQL的 group by 条件(sql 只支持单条的查询sql,暂不支持union all的sql)
*
* @param sql SQL 语句
* @param groupByCondition groupBy 条件
* @param dbType DB类型
* @return {@link String} 添加group by条件后的sql
*/
public static String addGroupBy(String sql, String groupByCondition, DbType dbType) {
TemplateBuilderSqlExpr templateBuilderSqlExpr = handleConditionAndSql(sql, groupByCondition, dbType);
if (templateBuilderSqlExpr == null) {
return sql;
}
List statements = templateBuilderSqlExpr.getStatements();
SQLExpr conditionExpr = templateBuilderSqlExpr.getConditionExpr();
boolean isMultiStatement = templateBuilderSqlExpr.isMultiStatement();
StringBuilder builder = new StringBuilder();
for (int i = 0; i < statements.size(); ++i) {
addGroupBy(statements.get(i), conditionExpr);
builder.append(toSQLString(statements.get(i), dbType));
// 如果是多条sql,并且不是最后一条sql,则添加;
if (i != statements.size() - 1 && isMultiStatement) {
builder.append(";");
}
}
return builder.toString();
}
/**
* 获取sql中排序字段的集合, 按照order by字段的顺序返回
*
* @param sql sql
* @param dbType db类型
* @return order by的字段集合
*/
public static Set listOrderByFieldSet(String sql, DbType dbType) {
Set orderByFieldSet = new LinkedHashSet<>(16);
if (StringUtils.isEmpty(sql)) {
return orderByFieldSet;
}
List statements = parseStatements(sql, dbType);
if (CollectionUtils.isEmpty(statements)) {
throw new IllegalArgumentException("要添加sql不包含sql语句!");
}
SQLStatement sqlStatement = statements.iterator().next();
if (sqlStatement instanceof SQLSelectStatement) {
SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) sqlStatement;
SQLSelectQuery query = sqlSelectStatement.getSelect().getQuery();
SQLOrderBy sqlOrderBy;
if (query instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) query;
sqlOrderBy = queryBlock.getOrderBy();
} else if (query instanceof SQLUnionQuery) {
SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) query;
sqlOrderBy = sqlUnionQuery.getOrderBy();
} else {
throw new IllegalArgumentException("add where not support " + query.getClass().getName());
}
if (sqlOrderBy != null) {
List sqlSelectOrderByItems = Optional.ofNullable(sqlOrderBy.getItems()).orElse(new ArrayList<>(0));
orderByFieldSet = sqlSelectOrderByItems.stream().map(a -> SQLUtils.toSQLString(a.getExpr())).collect(Collectors.toCollection(LinkedHashSet::new));
}
} else {
throw new RuntimeException("不支持该sql类型!");
}
return orderByFieldSet;
}
/**
* 动态增加 SQL
*
* @param sql select 查询的SQL 语句
* @param orderByCondition 排序条件
* @param dbType DB类型
* @return 添加排序条件后的sql
*/
public static String addOrderBy(String sql, String orderByCondition, DbType dbType) {
if (StringUtils.isEmpty(sql) || StringUtils.isEmpty(orderByCondition)) {
return sql;
}
List statements = parseStatements(sql, dbType);
if (CollectionUtils.isEmpty(statements)) {
throw new IllegalArgumentException("要添加sql不包含sql语句!");
}
SQLSelectOrderByItem orderByItem = SQLUtils.toOrderByItem(orderByCondition, dbType);
StringBuilder builder = new StringBuilder();
boolean isMultiStatement = statements.size() > 1;
for (int i = 0; i < statements.size(); ++i) {
addOrderBy(statements.get(i), orderByItem);
builder.append(toSQLString(statements.get(i), dbType));
// 如果是多条sql,并且不是最后一条sql,则添加;
if (i != statements.size() - 1 && isMultiStatement) {
builder.append(";");
}
}
return builder.toString();
}
/**
* 给原始的sqlStmt添加排序项
*
* @param stmt 原始的sqlStmt
* @param orderByItem 要排序的项
**/
public static void addOrderBy(SQLStatement stmt, SQLSelectOrderByItem orderByItem) {
if (stmt instanceof SQLSelectStatement) {
SQLSelectQuery query = ((SQLSelectStatement) stmt).getSelect().getQuery();
if (query instanceof SQLSelectQueryBlock) {
addOrderBy(query, orderByItem);
} else if (query instanceof SQLUnionQuery) {
// 如果是union all 查询,则直接给order by 添加条件
SQLOrderBy sqlOrderBy = ((SQLUnionQuery) query).getOrderBy();
if (sqlOrderBy == null) {
sqlOrderBy = new SQLOrderBy();
}
sqlOrderBy.addItem(orderByItem);
((SQLUnionQuery) query).setOrderBy(sqlOrderBy);
}
return;
}
if (stmt instanceof SQLDeleteStatement) {
throw new IllegalArgumentException("add order by not support " + stmt.getClass().getName());
}
if (stmt instanceof SQLUpdateStatement) {
throw new IllegalArgumentException("add order by not support " + stmt.getClass().getName());
}
throw new IllegalArgumentException("add order by not support " + stmt.getClass().getName());
}
/**
* 去掉sql中的order by条件
*
* @param sql 要去掉的sql
* @param dbType 数据库类型
* @return java.lang.String 去掉order by后的sql
**/
public static String clearOrderBy(String sql, DbType dbType) {
List statements = parseStatements(sql, dbType);
if (CollectionUtils.isEmpty(statements)) {
throw new IllegalArgumentException("要添加sql不包含sql语句!");
}
SQLStatement stmt = statements.get(statements.size() - 1);
SQLSelectQuery query;
SQLWithSubqueryClause withSubqueryClause;
if (stmt instanceof SQLSelectStatement) {
withSubqueryClause = ((SQLSelectStatement) stmt).getSelect().getWithSubQuery();
query = ((SQLSelectStatement) stmt).getSelect().getQuery();
clearOrderBy(query);
} else {
throw new IllegalArgumentException("clear order condition not support " + stmt.getClass().getName());
}
return (withSubqueryClause == null ? "" : SQLUtils.toSQLString(withSubqueryClause)) + "\n" + SQLUtils.toSQLString(query, dbType);
}
/**
* 得到总条数的sql,从原sql中中
*
* @param sqlContent 要替换的sql的sql内容
* @param dbType 数据库类型
* @return 替换后的计算总条数sql
* @throws RuntimeException 解析失败则抛出异常
*/
public static String getTotalNumSqlFromSql(String sqlContent, DbType dbType) throws RuntimeException {
if (StringUtils.isEmpty(sqlContent)) {
return sqlContent;
}
sqlContent = SelfDruidSqlUtils.clearOrderBy(sqlContent, dbType);
sqlContent = getLimitSqlFromSql(sqlContent, dbType, 0, Integer.MAX_VALUE);
String returnSql;
if (isContainGroupBy(sqlContent, dbType)) {
returnSql = "select count(1) as " + NUM_STR + " from (" + sqlContent + ") t";
} else {
returnSql = SelfDruidSqlUtils.replaceSelectItem((sqlContent), " count(1) ", NUM_STR, false, dbType);
}
return returnSql;
}
/**
* 得到limit 语句,从原sql中中
*
* @param sqlContent 要替换的sql的sql内容
* @param dbType 数据库类型的枚举类
* @param offset 偏移量
* @param count 每页数量
* @return 替换后的limit sql
*/
public static String getLimitSqlFromSql(String sqlContent, com.alibaba.druid.DbType dbType, int offset, int count) {
if (StringUtils.isEmpty(sqlContent)) {
return sqlContent;
}
return PagerUtils.limit(sqlContent, dbType, offset, count);
}
/**
* 删除 掉变量表达式,
*
* ${start_tpl_var_*} xxxx ${end_tpl_var_*}
*
* 以${start_tpl_var_*} 起始 ${end_tpl_var_*} 结尾 包含配置变量,在查询中如空或未填写,将会过滤掉包含的整段内容
*
* @param str 要删除的字符串
* @return {@link String}
*/
public static String removeVarSubRangeStr(String str) {
// 首先找出有这些的开头和结束的标识字符
List startStrList = findSubStrPattern(str, "\\$\\{start_tpl_var_[a-zA-Z-_0-9]+\\}");
List endStrList = findSubStrPattern(str, "\\$\\{end_tpl_var_[a-zA-Z-_0-9]+\\}");
for (int i = 0; i < Math.min(startStrList.size(), endStrList.size()); ++i) {
str = subRangeString(str, startStrList.get(i), endStrList.get(i));
}
return str;
}
/**
* 为查询添加 having的条件
*
* @param condition 条件
* @param dbType db型
* @param query 查询
*/
private static void addHavingForQuery(String condition, DbType dbType, SQLSelectQuery query) {
SQLSelectQueryBlock sqlSelectQueryBlock = null;
if (query instanceof SQLUnionQuery) {
SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) query;
addHavingForQuery(condition, dbType, sqlUnionQuery.getLeft());
addHavingForQuery(condition, dbType, sqlUnionQuery.getRight());
} else if (query instanceof SQLSelectQueryBlock) {
sqlSelectQueryBlock = (SQLSelectQueryBlock) query;
} else {
throw new IllegalArgumentException("要添加having条件的sql为非查询 union all 或普通select sql,不能添加!sql内容为:" + query);
}
if (sqlSelectQueryBlock == null) {
return;
}
SQLSelectGroupByClause sqlSelectGroupByClause = Optional.ofNullable(sqlSelectQueryBlock.getGroupBy()).orElse(new SQLSelectGroupByClause());
SQLExpr sqlExpr = sqlSelectGroupByClause.getHaving();
SQLExpr newSqlExpr;
if (sqlExpr == null) {
newSqlExpr = SQLUtils.toSQLExpr(condition);
} else {
newSqlExpr = SQLUtils.toSQLExpr("(" + SQLUtils.toSQLString(sqlExpr, dbType) + " ) and (" + condition + ")");
}
sqlSelectGroupByClause.setHaving(newSqlExpr);
sqlSelectQueryBlock.setGroupBy(sqlSelectGroupByClause);
}
/**
* 清除查询的sql中的orderBy条件
*
* @param query sql查询的对象
**/
private static void clearOrderBy(SQLSelectQuery query) {
if (query instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) query;
if (queryBlock.getOrderBy() != null) {
queryBlock.setOrderBy(null);
}
return;
}
if (query instanceof SQLUnionQuery) {
SQLUnionQuery union = (SQLUnionQuery) query;
if (union.getOrderBy() != null) {
union.setOrderBy(null);
}
clearOrderBy(union.getLeft());
clearOrderBy(union.getRight());
}
}
/**
* 增加group by,(sql 只支持单条的查询sql,暂不支持union all的sql)
*
* @param stmt 查询的stmt
* @param groupByCondition group by 的条件
**/
private static void addGroupBy(SQLStatement stmt, SQLExpr groupByCondition) {
if (stmt instanceof SQLSelectStatement) {
SQLSelectQuery query = ((SQLSelectStatement) stmt).getSelect().getQuery();
if (query instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) query;
SQLSelectGroupByClause groupBy = queryBlock.getGroupBy();
if (groupBy == null) {
groupBy = new SQLSelectGroupByClause();
}
groupBy.addItem(groupByCondition);
queryBlock.setGroupBy(groupBy);
} else {
throw new IllegalArgumentException("add groupBy not support " + stmt.getClass().getName());
}
return;
}
if (stmt instanceof SQLDeleteStatement) {
throw new IllegalArgumentException("add groupBy not support " + stmt.getClass().getName());
}
if (stmt instanceof SQLUpdateStatement) {
throw new IllegalArgumentException("add groupBy not support " + stmt.getClass().getName());
}
throw new IllegalArgumentException("add groupBy not support " + stmt.getClass().getName());
}
/**
* 删除某两个子字符串范围
*
* @param body 要删除的字符串
* @param str1 str1
* @param str2 str2
* @return {@link String}
*/
private static String subRangeString(String body, String str1, String str2) {
if (StringUtils.isEmpty(body)) {
return body;
}
while (true) {
int index1 = body.indexOf(str1);
if (index1 != -1) {
int index2 = body.indexOf(str2, index1);
if (index2 != -1) {
body = body.substring(0, index1) + body.substring(index2 + str2.length());
} else {
return body;
}
} else {
return body;
}
}
}
/**
* 根据表达式找出匹配的子字符串
*
* @param needFindString 需要找到字符串
* @param pattern 模式
* @return 匹配的子字符串
*/
private static List findSubStrPattern(String needFindString, String pattern) {
List resultStrList = new ArrayList<>(10);
if (StringUtils.isEmpty(needFindString)) {
return resultStrList;
}
Pattern datePattern = Pattern.compile(pattern);
Matcher dateMatcher = datePattern.matcher(needFindString);
while (dateMatcher.find()) {
resultStrList.add(dateMatcher.group());
}
return resultStrList;
}
/**
* 给 SQLSelectQuery 添加where 条件
*
* @param query 查询
* @param whereCondition 条件
*/
private static void addWhereForSqlSelectQuery(SQLSelectQuery query, SQLExpr whereCondition) {
if (query instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) query;
queryBlock.addWhere(whereCondition);
} else if (query instanceof SQLUnionQuery) {
SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) query;
addWhereForSqlSelectQuery(sqlUnionQuery.getLeft(), whereCondition);
addWhereForSqlSelectQuery(sqlUnionQuery.getRight(), whereCondition);
} else {
throw new IllegalArgumentException("add where not support " + query.getClass().getName());
}
}
/**
* 给selectQuery添加排序项
*
* @param sqlSelectQuery select的查询sql
* @param orderByItem 要排序的项
*/
private static void addOrderBy(SQLSelectQuery sqlSelectQuery, SQLSelectOrderByItem orderByItem) {
if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) sqlSelectQuery;
SQLOrderBy orderBy = queryBlock.getOrderBy();
if (orderBy == null) {
orderBy = new SQLOrderBy();
}
orderBy.addItem(orderByItem);
queryBlock.setOrderBy(orderBy);
} else {
throw new IllegalArgumentException("add order by not support " + sqlSelectQuery.getClass().getName());
}
}
/**
* 处理条件和sql
*
* @param sql sql
* @param condition 条件 内容
* @param dbType db 类型
* @return {@link TemplateBuilderSqlExpr}
*/
private static TemplateBuilderSqlExpr handleConditionAndSql(String sql, String condition, DbType dbType) {
if (StringUtils.isEmpty(sql) || StringUtils.isEmpty(condition)) {
return null;
}
List statements = parseStatements(sql, dbType);
if (CollectionUtils.isEmpty(statements)) {
throw new IllegalArgumentException("要添加sql不包含sql语句!");
}
SQLExpr conditionExpr = toSQLExpr(condition, dbType);
StringBuilder builder = new StringBuilder();
boolean isMultiStatement = statements.size() > 1;
return new TemplateBuilderSqlExpr(statements, conditionExpr, builder, isMultiStatement);
}
/**
* 获得sql 真正的查询别名,例如 a as '中国';解析到别名为:'中国',然后可以调用该方法,获取到中国
*
* @param alias 别名
* @return {@link String}
*/
public static String getRealAlias(String alias) {
if (alias == null || alias.length() == 0) {
return alias;
}
char first = alias.charAt(0);
if (first == '"' || first == '\'' || first == '`') {
char[] chars = new char[alias.length() - 2];
int len = 0;
for (int i = 1; i < alias.length() - 1; ++i) {
char ch = alias.charAt(i);
if (ch == '\\') {
++i;
ch = alias.charAt(i);
}
chars[len++] = ch;
}
return new String(chars, 0, len);
}
return alias;
}
/**
* 临时的 用来存放变量的内部类
*
* @author zhangcanlong
* @since 2021/10/21
*/
@Data
@AllArgsConstructor
static class TemplateBuilderSqlExpr {
private List statements;
private SQLExpr conditionExpr;
private StringBuilder builder;
private boolean multiStatement;
}
}