All Downloads are FREE. Search and download functionalities are using the official Maven repository.

pro.shuangxi.devTool.mybatis.MyBatisSqlInterceptor Maven / Gradle / Ivy

There is a newer version: 1.0.6
Show newest version
package pro.shuangxi.devTool.mybatis;

import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import pro.shuangxi.devTool.utils.LogsUtils;

/**
 * 自定义 MyBatis 日志插件
 */
@Intercepts({
        @Signature(
                type = StatementHandler.class,//拦截器需要拦截的接口,有 4 个可选项,分别是:Executor、ParameterHandler、ResultSetHandler 以及 StatementHandler
                method = "query",//拦截器所拦截接口中的方法名,也就是前面四个接口中的方法名,接口和方法要对应上。
                args = {Statement.class,ResultHandler.class}//拦截器所拦截方法的参数类型,通过方法名和参数类型可以锁定唯一一个方法。
        )
})
@Slf4j
public class MyBatisSqlInterceptor implements Interceptor {

//    @Autowired(required = false)
//    List preProcessors;

    public MyBatisSqlInterceptor() {
        log.debug("MyBatisSqlInterceptor插件加载完毕");
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        List logs = new ArrayList<>();
//        for (PreProcessor processor : preProcessors) {
//            processor.doProcess(logs,invocation);
//        }
        try {
            StatementHandler handler = (StatementHandler) invocation.getTarget();
            String sql = handler.getBoundSql().getSql();
            List parameterMappings = handler.getBoundSql().getParameterMappings();
            Object paramObject = handler.getBoundSql().getParameterObject();
            List params = new ArrayList<>();
            List paramsClass = new ArrayList<>();

            if (paramObject instanceof ParamMap) {
                for (int i = 0; i < parameterMappings.size(); i++) {
                    ParameterMapping mapping = parameterMappings.get(i);
                    String property = mapping.getProperty();
                    Map map = (Map) paramObject;

                    params.add(BeanUtils.getProperty(map, property));
                    paramsClass.add(String.class);
                }

            } else if (paramObject instanceof Map) {
                for (int i = 0; i < parameterMappings.size(); i++) {
                    ParameterMapping mapping = parameterMappings.get(i);
                    String property = mapping.getProperty();
                    Map map = (Map) paramObject;
                    params.add(String.valueOf(map.get(property)));
                    paramsClass.add(mapping.getJavaType());
                }
            }else if (paramObject instanceof String) {
                for (int i = 0; i < parameterMappings.size(); i++) {
                    params.add(String.valueOf(paramObject));
                    paramsClass.add(parameterMappings.get(i).getJavaType()
                    );
                }
            } else if (!Objects.isNull(paramObject)) {
                params.add(String.valueOf(paramObject));
                paramsClass.add(parameterMappings.get(0).getJavaType());
            }

            for (int i = 0; i < params.size(); i++) {
                String param = params.get(i);
                Class clazz = paramsClass.get(i);
                switch (clazz.getName()) {
                    case "java.lang.String":
                        sql = sql.replaceFirst("\\?", "'" + param + "'");
                        break;
                    default:
                        sql = sql.replaceFirst("\\?", param);
                        break;

                }
            }
            List list = getTableNames(sql);
            logs.add("-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->");
            logs.add("这是涉及的数据库表:");
            for (String table : list) {
                logs.add(table);
            }
            logs.add("-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->-->");
            logs.add("这是SQL:");
            logs.add(sql);

        } catch (Exception e) {
            log.error("插件报错:");
        }

        LogsUtils.doLog(logs);
        return invocation.proceed();
    }
 
    @Override
    public Object plugin(Object target) {
        return Interceptor.super.plugin(target);
    }
 
 
    @Override
    public void setProperties(Properties properties) {
        Interceptor.super.setProperties(properties);
    }


    public static List getTableNames(String sql) {
        sql = sql.toUpperCase();
        List tableNames = new ArrayList<>();


        // 匹配FROM语句中的表名
        String fromPatternStr = "\\bFROM\\b\\s+([a-zA-Z0-9_]+)(?:\\s|$)";
        Pattern fromPattern = Pattern.compile(fromPatternStr, Pattern.CASE_INSENSITIVE);
        Matcher fromMatcher = fromPattern.matcher(sql);
        while (fromMatcher.find()) {
            String tableName = fromMatcher.group(1);
            tableNames.add(tableName);
        }

        // 匹配JOIN语句中的表名
        String joinPatternStr = "\\bJOIN\\b\\s+([a-zA-Z0-9_]+)(?:\\s|$)";
        Pattern joinPattern = Pattern.compile(joinPatternStr, Pattern.CASE_INSENSITIVE);
        Matcher joinMatcher = joinPattern.matcher(sql);
        while (joinMatcher.find()) {
            String tableName = joinMatcher.group(1);
            tableNames.add(tableName);
        }

        // 匹配INSERT语句中的表名
        String insertPatternStr = "\\bINSERT\\b\\s+\\bINTO\\b\\s+([a-zA-Z0-9_]+)(?:\\s|$)";
        Pattern insertPattern = Pattern.compile(insertPatternStr, Pattern.CASE_INSENSITIVE);
        Matcher insertMatcher = insertPattern.matcher(sql);
        while (insertMatcher.find()) {
            String tableName = insertMatcher.group(1);
            tableNames.add(tableName);
        }

        return tableNames;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy