All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.logging.log4j.util.internal.SerializationUtil Maven / Gradle / Ivy

There is a newer version: 6.1.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.logging.log4j.util.internal;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.StreamCorruptedException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import org.apache.logging.log4j.status.StatusLogger;
import org.apache.logging.log4j.util.FilteredObjectInputStream;

/**
 * Provides methods to increase the safety of object serialization/deserialization.
 */
public final class SerializationUtil {

    private static final String DEFAULT_FILTER_CLASS =
            "org.apache.logging.log4j.util.internal.DefaultObjectInputFilter";
    private static final Method setObjectInputFilter;
    private static final Method getObjectInputFilter;
    private static final Method newObjectInputFilter;

    static {
        Method[] methods = ObjectInputStream.class.getMethods();
        Method setMethod = null;
        Method getMethod = null;
        for (final Method method : methods) {
            if (method.getName().equals("setObjectInputFilter")) {
                setMethod = method;
            } else if (method.getName().equals("getObjectInputFilter")) {
                getMethod = method;
            }
        }
        Method newMethod = null;
        try {
            if (setMethod != null) {
                final Class clazz = Class.forName(DEFAULT_FILTER_CLASS);
                methods = clazz.getMethods();
                for (final Method method : methods) {
                    if (method.getName().equals("newInstance") && Modifier.isStatic(method.getModifiers())) {
                        newMethod = method;
                        break;
                    }
                }
            }
        } catch (final ClassNotFoundException ex) {
            // Ignore the exception
        }
        newObjectInputFilter = newMethod;
        setObjectInputFilter = setMethod;
        getObjectInputFilter = getMethod;
    }

    public static final List REQUIRED_JAVA_CLASSES = Arrays.asList(
            "java.math.BigDecimal",
            "java.math.BigInteger",
            // for Message delegate
            "java.rmi.MarshalledObject",
            // all primitives
            "boolean",
            "byte",
            "char",
            "double",
            "float",
            "int",
            "long",
            "short");

    public static final List REQUIRED_JAVA_PACKAGES =
            Arrays.asList("java.lang.", "java.time.", "java.util.", "org.apache.logging.log4j.");

    public static void writeWrappedObject(final Serializable obj, final ObjectOutputStream out) throws IOException {
        final ByteArrayOutputStream bout = new ByteArrayOutputStream();
        try (final ObjectOutputStream oos = new ObjectOutputStream(bout)) {
            oos.writeObject(obj);
            oos.flush();
            out.writeObject(bout.toByteArray());
        }
    }

    @SuppressFBWarnings(
            value = "OBJECT_DESERIALIZATION",
            justification =
                    "Object deserialization uses either Java 9 native filter or our custom filter to limit the kinds of classes deserialized.")
    public static Object readWrappedObject(final ObjectInputStream in) throws IOException, ClassNotFoundException {
        assertFiltered(in);
        final byte[] data = (byte[]) in.readObject();
        final ByteArrayInputStream bin = new ByteArrayInputStream(data);
        final ObjectInputStream ois;
        if (in instanceof FilteredObjectInputStream) {
            ois = new FilteredObjectInputStream(bin, ((FilteredObjectInputStream) in).getAllowedClasses());
        } else {
            try {
                final Object obj = getObjectInputFilter.invoke(in);
                final Object filter = newObjectInputFilter.invoke(null, obj);
                ois = new ObjectInputStream(bin);
                setObjectInputFilter.invoke(ois, filter);
            } catch (IllegalAccessException | InvocationTargetException ex) {
                throw new StreamCorruptedException("Unable to set ObjectInputFilter on stream");
            }
        }
        try {
            return ois.readObject();
        } catch (final Exception | LinkageError e) {
            StatusLogger.getLogger().warn("Ignoring {} during deserialization", e.getMessage());
            return null;
        } finally {
            ois.close();
        }
    }

    public static void assertFiltered(final java.io.ObjectInputStream stream) {
        if (!(stream instanceof FilteredObjectInputStream) && setObjectInputFilter == null) {
            throw new IllegalArgumentException(
                    "readObject requires a FilteredObjectInputStream or an ObjectInputStream that accepts an ObjectInputFilter");
        }
    }

    /**
     * Gets the class name of an array component recursively.
     * 

* If {@code clazz} is not an array class its name is returned. *

* @param clazz the binary name of a class. */ public static String stripArray(final Class clazz) { Class currentClazz = clazz; while (currentClazz.isArray()) { currentClazz = currentClazz.getComponentType(); } return currentClazz.getName(); } /** * Gets the class name of an array component recursively. *

* If {@code name} is not the name of an array class it is returned unchanged. *

* @param name the name of a class. * @see Class#getName() */ public static String stripArray(final String name) { final int offset = name.lastIndexOf('[') + 1; if (offset == 0) { return name; } // Reference types if (name.charAt(offset) == 'L') { return name.substring(offset + 1, name.length() - 1); } // Primitive classes switch (name.substring(offset)) { case "Z": return "boolean"; case "B": return "byte"; case "C": return "char"; case "D": return "double"; case "F": return "float"; case "I": return "int"; case "J": return "long"; case "S": return "short"; default: throw new IllegalArgumentException("Unsupported array class signature '" + name + "'"); } } private SerializationUtil() {} }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy