org.iartisan.runtime.jdbc.PaginationInterceptor Maven / Gradle / Ivy
package org.iartisan.runtime.jdbc;
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.executor.resultset.DefaultResultSetHandler;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeHandler;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.iartisan.runtime.bean.Page;
import org.iartisan.runtime.jdbc.annotations.Pagination;
import org.iartisan.runtime.jdbc.dialects.MySQLDialect;
import org.iartisan.runtime.utils.CollectionUtil;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.*;
import java.util.List;
import java.util.Map;
import java.util.Properties;
/**
*
* mysql 分页插件
*
* @author King
* @since 2017/6/19
*/
@Deprecated
@Intercepts({
@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class})
})
public class PaginationInterceptor implements Interceptor {
//目前只支持mysql的分页查询
private MySQLDialect mySQLDialect = MySQLDialect.newInstance();
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object target = invocation.getTarget();
if (target instanceof StatementHandler) {
RoutingStatementHandler routingStatementHandler = (RoutingStatementHandler) target;
MetaObject metaObject = MetaObject.forObject(routingStatementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
if (mappedStatement.getSqlCommandType().equals(SqlCommandType.SELECT)) {
String mapperId = mappedStatement.getId();
Class clazz = refelectMapperClass(mapperId);
String methodName = refelectMethodName(mapperId);
//判断该方法是否需要分页操作
if (needPagination(clazz, methodName)) {
Connection connection = (Connection) invocation.getArgs()[0];
Object parameterObject = routingStatementHandler.getParameterHandler().getParameterObject();
Page page = null;
if (parameterObject instanceof Page) {
page = (Page) parameterObject;
} else if (parameterObject instanceof Map) {
Map objectMap = (Map) parameterObject;
//入参需要标注 @Param("page")
page = (Page) objectMap.get("page");
}
if (null == page) {
//如果参数中没有Page对象则报错
throw new IllegalArgumentException("分页方法中缺少:org.iartisan.runtime.bean.Page 对象");
}
BoundSql boundSql = routingStatementHandler.getBoundSql();
String querySQL = mySQLDialect.buildPaginationSQL(boundSql.getSql(), (page.getCurrPage() - 1) * page.getPageSize(), page.getPageSize());
//计算count条数
countTotal(mappedStatement, boundSql, page, connection);
//禁用内存分页
metaObject.setValue("delegate.boundSql.sql", querySQL);
metaObject.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
metaObject.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
}
}
}
return invocation.proceed();
}
private static final String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s ) TOTAL";
/**
* 计算总条数
*/
private void countTotal(MappedStatement mappedStatement, BoundSql boundSql, Page page, Connection connection) throws SQLException, IllegalAccessException {
PreparedStatement statement = connection.prepareStatement(String.format(SQL_BASE_COUNT, boundSql.getSql()));
setParameters(mappedStatement, statement, boundSql);
int total = 0;
try (ResultSet resultSet = statement.executeQuery()) {
if (resultSet.next()) {
total = resultSet.getInt(1);
}
}
page.setTotalRecords(total);
}
private Field getAdditionalParametersField() {
try {
Field additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
additionalParametersField.setAccessible(true);
return additionalParametersField;
} catch (NoSuchFieldException e) {
// ignored, Because it will never happen.
}
return null;
}
public void setParameters(MappedStatement mappedStatement, PreparedStatement ps, BoundSql boundSql) throws SQLException, IllegalAccessException {
final Object parameterObject = boundSql.getParameterObject();
final TypeHandlerRegistry typeHandlerRegistry = mappedStatement.getConfiguration().getTypeHandlerRegistry();
Configuration configuration = mappedStatement.getConfiguration();
// 反射获取动态参数
Map additionalParameters = null;
additionalParameters = (Map) getAdditionalParametersField().get(boundSql);
ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
List parameterMappings = boundSql.getParameterMappings();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
if (boundSql.hasAdditionalParameter(propertyName)) {//issue#448 ask first for additional params
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
value = parameterObject;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = metaObject.getValue(propertyName);
if (value == null && CollectionUtil.isNotEmpty(additionalParameters)) {
value = additionalParameters.get(propertyName);
}
}
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
jdbcType = configuration.getJdbcTypeForNull();
}
typeHandler.setParameter(ps, i + 1, value, jdbcType);
}
}
}
}
private Class refelectMapperClass(String sqlId) throws ClassNotFoundException {
int mapperPosition = sqlId.lastIndexOf(".");
Class clazz = Class.forName(sqlId.substring(0, mapperPosition));
return clazz;
}
private String refelectMethodName(String sqlId) {
int mapperPosition = sqlId.lastIndexOf(".");
String methodName = sqlId.substring(mapperPosition + 1);
return methodName;
}
public boolean needPagination(Class clazz, String methodName) {
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (method.getName().equals(methodName)) {
Pagination pagination = method.getAnnotation(Pagination.class);
if (null != pagination) {
return true;
}
break;
}
}
return false;
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}