All Downloads are FREE. Search and download functionalities are using the official Maven repository.

nl.weeaboo.lua2.io.ObjectSerializer Maven / Gradle / Ivy

package nl.weeaboo.lua2.io;

import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import javax.annotation.Nullable;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Writes Lua objects to a binary stream.
 *
 * @see LuaSerializer
 */
public class ObjectSerializer extends ObjectOutputStream {

    /**
     * Problem level.
     */
    public enum ErrorLevel {
        NONE, WARNING, ERROR
    }

    private static final Logger LOG = LoggerFactory.getLogger(ObjectSerializer.class);

    private final @Nullable Environment env; // Null if empty or not used
    private final Set validPackages = new HashSet<>();
    private final Set> validClasses = new HashSet<>();
    private final ExecutorService executor;

    private final List errors = new ArrayList<>();
    private final List warnings = new ArrayList<>();
    private final Map, Stats> classCounter = new IdentityHashMap<>();

    private ErrorLevel packageErrorLevel = ErrorLevel.ERROR;
    private boolean collectStats = true;
    private boolean checkTypes;

    protected ObjectSerializer(OutputStream out, Environment e) throws IOException {
        super(out);

        env = (e.size() != 0 ? e : null);
        executor = new DelayedIoExecutor("LuaObjectSerializer");

        resetValidPackages();
        resetValidClasses();

        onPackageLimitChanged();
    }

    @Override
    public void close() throws IOException {
        try {
            super.close();
        } finally {
            executor.shutdown();
        }
    }

    private static String toErrorString(String[] errors) {
        StringBuilder sb = new StringBuilder();
        sb.append(errors.length).append(" error(s) occurred while writing objects:");

        int t = 1;
        for (String err : errors) {
            sb.append('\n');
            sb.append(t);
            sb.append(": ");
            sb.append(err);
            t++;
        }
        return sb.toString();
    }

    /**
     * Clears the list and returns its former contents.
     */
    private static String[] consume(Collection list) {
        String[] result = list.toArray(new String[list.size()]);
        list.clear();
        return result;
    }

    /**
     * @return An array containing all warnings encountered during serialization.
     * @throws IOException if any fatal errors were encountered.
     */
    public String[] checkErrors() throws IOException {
        String[] errors = consume(this.errors);
        String[] warnings = consume(this.warnings);

        if (errors.length > 0) {
            throw new RuntimeException(toErrorString(errors));
        }

        if (collectStats) {
            Entry, Stats>[] entries = classCounter.entrySet().toArray(new Entry[0]);
            Arrays.sort(entries, new Comparator, Stats>>() {
                @Override
                public int compare(Entry, Stats> e1, Entry, Stats> e2) {
                    return -e1.getValue().compareTo(e2.getValue());
                }
            });
            for (Entry, Stats> entry : entries) {
                LOG.debug("[stats] {}: {}", entry.getKey().getName(), entry.getValue());
            }
        }

        return warnings;
    }

    @Override
    protected @Nullable Object replaceObject(Object obj) {
        Class clazz = obj.getClass();

        // Updating stats
        if (collectStats) {
            Stats stats = classCounter.get(clazz);
            if (stats == null) {
                stats = new Stats();
                classCounter.put(clazz, stats);
            }
            stats.count++;
        }

        // Environment
        if (env != null) {
            String id = env.getId(obj);
            if (id != null) {
                return new RefEnvironment(id);
            }
        }

        if (checkTypes) {
            // Whitelisted types
            if (clazz.getAnnotation(LuaSerializable.class) != null) {
                return obj; // Whitelist types with the LuaSerializable annotation
            } else if (clazz.isArray()) {
                return obj; // Whitelist array types
            } else if (clazz.isEnum()) {
                return obj; // Whitelist enum types
            }

            if (packageErrorLevel != ErrorLevel.NONE) {
                if (!isValidClass(clazz)) {

                    String message = "Class outside valid packages: " + clazz.getName() + " :: " + obj;
                    if (packageErrorLevel == ErrorLevel.ERROR) {
                        errors.add(message);
                        return null; // Don't serialize object in case of error
                    } else if (packageErrorLevel == ErrorLevel.WARNING) {
                        warnings.add(message);
                    }
                }
            }
        }

        return obj;
    }

    private boolean isValidClass(Class clazz) {
        if (validClasses.contains(clazz)) {
            return true;
        }

        // Check if this package is a valid package (or a sub-package of a valid package)
        String packageName = clazz.getPackage().getName();
        if (validPackages.contains(packageName)) {
            return true;
        }

        return false;
    }

    /**
     * Calls {@link ObjectOutputStream#writeObject(Object)} on a new thread.
     * 

* This method can be used to avoid stack space issues when serializing large object graphs. * * @throws IOException If the thread throws an exception, or if the wait for the thread to finish is interrupted. */ public void writeObjectOnNewThread(final Object obj) throws IOException { Future future = executor.submit(createAsyncWriteTask(obj)); try { future.get(); } catch (InterruptedException e) { throw new IOException("Async write interrupted: " + e); } catch (ExecutionException e) { throw new IOException("Error during async write", e.getCause()); } } protected Callable createAsyncWriteTask(final Object obj) { return new Callable() { @Override public Void call() throws IOException { writeObject(obj); return null; } }; } private void resetValidPackages() { validPackages.clear(); validPackages.add("java.util"); validPackages.add("java.util.atomic"); validPackages.add("java.util.concurrent"); } private void resetValidClasses() { validClasses.clear(); Collections.>addAll(validClasses, Boolean.class, Byte.class, Short.class, Integer.class, Long.class, Float.class, Double.class, String.class, Class.class, Random.class, BitSet.class); } private void onPackageLimitChanged() { checkTypes = (packageErrorLevel != ErrorLevel.NONE); updateEnableReplace(); } private void updateEnableReplace() { boolean replace = (env != null || checkTypes || collectStats); try { enableReplaceObject(replace); } catch (SecurityException se) { LOG.error("Error calling 'enableReplaceObject'", se); } } /** * Determines the behavior when a non-allowed class is written. * @see #setAllowedPackages(Collection) * @see #setAllowedClasses(Collection) */ public ErrorLevel getPackageErrorLevel() { return packageErrorLevel; } /** * @see #getPackageErrorLevel() */ public void setPackageErrorLevel(ErrorLevel el) { if (packageErrorLevel != el) { packageErrorLevel = el; onPackageLimitChanged(); } } /** * Defines a set of packages that may be written. *

* Every class serialized must belong to either an allowed package, be an allowed class, or have the * {@link LuaSerializable} annotation. * * @see #setAllowedClasses(Collection) * @see #getPackageErrorLevel() */ public void setAllowedPackages(Collection packages) { resetValidPackages(); validPackages.addAll(packages); } /** * Defines a set of classes that may be written. *

* Every class serialized must belong to either an allowed package, be an allowed class, or have the * {@link LuaSerializable} annotation. * * @see #setAllowedPackages(Collection) * @see #getPackageErrorLevel() */ public void setAllowedClasses(Collection> classes) { resetValidClasses(); validClasses.addAll(classes); } /** * If {@code true}, tracks various statistics during use and warns if certain values (primarily stack * depth) become dangerously large. */ public boolean getCollectStats() { return collectStats; } /** * @see #getCollectStats() */ public void setCollectStats(boolean enable) { if (collectStats != enable) { collectStats = enable; updateEnableReplace(); } } // Inner Classes private static class Stats implements Comparable { public int count; @Override public int compareTo(Stats s) { return (count < s.count ? -1 : (count == s.count ? 0 : 1)); } @Override public String toString() { return Integer.toString(count); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy