Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.github.dreamroute.pager.starter.interceptor.PagerInterceptor Maven / Gradle / Ivy
package com.github.dreamroute.pager.starter.interceptor;
import cn.hutool.core.annotation.AnnotationUtil;
import com.github.dreamroute.pager.starter.anno.Pager;
import com.github.dreamroute.pager.starter.anno.PagerContainer;
import com.github.dreamroute.pager.starter.api.PageRequest;
import com.github.dreamroute.pager.starter.exception.PaggerException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.transaction.Transaction;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.util.CollectionUtils;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import static com.github.dreamroute.pager.starter.anno.PagerContainer.ID;
import static com.github.dreamroute.pager.starter.interceptor.ProxyUtil.getOriginObj;
import static java.util.Arrays.stream;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toSet;
/**
* 分页插件
*
* @author w.dehi
*/
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
})
public class PagerInterceptor implements Interceptor, ApplicationListener {
private ConcurrentHashMap pagerContainer;
/**
* 单表
*/
private static final int SINGLE = 1;
private static final String COUNT_NAME = "_$count$_";
private static final String WHERE = " WHERE ";
private static final String FROM = " FROM ";
@Override
public void onApplicationEvent(ContextRefreshedEvent event) {
SqlSessionFactory sqlSessionFactory = event.getApplicationContext().getBean(SqlSessionFactory.class);
// 将此方法移动到Spring容器初始化之后执行的原因是:如果放在下方的intercept方法中来执行,
// 那么就会有并发问题(获取ms的sqlSource然后修改sqlSource),那么就需要对该方法加锁,影响性能
parsePagerContainer(sqlSessionFactory.getConfiguration());
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object param = args[1];
Configuration config = ms.getConfiguration();
PagerContainer pc = pagerContainer.get(ms.getId());
// 拦截请求的条件:1. @Page标记接口,2.参数是:PageRequest
if (pc == null || !(param instanceof PageRequest)) {
return invocation.proceed();
}
Executor executor = (Executor) (getOriginObj(invocation.getTarget()));
Transaction transaction = executor.getTransaction();
Connection conn = transaction.getConnection();
String count = pc.getCount();
PreparedStatement ps = conn.prepareStatement(count);
BoundSql countBoundSql = new BoundSql(config, count, pc.getOriginPmList(), param);
ParameterHandler parameterHandler = config.newParameterHandler(ms, param, countBoundSql);
parameterHandler.setParameters(ps);
ResultSet rs = ps.executeQuery();
PageContainer container = new PageContainer<>();
while (rs.next()) {
long totle = rs.getLong(COUNT_NAME);
container.setTotal(totle);
}
ps.close();
PageRequest> pr = (PageRequest>) param;
int pageNum = pr.getPageNum();
int pageSize = pr.getPageSize();
// 由于不希望在pageRequest中增加start参数,所以limit时改变pageNum来代替start,因此resp的pageNum需要在设置start之前进行设置
container.setPageNum(pageNum);
container.setPageSize(pr.getPageSize());
int start = (pageNum - 1) * pageSize;
pr.setPageNum(start);
if (container.getTotal() != 0) {
Object result = invocation.proceed();
List> ls = (List>) result;
container.addAll(ls);
}
return container;
}
/**
* 这里将pagerContainer进行缓存,由于分页插件的执行sql是固定的,所以可以缓存
*/
private void parsePagerContainer(Configuration config) {
if (pagerContainer == null) {
pagerContainer = new ConcurrentHashMap<>();
parseAnno(config);
updateSqlSource(config);
}
}
private void parseAnno(Configuration config) {
Collection> mappers = config.getMapperRegistry().getMappers();
if (mappers != null && !mappers.isEmpty()) {
for (Class> mapper : mappers) {
String mapperName = mapper.getName();
stream(mapper.getDeclaredMethods()).filter(method -> AnnotationUtil.hasAnnotation(method, Pager.class)).forEach(method -> {
String dictinctBy = AnnotationUtil.getAnnotationValue(method, Pager.class, "distinctBy");
PagerContainer container = new PagerContainer();
container.setDistinctBy(StringUtils.isEmpty(dictinctBy) ? ID : dictinctBy);
pagerContainer.put(mapperName + "." + method.getName(), container);
});
}
}
}
private void updateSqlSource(Configuration config) {
pagerContainer.keySet().stream().map(config::getMappedStatement).forEach(ms -> {
MetaObject mo = config.newMetaObject(ms);
String beforeSql = (String) mo.getValue("sqlSource.sqlSource.sql");
String afterSql = parseSql(beforeSql, ms.getId());
mo.setValue("sqlSource.sqlSource.sql", afterSql);
@SuppressWarnings("unchecked") List beforePmList = (List) mo.getValue("sqlSource.sqlSource.parameterMappings");
// 这里把原始的pmList保存起来,count的时会用到
pagerContainer.get(ms.getId()).setOriginPmList(beforePmList);
List afterPmList = parseParameterMappings(config, ms.getId(), beforePmList);
mo.setValue("sqlSource.sqlSource.parameterMappings", afterPmList);
});
}
private List parseParameterMappings(Configuration config, String id, List pmList) {
List result = new ArrayList<>(ofNullable(pmList).orElseGet(ArrayList::new));
result.add(new ParameterMapping.Builder(config, "pageNum", int.class).build());
result.add(new ParameterMapping.Builder(config, "pageSize", int.class).build());
// 多表情况下:由于插件改写sql会在sql的末尾增加一次查询条件,所以这里需要在sql末尾再次增加一次查询条件
if (!pagerContainer.get(id).isSingleTable()) {
result.addAll(ofNullable(pmList).orElseGet(ArrayList::new));
}
return result;
}
private String parseSql(String sql, String id) {
PagerContainer container = pagerContainer.get(id);
Select select;
String afterSql;
try {
select = (Select) CCJSqlParserUtil.parse(sql);
} catch (Exception e) {
throw new PaggerException("SQL语句异常,你的sql语句是: [" + sql + "]", e);
}
List tableList = new TablesNamesFinder().getTableList(select);
PlainSelect body = (PlainSelect) select.getSelectBody();
String columns = body.getSelectItems().stream().map(Object::toString).collect(joining(","));
String from = body.getFromItem().toString();
String where = ofNullable(body.getWhere()).map(Object::toString).orElse("");
if (tableList != null && tableList.size() == SINGLE) {
where = StringUtils.isNotBlank(where) ? (WHERE + where) : "";
sql = "SELECT " + columns + FROM + from + where;
container.setCount("SELECT COUNT(*) " + COUNT_NAME + " FROM (" + sql + ") _$_t");
String orderBy = ofNullable(body.getOrderByElements()).orElseGet(ArrayList::new).stream().map(Objects::toString).collect(joining(", "));
orderBy = StringUtils.isNoneBlank(orderBy) ? (" ORDER BY " + orderBy) : "";
afterSql = sql + orderBy + " LIMIT ?, ?";
container.setSingleTable(true);
} else {
String joins = body.getJoins().stream().map(Object::toString).collect(joining(" "));
String alias = "";
String distinctBy = container.getDistinctBy();
if (distinctBy.indexOf('.') != -1) {
alias = distinctBy.split("\\.")[0];
}
// 如果order by不为空,那么子查询的查询列需要将order by列也带上,否则H2会报错(order by列需要在查询列中),MySQL则不会
String orderBy = "";
String subQueryColumns = "";
List orderbyList = body.getOrderByElements();
if (!CollectionUtils.isEmpty(orderbyList)) {
orderBy = " ORDER BY " + orderbyList.stream().map(Object::toString).collect(joining(", "));
// order by列和主表id列重复,需要去重
Set orderbyListStr = orderbyList.stream().map(OrderByElement::getExpression).map(Objects::toString).collect(toSet());
orderbyListStr.add(distinctBy);
subQueryColumns = String.join(", ", orderbyListStr);
}
String afterFrom = FROM + from + " " + joins + WHERE + where;
String subQuery = "SELECT DISTINCT " + subQueryColumns + afterFrom;
String noCondition = "SELECT " + columns + FROM + from + " " + joins + " ";
String result = noCondition + WHERE + distinctBy + " IN (SELECT " + distinctBy + " FROM (" + subQuery + orderBy + " LIMIT ?, ?) " + alias + ")";
if (StringUtils.isNoneBlank(where)) {
result = result + " AND " + where;
}
afterSql = result + orderBy;
String count = "SELECT count(DISTINCT " + distinctBy + ") " + COUNT_NAME + afterFrom;
container.setCount(count);
}
return afterSql;
}
}