com.github.dreamroute.pager.starter.interceptor.PagerInterceptor Maven / Gradle / Ivy
package com.github.dreamroute.pager.starter.interceptor;
import static com.github.dreamroute.pager.starter.anno.PagerContainer.ID;
import static com.github.dreamroute.pager.starter.interceptor.ProxyUtil.getOriginObj;
import static java.util.Arrays.stream;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toSet;
import static org.apache.commons.lang3.StringUtils.isEmpty;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import cn.hutool.core.annotation.AnnotationUtil;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.text.CharSequenceUtil;
import com.github.dreamroute.pager.starter.anno.Pager;
import com.github.dreamroute.pager.starter.anno.PagerContainer;
import com.github.dreamroute.pager.starter.anno.PagerContainerBaseInfo;
import com.github.dreamroute.pager.starter.api.PageRequest;
import com.github.dreamroute.pager.starter.exception.PaggerException;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.MappedStatement.Builder;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.transaction.Transaction;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.util.CollectionUtils;
/**
* 分页插件,原理:通过注解标注需要分页的接口方法,拦截该方法,抽取原生sql,然后做如下几个动作:
*
* - 根据原生sql语句生成一个统计的sql,并且执行查询,获得统计结果;
* - 根据上一步结果判断,如果统计不为0,那么改写原生sql,加上分页参数,执行查询操作,获取查询结果;否则无需进行查询直接返回结果
* - 将上述两个结果封装成分页结果;
*
*
* @author w.dehi
*/
@Intercepts({
@Signature(
type = Executor.class,
method = "query",
args = {
MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class
}),
@Signature(
type = Executor.class,
method = "query",
args = {
MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class,
CacheKey.class, BoundSql.class
})
})
@Slf4j
public class PagerInterceptor implements Interceptor, ApplicationListener {
private final ConcurrentHashMap pagerContainer = new ConcurrentHashMap<>();
// 单表
private static final int SINGLE = 1;
private static final String COUNT_NAME = "_$count$_";
private static final String WHERE = " WHERE ";
private static final String FROM = " FROM ";
private Configuration config;
/**
* Spring启动完毕之后,就将需要分页的Mapper抽取出来,存入缓存
*/
public void onApplicationEvent(ContextRefreshedEvent event) {
SqlSessionFactory sqlSessionFactory = event.getApplicationContext().getBean(SqlSessionFactory.class);
config = sqlSessionFactory.getConfiguration();
Collection> mappers = config.getMapperRegistry().getMappers();
if (mappers != null && !mappers.isEmpty()) {
for (Class> mapper : mappers) {
String mapperName = mapper.getName();
stream(mapper.getDeclaredMethods())
.filter(method -> AnnotationUtil.hasAnnotation(method, Pager.class))
.forEach(method -> {
PagerContainerBaseInfo container = new PagerContainerBaseInfo();
String dictinctBy = AnnotationUtil.getAnnotationValue(method, Pager.class, "distinctBy");
if (isNotBlank(dictinctBy)) {
container.setDistinctBy(dictinctBy);
}
pagerContainer.put(mapperName + "." + method.getName(), container);
});
}
}
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
MappedStatement ms = (MappedStatement) invocation.getArgs()[0];
Object param = invocation.getArgs()[1];
// 如果参数不加@Param,那么这里是单个,如果加了,那么这里是个Map,取其一即可
Object objParam = param;
String paramAlias = null;
// 如果是@Param风格,那么需要获取到对象参数以及@Param的value
if (param instanceof ParamMap) {
IllegalArgumentException ex = new IllegalArgumentException(
"接口" + ms.getId() + "参数有误, 分页接口参数必有且仅能有一个,并且是继承了PageRequest的,需要把多个参数封装在一个对象中");
ParamMap> p = (ParamMap>) param;
// 如果不是分页查询,直接返回,避免其他查询走下方流程
boolean pageSelect = p.values().stream().anyMatch(PageRequest.class::isInstance);
if (!pageSelect) {
return invocation.proceed();
}
if (p.size() != 2) {
throw ex;
}
objParam = p.values().stream().findAny().orElseThrow(() -> ex);
paramAlias = p.keySet().stream()
.filter(e -> !e.toLowerCase().matches("param\\d+"))
.findAny()
.orElseThrow(() -> ex);
}
PagerContainerBaseInfo pc = pagerContainer.get(ms.getId());
// 拦截请求的条件:1. @Page标记接口,2.参数是:PageRequest
if (pc == null || !(objParam instanceof PageRequest)) {
return invocation.proceed();
}
BoundSql boundSql = ms.getBoundSql(param);
String beforeSql = boundSql.getSql();
PagerContainer p = parseSql(beforeSql, ms.getId());
List beforePmList = boundSql.getParameterMappings();
p.setOriginPmList(beforePmList);
List afterPmList =
parseParameterMappings(config, beforePmList, paramAlias, p.isSingleTable());
p.setAfterPmList(afterPmList);
Executor executor = (Executor) (getOriginObj(invocation.getTarget()));
Transaction transaction = executor.getTransaction();
// 处理统计信息
BoundSql countBoundSql = new BoundSql(config, p.getCountSql(), p.getOriginPmList(), param);
copyProps(boundSql, countBoundSql, config);
MappedStatement m = new Builder(
config,
ms.getId() + "(分页统计)",
new StaticSqlSource(config, p.getCountSql()),
SqlCommandType.SELECT)
.build();
StatementHandler countHandler =
config.newStatementHandler(executor, m, param, RowBounds.DEFAULT, null, countBoundSql);
Statement countStmt = prepareStatement(transaction, countHandler);
((PreparedStatement) countStmt).execute();
ResultSet rs = countStmt.getResultSet();
ResultWrapper
© 2015 - 2025 Weber Informatics LLC | Privacy Policy