All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.netty.protocol.dubbo.serialization.DefaultSerializeClassChecker Maven / Gradle / Ivy

The newest version!
package com.github.netty.protocol.dubbo.serialization;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.*;

public class DefaultSerializeClassChecker implements AllowClassNotifyListener {
    public static final DefaultSerializeClassChecker INSTANCE = new DefaultSerializeClassChecker(SerializeSecurityManager.INSTANCE);
    private static final long MAGIC_HASH_CODE = 0xcbf29ce484222325L;
    private static final long MAGIC_PRIME = 0x100000001b3L;
    private final SerializeSecurityManager serializeSecurityManager;
    //        private static final ErrorTypeAwareLogger logger =
    //                LoggerFactory.getErrorTypeAwareLogger(DefaultSerializeClassChecker.class);
    private volatile SerializeCheckStatus checkStatus = AllowClassNotifyListener.DEFAULT_STATUS;
    private volatile boolean checkSerializable = true;
    private volatile long[] allowPrefixes = new long[0];

    private volatile long[] disAllowPrefixes = new long[0];

    public DefaultSerializeClassChecker(SerializeSecurityManager manager) {
        serializeSecurityManager = manager;
        serializeSecurityManager.registerListener(this);
    }

    private static long[] loadPrefix(Set allowedList) {
        long[] array = new long[allowedList.size()];

        int index = 0;
        for (String name : allowedList) {
            if (name == null || name.isEmpty()) {
                continue;
            }

            long hashCode = MAGIC_HASH_CODE;
            for (int j = 0; j < name.length(); ++j) {
                char ch = name.charAt(j);
                if (ch == '$') {
                    ch = '.';
                }
                hashCode ^= ch;
                hashCode *= MAGIC_PRIME;
            }

            array[index++] = hashCode;
        }

        if (index != array.length) {
            array = Arrays.copyOf(array, index);
        }
        Arrays.sort(array);
        return array;
    }


    @Override
    public synchronized void notifyPrefix(Set allowedList, Set disAllowedList) {
        this.allowPrefixes = loadPrefix(allowedList);
        this.disAllowPrefixes = loadPrefix(disAllowedList);
    }

    @Override
    public synchronized void notifyCheckStatus(SerializeCheckStatus status) {
        this.checkStatus = status;
    }


    @Override
    public synchronized void notifyCheckSerializable(boolean checkSerializable) {
        this.checkSerializable = checkSerializable;
    }

    /**
     * Try load class
     *
     * @param className class name
     * @throws IllegalArgumentException if class is blocked
     * @return Class
     */
    public Class loadClass(ClassLoader classLoader, String className) throws ClassNotFoundException {
        Class aClass = loadClass0(classLoader, className);
        if (!aClass.isPrimitive() && !Serializable.class.isAssignableFrom(aClass)) {
            String msg = "[Serialization Security] Serialized class " + className
                    + " has not implement Serializable interface. "
                    + "Current mode is strict check, will disallow to deserialize it by default. ";
            if (serializeSecurityManager.getWarnedClasses()
                    .add(className)) {
                //                    logger.error(PROTOCOL_UNTRUSTED_SERIALIZE_CLASS, "", "", msg);
            }

            if (checkSerializable) {
                throw new IllegalArgumentException(msg);
            }
        }

        return aClass;
    }

    private Class loadClass0(ClassLoader classLoader, String className) throws ClassNotFoundException {
        if (checkStatus == SerializeCheckStatus.DISABLE) {
            return ClassUtils.forName(className, classLoader);
        }

        long hash = MAGIC_HASH_CODE;
        for (int i = 0, typeNameLength = className.length(); i < typeNameLength; ++i) {
            char ch = className.charAt(i);
            if (ch == '$') {
                ch = '.';
            }
            hash ^= ch;
            hash *= MAGIC_PRIME;

            if (Arrays.binarySearch(allowPrefixes, hash) >= 0) {
                return ClassUtils.forName(className, classLoader);
            }
        }

        if (checkStatus == SerializeCheckStatus.STRICT) {
            String msg = "[Serialization Security] Serialized class " + className + " is not in allow list. "
                    + "Current mode is `STRICT`, will disallow to deserialize it by default. "
                    + "Please add it into security/serialize.allowlist or follow FAQ to configure it.";
            if (serializeSecurityManager.getWarnedClasses()
                    .add(className)) {
                //                    logger.error(PROTOCOL_UNTRUSTED_SERIALIZE_CLASS, "", "", msg);
            }

            throw new IllegalArgumentException(msg);
        }

        hash = MAGIC_HASH_CODE;
        for (int i = 0, typeNameLength = className.length(); i < typeNameLength; ++i) {
            char ch = className.charAt(i);
            if (ch == '$') {
                ch = '.';
            }
            hash ^= ch;
            hash *= MAGIC_PRIME;

            if (Arrays.binarySearch(disAllowPrefixes, hash) >= 0) {
                String msg = "[Serialization Security] Serialized class " + className + " is in disallow list. "
                        + "Current mode is `WARN`, will disallow to deserialize it by default. "
                        + "Please add it into security/serialize.allowlist or follow FAQ to configure it.";
                if (serializeSecurityManager.getWarnedClasses()
                        .add(className)) {
                    //                        logger.warn(PROTOCOL_UNTRUSTED_SERIALIZE_CLASS, "", "", msg);
                }

                throw new IllegalArgumentException(msg);
            }
        }

        hash = MAGIC_HASH_CODE;
        for (int i = 0, typeNameLength = className.length(); i < typeNameLength; ++i) {
            char ch = Character.toLowerCase(className.charAt(i));
            if (ch == '$') {
                ch = '.';
            }
            hash ^= ch;
            hash *= MAGIC_PRIME;

            if (Arrays.binarySearch(disAllowPrefixes, hash) >= 0) {
                String msg = "[Serialization Security] Serialized class " + className + " is in disallow list. "
                        + "Current mode is `WARN`, will disallow to deserialize it by default. "
                        + "Please add it into security/serialize.allowlist or follow FAQ to configure it.";
                if (serializeSecurityManager.getWarnedClasses()
                        .add(className)) {
                    //                        logger.warn(PROTOCOL_UNTRUSTED_SERIALIZE_CLASS, "", "", msg);
                }

                throw new IllegalArgumentException(msg);
            }
        }

        Class clazz = ClassUtils.forName(className, classLoader);
        if (serializeSecurityManager.getWarnedClasses()
                .add(className)) {
            //                logger.warn(
            //                        PROTOCOL_UNTRUSTED_SERIALIZE_CLASS,
            //                        "",
            //                        "",
            //                        "[Serialization Security] Serialized class " + className + " is not in
            //                        allow list. "
            //                                + "Current mode is `WARN`, will allow to deserialize it by default. "
            //                                + "Dubbo will set to `STRICT` mode by default in the future. "
            //                                + "Please add it into security/serialize.allowlist or follow FAQ to
            //                                configure it.");
        }
        return clazz;
    }

    public boolean isCheckSerializable() {
        return checkSerializable;
    }


    public static class ClassUtils {
        /**
         * Suffix for array class names: "[]"
         */
        public static final String ARRAY_SUFFIX = "[]";
        /**
         * Prefix for internal array class names: "[L"
         */
        private static final String INTERNAL_ARRAY_PREFIX = "[L";
        /**
         * Map with primitive type name as key and corresponding primitive type as value, for example: "int" ->
         * "int.class".
         */
        private static final Map> PRIMITIVE_TYPE_NAME_MAP = new HashMap<>(32);
        /**
         * Map with primitive wrapper type as key and corresponding primitive type as value, for example: Integer.class
         * -> int.class.
         */
        private static final Map, Class> PRIMITIVE_WRAPPER_TYPE_MAP = new HashMap<>(16);

        static {
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Boolean.class, boolean.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Byte.class, byte.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Character.class, char.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Double.class, double.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Float.class, float.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Integer.class, int.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Long.class, long.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Short.class, short.class);
            PRIMITIVE_WRAPPER_TYPE_MAP.put(Void.class, void.class);

            Set> primitiveTypeNames = new HashSet<>(32);
            primitiveTypeNames.addAll(PRIMITIVE_WRAPPER_TYPE_MAP.values());
            primitiveTypeNames.addAll(Arrays.asList(boolean[].class, byte[].class, char[].class, double[].class,
                    float[].class, int[].class, long[].class, short[].class));
            for (Class primitiveTypeName : primitiveTypeNames) {
                PRIMITIVE_TYPE_NAME_MAP.put(primitiveTypeName.getName(), primitiveTypeName);
            }
        }

        public static ClassLoader getClassLoader(Class clazz) {
            ClassLoader cl = null;
            if (!clazz.getName()
                    .startsWith("org.apache.dubbo")) {
                cl = clazz.getClassLoader();
            }
            if (cl == null) {
                try {
                    cl = Thread.currentThread()
                            .getContextClassLoader();
                } catch (Exception ignored) {
                    // Cannot access thread context ClassLoader - falling back to system class loader...
                }
                if (cl == null) {
                    // No thread context class loader -> use class loader of this class.
                    cl = clazz.getClassLoader();
                    if (cl == null) {
                        // getClassLoader() returning null indicates the bootstrap ClassLoader
                        try {
                            cl = ClassLoader.getSystemClassLoader();
                        } catch (Exception ignored) {
                            // Cannot access system ClassLoader - oh well, maybe the caller can live with null...
                        }
                    }
                }
            }

            return cl;
        }

        /**
         * Return the default ClassLoader to use: typically the thread context ClassLoader, if available; the
         * ClassLoader that loaded the ClassUtils class will be used as fallback.
         * 

* Call this method if you intend to use the thread context ClassLoader in a scenario where you absolutely need * a non-null ClassLoader reference: for example, for class path resource loading (but not necessarily for * Class.forName, which accepts a null ClassLoader * reference as well). * * @return the default ClassLoader (never null) * @see Thread#getContextClassLoader() */ public static ClassLoader getClassLoader() { return getClassLoader(Hessian2FactoryManager.class); } public static Class forName(String name) throws ClassNotFoundException { return forName(name, getClassLoader()); } /** * Replacement for Class.forName() that also returns Class instances for primitives (like "int") * and array class names (like "String[]"). * * @param name the name of the Class * @param classLoader the class loader to use (may be null, which indicates the default class * loader) * @return Class instance for the supplied name * @throws ClassNotFoundException if the class was not found * @throws LinkageError if the class file could not be loaded * @see Class#forName(String, boolean, ClassLoader) */ public static Class forName( String name, ClassLoader classLoader) throws ClassNotFoundException, LinkageError { Class clazz = resolvePrimitiveClassName(name); if (clazz != null) { return clazz; } // "java.lang.String[]" style arrays if (name.endsWith(ARRAY_SUFFIX)) { String elementClassName = name.substring(0, name.length() - ARRAY_SUFFIX.length()); Class elementClass = forName(elementClassName, classLoader); return Array.newInstance(elementClass, 0) .getClass(); } // "[Ljava.lang.String;" style arrays int internalArrayMarker = name.indexOf(INTERNAL_ARRAY_PREFIX); if (internalArrayMarker != -1 && name.endsWith(";")) { String elementClassName = null; if (internalArrayMarker == 0) { elementClassName = name.substring(INTERNAL_ARRAY_PREFIX.length(), name.length() - 1); } else if (name.startsWith("[")) { elementClassName = name.substring(1); } Class elementClass = forName(elementClassName, classLoader); return Array.newInstance(elementClass, 0) .getClass(); } ClassLoader classLoaderToUse = classLoader; if (classLoaderToUse == null) { classLoaderToUse = getClassLoader(); } return classLoaderToUse.loadClass(name); } /** * Resolve the given class name as primitive class, if appropriate, according to the JVM's naming rules for * primitive classes. *

* Also supports the JVM's internal class names for primitive arrays. Does * not support the "[]" suffix notation for primitive arrays; this is * only supported by {@link #forName}. * * @param name the name of the potentially primitive class * @return the primitive class, or null if the name does not denote a primitive class or primitive * array class */ public static Class resolvePrimitiveClassName(String name) { Class result = null; // Most class names will be quite long, considering that they // SHOULD sit in a package, so a length check is worthwhile. if (name != null && name.length() <= 8) { // Could be a primitive - likely. result = PRIMITIVE_TYPE_NAME_MAP.get(name); } return result; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy