pro.shuangxi.devTool.mybatis.MyBatisSqlInterceptor Maven / Gradle / Ivy
package pro.shuangxi.devTool.mybatis;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import pro.shuangxi.devTool.utils.LogsUtils;
/**
* 自定义 MyBatis 日志插件
*/
@Intercepts({
@Signature(
type = StatementHandler.class,//拦截器需要拦截的接口,有 4 个可选项,分别是:Executor、ParameterHandler、ResultSetHandler 以及 StatementHandler
method = "query",//拦截器所拦截接口中的方法名,也就是前面四个接口中的方法名,接口和方法要对应上。
args = {Statement.class,ResultHandler.class}//拦截器所拦截方法的参数类型,通过方法名和参数类型可以锁定唯一一个方法。
)
})
@Slf4j
public class MyBatisSqlInterceptor implements Interceptor {
// @Autowired(required = false)
// List preProcessors;
public MyBatisSqlInterceptor() {
log.debug("MyBatisSqlInterceptor插件加载完毕");
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
List logs = new ArrayList<>();
// for (PreProcessor processor : preProcessors) {
// processor.doProcess(logs,invocation);
// }
try {
StatementHandler handler = (StatementHandler) invocation.getTarget();
String sql = handler.getBoundSql().getSql();
List parameterMappings = handler.getBoundSql().getParameterMappings();
Object paramObject = handler.getBoundSql().getParameterObject();
List params = new ArrayList<>();
List paramsClass = new ArrayList<>();
if (paramObject instanceof ParamMap) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping mapping = parameterMappings.get(i);
String property = mapping.getProperty();
Map map = (Map) paramObject;
params.add(BeanUtils.getProperty(map, property));
paramsClass.add(String.class);
}
} else if (paramObject instanceof Map) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping mapping = parameterMappings.get(i);
String property = mapping.getProperty();
Map map = (Map) paramObject;
params.add(String.valueOf(map.get(property)));
paramsClass.add(mapping.getJavaType());
}
}else if (paramObject instanceof String) {
for (int i = 0; i < parameterMappings.size(); i++) {
params.add(String.valueOf(paramObject));
paramsClass.add(parameterMappings.get(i).getJavaType()
);
}
} else if (!Objects.isNull(paramObject)) {
params.add(String.valueOf(paramObject));
paramsClass.add(parameterMappings.get(0).getJavaType());
}
for (int i = 0; i < params.size(); i++) {
String param = params.get(i);
Class clazz = paramsClass.get(i);
switch (clazz.getName()) {
case "java.lang.String":
sql = sql.replaceFirst("\\?", "'" + param + "'");
break;
default:
sql = sql.replaceFirst("\\?", param);
break;
}
}
List list = getTableNames(sql);
logs.add("-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->");
logs.add("这是涉及的数据库表:");
for (String table : list) {
logs.add(table);
}
logs.add("-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->");
logs.add("这是SQL:");
logs.add(sql);
} catch (Exception e) {
log.error("插件报错:");
}
LogsUtils.doLog(logs);
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Interceptor.super.plugin(target);
}
@Override
public void setProperties(Properties properties) {
Interceptor.super.setProperties(properties);
}
public static List getTableNames(String sql) {
sql = sql.toUpperCase();
List tableNames = new ArrayList<>();
// 匹配FROM语句中的表名
String fromPatternStr = "\\bFROM\\b\\s+([a-zA-Z0-9_]+)(?:\\s|$)";
Pattern fromPattern = Pattern.compile(fromPatternStr, Pattern.CASE_INSENSITIVE);
Matcher fromMatcher = fromPattern.matcher(sql);
while (fromMatcher.find()) {
String tableName = fromMatcher.group(1);
tableNames.add(tableName);
}
// 匹配JOIN语句中的表名
String joinPatternStr = "\\bJOIN\\b\\s+([a-zA-Z0-9_]+)(?:\\s|$)";
Pattern joinPattern = Pattern.compile(joinPatternStr, Pattern.CASE_INSENSITIVE);
Matcher joinMatcher = joinPattern.matcher(sql);
while (joinMatcher.find()) {
String tableName = joinMatcher.group(1);
tableNames.add(tableName);
}
// 匹配INSERT语句中的表名
String insertPatternStr = "\\bINSERT\\b\\s+\\bINTO\\b\\s+([a-zA-Z0-9_]+)(?:\\s|$)";
Pattern insertPattern = Pattern.compile(insertPatternStr, Pattern.CASE_INSENSITIVE);
Matcher insertMatcher = insertPattern.matcher(sql);
while (insertMatcher.find()) {
String tableName = insertMatcher.group(1);
tableNames.add(tableName);
}
return tableNames;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy