All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.ideaaedi.mybatis.data.security.support.EncryptParser Maven / Gradle / Ivy

There is a newer version: 1.4.3-mp3.5.1
Show newest version
package com.ideaaedi.mybatis.data.security.support;

import com.ideaaedi.mybatis.data.security.annotation.Encrypt;
import com.ideaaedi.mybatis.data.security.enums.TypeEnum;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSessionFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StopWatch;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * 加解密解析器
 *
 * @author JustryDeng
 * @since 2021/2/11 11:41:24
 */
public class EncryptParser implements SmartInitializingSingleton {
    
    private static final Logger log = LoggerFactory.getLogger(EncryptParser.class);
    
    /** 临时的class-method缓存(启动项目完毕后,会清除) */
    @SuppressWarnings("rawtypes")
    private static final Map TMP_CLAZZ_METHOD_CACHE = new ConcurrentHashMap<>(64);
    
    /** MappedStatementId-EncryptInfoHolder缓存 */
    private final Map statementIdAndEncryptInfoCache  = new ConcurrentHashMap<>(128);
    
    private final List sqlSessionFactoryList;

    private final EncryptOop encryptOop;
    
    private final DecryptOop decryptOop;
    
    public EncryptParser(@Autowired ApplicationContext applicationContext, @Autowired EncryptExecutor encryptExecutor) {
        encryptOop = new EncryptOop(encryptExecutor);
        decryptOop = new DecryptOop(encryptExecutor);
        // 兼容低版本spring的写法
        Map maps = applicationContext.getBeansOfType(SqlSessionFactory.class);
        //noinspection ConstantConditions
        if (maps == null) {
            this.sqlSessionFactoryList = new ArrayList<>(1);
        } else {
            this.sqlSessionFactoryList = new ArrayList<>(maps.values());
        }
    }
    
    /**
     * 定位EncryptInfoHolder
     *
     * @param mappedStatementId
     *            sql对应的MappedStatement的id
     * @return  mappedStatementId对应的加解密信息详情类
     */
    @Nullable
    public EncryptInfoHolder determineEncryptInfoHolder(String mappedStatementId) {
        return statementIdAndEncryptInfoCache.get(mappedStatementId);
    }
    
    /**
     * 加密
     */
    public Object doEncrypt(Object parameter, EncryptInfoHolder encryptInfoHolder) {
        return encryptOop.doEncrypt(parameter, encryptInfoHolder);
    }
    
    /**
     * 解密
     */
    public Object doDecrypt(Object rowResult, EncryptInfoHolder encryptInfoHolder) {
        return decryptOop.doDecrypt(rowResult, encryptInfoHolder);
    }
    
    public List getSqlSessionFactoryList() {
        return sqlSessionFactoryList;
    }
    
    @Override
    public void afterSingletonsInstantiated() {
        // step0. 获取mybatis配置
        StopWatch stopWatch = new StopWatch("encrypt parser");
        stopWatch.start("get configuration");
        
        List configurationList = sqlSessionFactoryList.stream().map(SqlSessionFactory::getConfiguration).collect(Collectors.toList());
        log.info("[EncryptParser] find org.apache.ibatis.session.Configuration count {}.", configurationList.size());
        stopWatch.stop();
        // step1. 获取sql对应的MappedStatement的Id
        stopWatch.start("get mappedStatementIdSet");
        Set mappedStatementIdSet = configurationList.stream().flatMap(x -> {
            Collection mappedStatements = x.getMappedStatements();
            Object[] objects = mappedStatements.toArray();
            List list = new ArrayList<>(16);
            for (Object object : objects) {
                if (object instanceof MappedStatement) {
                    list.add((MappedStatement) object);
                }
            }
            return list.stream();
        }).map(MappedStatement::getId).collect(Collectors.toSet());
        // 过滤掉那些自动生成的方法,如:为回填主键而自动生成的sql,如:com.aspire.ssm.test.mapper.Abc.Mapper.insertData!selectKey
        mappedStatementIdSet = mappedStatementIdSet.stream().filter(x -> !x.contains("!")).collect(Collectors.toSet());
        log.info("[EncryptParser] find org.apache.ibatis.mapping.MappedStatement {}.", mappedStatementIdSet);
        stopWatch.stop();
        // step2. 获取sql对应的Method-MappedStatement的Id map
        stopWatch.start("get methodStatementIdMap,methodSet");
        Map methodStatementIdMap = mappedStatementIdSet.stream().collect(
                Collectors.toMap(this::determineMethodByMappedStatementId, Function.identity()));
        Set methodSet = methodStatementIdMap.keySet();
        stopWatch.stop();
        // step3. 获取对应的注解
        stopWatch.start("get params,return info");
        Map methodParamsMap = methodSet.stream().collect(Collectors.toMap(Function.identity(), Method::getParameters));
        Map methodParamAnnotationMap = methodSet.stream().collect(Collectors.toMap(Function.identity(), Method::getParameterAnnotations));
        Map> methodReturnMap = methodSet.stream().collect(Collectors.toMap(Function.identity(), Method::getReturnType));
        stopWatch.stop();
        // step4. 解析加密信息(其中,会对对已有代码进行一些强约束校验)
        stopWatch.start("get encryptInfoHolderList");
        List encryptInfoHolderList = methodSet.stream().map(method ->
                EncryptInfoHolder.Factory.create(
                        methodStatementIdMap.get(method),
                        method,
                        methodParamsMap.get(method),
                        methodParamAnnotationMap.get(method),
                        methodReturnMap.get(method)
                )
        ).collect(Collectors.toList());
        stopWatch.stop();
        // step5. 后处理收尾
        stopWatch.start("do completely");
        Map encryptDecryptInfoMap = encryptInfoHolderList.stream()
                .collect(Collectors.toMap(EncryptInfoHolder::getMappedStatementId, Function.identity()));
        // 存放加密信息
        statementIdAndEncryptInfoCache.putAll(encryptDecryptInfoMap);
        log.info("[EncryptParser] parse end. Obtain encryptDecryptInfoMap {}.", encryptDecryptInfoMap);
        // 清除临时缓存
        TMP_CLAZZ_METHOD_CACHE.clear();
        stopWatch.stop();
        log.info("[EncryptParser] time-consuming statistics {}.", stopWatch);
    }
    
    /**
     * 获取mappedStatementId对应的方法
     *
     * @param mappedStatementId
     *            sql对应的MappedStatement的Id
     * @return  mappedStatementId对应的方法
     */
    private Method determineMethodByMappedStatementId(String mappedStatementId) {
        try {
            int lastDotIndex = mappedStatementId.lastIndexOf(".");
            Class targetClass = Class.forName(mappedStatementId.substring(0, lastDotIndex));
            Method[] declareMethods = TMP_CLAZZ_METHOD_CACHE.computeIfAbsent(targetClass, Class::getDeclaredMethods);
    
            String targetMethodName = mappedStatementId.substring(lastDotIndex + 1);
            List targetMethodList = Arrays.stream(declareMethods).filter(x -> targetMethodName.equals(x.getName())).collect(Collectors.toList());
            int size = targetMethodList.size();
            if (size != 1) {
                throw new IllegalStateException(String.format("except find [%s] 1, but actual find %s", mappedStatementId, size));
            }
            return targetMethodList.get(0);
        }catch (ClassNotFoundException e) {
            throw new IllegalStateException(e);
        }
    }
    
    /**
     * 加密相关方法
     * 

* 注:启动项目时,在{@link EncryptParser#afterSingletonsInstantiated()}中做了一些强约束,所以在运行时进行加密时,可以跳过一些没必要的校验,以提升性能 *

* * @author JustryDeng * @since 2021/2/11 16:55:24 */ public static class EncryptOop { private final EncryptExecutor encryptExecutor; private EncryptOop(EncryptExecutor encryptExecutor) { this.encryptExecutor = encryptExecutor; } /** * 加密 * * @param parameter * 到mybatis插件时的参数。即:{@link Executor#update(MappedStatement, Object)}的第二个参数 * @param encryptInfoHolder * 对应的加密信息 * @return 加密后的parameter */ public Object doEncrypt(Object parameter, EncryptInfoHolder encryptInfoHolder) { if (parameter == null) { return null; } TypeEnum typeEnum = TypeEnum.parseType(parameter.getClass()); if (typeEnum == TypeEnum.PRIMITIVE_OR_WRAPPER || typeEnum ==TypeEnum.STRING) { // 当parameter是这两种场景时,是不需要加密的 ( // 提示:@Encrypt使用在ElementType.PARAMETER前时,已强制要求同时使用@Param注解指定名称,此时的parameter类型是Map而非直接是String) return parameter; } List encryptBeanInfoList = encryptInfoHolder.getEncryptBeanInfoList(); EncryptInfoHolder.ParamEncryptDetailInfo encryptParamInfo = encryptInfoHolder.getEncryptParamInfo(); Object newParameter; switch (typeEnum) { case MAP: //noinspection unchecked newParameter = mapEncrypt((Map) parameter, encryptBeanInfoList, encryptParamInfo); break; case COLLECTION: //noinspection unchecked newParameter = collectionEncrypt((Collection) parameter, encryptBeanInfoList); break; case ARRAY: newParameter = arrayEncrypt((Object[]) parameter, encryptBeanInfoList); break; case CUSTOM_BEAN: newParameter = pojoEncrypt(parameter, encryptBeanInfoList); break; default: newParameter = parameter; } return newParameter; } /** * string加密 * * @param strName * 字段名。注:当@Encrypt应用于ElementType.PARAMETER上时,此为@Param指定的名称。 * @param strValue * 待加密的字段值 * @param annotation * 加密注解信息 * @param extParam * 字段所在的当前对象。 注: 当@Encrypt应用于ElementType.PARAMETER上时,extParam为null * @return 加密后的字符串 */ private String stringEncrypt(String strName, String strValue, Encrypt annotation, Object extParam) { if (StringUtils.isEmpty(strValue)) { return strValue; } if (extParam == null) { return encryptExecutor.encryptParameter(strName, strValue, annotation); } else { return encryptExecutor.encryptField(strName, strValue, annotation, extParam); } } /** * pojo加密 */ private T pojoEncrypt(T originPojo, List encryptBeanInfoList) { if (originPojo == null) { return originPojo; } final T returnPojo; if (originPojo instanceof PojoCloneable) { //noinspection unchecked returnPojo = (T)((PojoCloneable) originPojo).clonePojo(); } else { returnPojo = originPojo; } EncryptInfoHolder.BeanEncryptDetailInfo beanEncryptDetailInfo = encryptBeanInfoList.stream() .filter(x -> x.getBeanClass() == returnPojo.getClass()).findFirst().orElse(null); if (beanEncryptDetailInfo != null) { Map fieldEncryptMap = beanEncryptDetailInfo.getFieldEncryptMap(); fieldEncryptMap.forEach(((field, encrypt) -> { Object oldValue = null; try { oldValue = FieldUtils.readField(field, returnPojo, true); if (oldValue == null) { return; } String oldValueStr = oldValue.toString(); if (StringUtils.isEmpty(oldValueStr)) { return; } FieldUtils.writeField(field, returnPojo, stringEncrypt(field.getName(), oldValueStr, encrypt, returnPojo), true); } catch (IllegalAccessException e) { throw new RuntimeException(String.format("handle field [%s] occur exception. oldValue is %s", field, oldValue), e); } })); } return returnPojo; } /** * map加密 */ @SuppressWarnings("unchecked") private Map mapEncrypt(Map map, List encryptBeanInfoList, EncryptInfoHolder.ParamEncryptDetailInfo encryptParamInfo) { if (CollectionUtils.isEmpty(map)) { return map; } Map paramEncryptMap = encryptParamInfo.getParamEncryptMap(); MapperMethod.ParamMap paramMap = new MapperMethod.ParamMap<>(); paramMap.putAll(map); String key; Object value; boolean valueIsString; // 避免同一个map中的value重复加密 final Set alreadyEncryptedSet = new HashSet<>(8); for (Map.Entry entry : map.entrySet()) { key = entry.getKey(); value = entry.getValue(); // value 是null 或者 是基础类型获其包装类 if (value == null || ClassUtils.isPrimitiveOrWrapper(value.getClass())) { continue; } if (alreadyEncryptedSet.contains(value)) { continue; } // value是字符串 valueIsString = value instanceof String; if (valueIsString) { if (paramEncryptMap.containsKey(key)) { paramMap.put(key, stringEncrypt(key, (String)value, paramEncryptMap.get(key), null)); } continue; } // value是map if (value instanceof Map) { // 只有第一层Map待会有@Param的信息,后面层的map是不会有@Param的,为了避免误加密里层map,所以这里传一个EMPTY进去即可 value = mapEncrypt((Map) value, encryptBeanInfoList, EncryptInfoHolder.ParamEncryptDetailInfo.EMPTY); paramMap.put(key, value); alreadyEncryptedSet.add(value); continue; } // value是collection if (value instanceof Collection) { value = collectionEncrypt((Collection)value, encryptBeanInfoList); paramMap.put(key, value); alreadyEncryptedSet.add(value); continue; } // value是array if (value instanceof Object[]) { Object newValue = arrayEncrypt((Object[])value, encryptBeanInfoList); paramMap.put(key, newValue); alreadyEncryptedSet.add(value); continue; } // value是Enum,那么不处理 if (value instanceof Enum) { continue; } // 上述情况都不成立,那么value是普通pojo value = pojoEncrypt(value, encryptBeanInfoList); paramMap.put(key, value); alreadyEncryptedSet.add(value); } alreadyEncryptedSet.clear(); return paramMap; } /** * 集合加密 */ private Collection collectionEncrypt(Collection collection, List encryptBeanInfoList) { ArrayList list = new ArrayList<>(collection.size()); if (CollectionUtils.isEmpty(collection)) { return list; } // 统一泛型,一般的,第一个子元素的类型就能代表所有子元素的类型 Object subElement = collection.iterator().next(); Objects.requireNonNull(subElement, "exist null element in collection " + collection); TypeEnum typeEnum = TypeEnum.parseType(subElement.getClass()); if (typeEnum != TypeEnum.CUSTOM_BEAN) { return collection; } // 只处理普通业务bean collection.forEach(pojo -> list.add(pojoEncrypt(pojo, encryptBeanInfoList))); return list; } /** * 数组加密 */ private T[] arrayEncrypt(@NonNull T[] array, List encryptBeanInfoList) { if (array.length == 0) { return array; } // 统一泛型,一般的,第一个子元素的类型就能代表所有子元素的类型 Object subElement = array[0]; Objects.requireNonNull(subElement, "exist null element in array " + Arrays.toString(array)); TypeEnum typeEnum = TypeEnum.parseType(subElement.getClass()); if (typeEnum != TypeEnum.CUSTOM_BEAN) { return array; } // 只处理普通业务bean for (int i = 0; i < array.length; i++) { array[i] = pojoEncrypt(array[i], encryptBeanInfoList); } return array; } } /** * 解密相关方法 *

* 注:启动项目时,在{@link EncryptParser#afterSingletonsInstantiated()}中做了一些强约束,所以在运行时进行解密时,可以跳过一些没必要的校验,以提升性能 *

* * @author JustryDeng * @since 2021/2/11 16:55:24 */ public static class DecryptOop { private final EncryptExecutor encryptExecutor; private DecryptOop(EncryptExecutor encryptExecutor) { this.encryptExecutor = encryptExecutor; } /** * 解密 * * @param rowResult * sql查询出来的原生结果 * @param encryptInfoHolder * 对应的解密信息 * @return 解密后的结果 */ public Object doDecrypt(Object rowResult, EncryptInfoHolder encryptInfoHolder) { if (rowResult == null) { return null; } TypeEnum typeEnum = TypeEnum.parseType(rowResult.getClass()); EncryptInfoHolder.BeanEncryptDetailInfo decryptBeanInfo = encryptInfoHolder.getDecryptBeanInfo(); Object newResult; String rowResultStr = null; if (log.isTraceEnabled()) { rowResultStr = rowResult.toString(); } switch (typeEnum) { case PRIMITIVE_OR_WRAPPER: case STRING: case SYSTEM_BEAN: return rowResult; case MAP: //noinspection unchecked newResult = mapDecrypt((Map) rowResult, decryptBeanInfo); break; case COLLECTION: //noinspection unchecked newResult = collectionDecrypt((Collection) rowResult, decryptBeanInfo); break; case ARRAY: newResult = arrayDecrypt((Object[]) rowResult, decryptBeanInfo); break; case CUSTOM_BEAN: newResult = pojoDecrypt(rowResult, decryptBeanInfo); break; default: throw new IllegalArgumentException("Cannot support for typeEnum [" + typeEnum + "]"); } if (log.isTraceEnabled()) { log.trace("doDecrypt from [{}] to [{}]", rowResultStr, newResult); } return newResult; } /** * string解密 * * @param strName * 字段名。注:当@Encrypt应用于ElementType.PARAMETER上时,此为@Param指定的名称。 * @param strValue * 待解密的字段值 * @param annotation * 解密注解信息 * @param extParam * 字段所在的当前对象。 * @return 解密后的字符串 */ private String stringDecrypt(String strName, String strValue, Encrypt annotation, Object extParam) { if (StringUtils.isEmpty(strValue) || extParam == null) { return strValue; } return encryptExecutor.decryptField(strName, strValue, annotation, extParam); } /** * pojo解密 */ private T pojoDecrypt (T originPojo, EncryptInfoHolder.BeanEncryptDetailInfo decryptBeanInfo) { if (originPojo == null) { return originPojo; } final T returnPojo; if (originPojo instanceof PojoCloneable) { //noinspection unchecked returnPojo = (T)((PojoCloneable) originPojo).clonePojo(); } else { returnPojo = originPojo; } Map fieldEncryptMap = decryptBeanInfo.getFieldEncryptMap(); fieldEncryptMap.forEach(((field, encrypt) -> { Object oldValue = null; try { oldValue = FieldUtils.readField(field, returnPojo, true); if (oldValue == null) { return; } String oldValueStr = oldValue.toString(); if (StringUtils.isEmpty(oldValueStr)) { return; } FieldUtils.writeField(field, returnPojo, stringDecrypt(field.getName(), oldValueStr, encrypt, returnPojo), true); } catch (IllegalAccessException e) { throw new RuntimeException(String.format("handle field [%s] occur exception. oldValue is %s", field, oldValue), e); } })); return returnPojo; } /** * map解密 */ private Map mapDecrypt(Map map, EncryptInfoHolder.BeanEncryptDetailInfo decryptBeanInfo) { if (CollectionUtils.isEmpty(map)) { return map; } map.forEach((k, v) -> { Object newObj = pojoDecrypt(v, decryptBeanInfo); // 同一个引用,其实这里不显式的赋值也是可以的 map.put(k, newObj); }); return map; } /** * 集合解密 */ private Collection collectionDecrypt(Collection collection, EncryptInfoHolder.BeanEncryptDetailInfo decryptBeanInfo) { if (CollectionUtils.isEmpty(collection)) { return collection; } ArrayList list = new ArrayList<>(collection.size()); // 只处理普通业务bean collection.forEach(pojo -> list.add(pojoDecrypt(pojo, decryptBeanInfo))); return list; } /** * 数组解密 */ private T[] arrayDecrypt(@NonNull T[] array, EncryptInfoHolder.BeanEncryptDetailInfo decryptBeanInfo) { if (array.length == 0) { return array; } // 只处理普通业务bean for (int i = 0; i < array.length; i++) { array[i] = pojoDecrypt(array[i], decryptBeanInfo); } return array; } } }