All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.common.rpc.RemoteExceptionCause Maven / Gradle / Ivy

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2017 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.common.rpc;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;

import org.wildfly.common.Assert;
import org.wildfly.common._private.CommonMessages;

/**
 * A remote exception cause.  Instances of this class are intended to aid with diagnostics and are not intended to be
 * directly thrown.  They may be added to other exception types as a cause or suppressed throwable.
 *
 * @author David M. Lloyd
 */
public final class RemoteExceptionCause extends Throwable {
    private static final long serialVersionUID = 7849011228540958997L;

    private static final ClassValue>> fieldGetterValue = new ClassValue>>() {
        protected Function> computeValue(final Class type) {
            final Field[] fields = type.getFields();
            final int length = fields.length;
            int i, j;
            for (i = 0, j = 0; i < length; i ++) {
                if ((fields[i].getModifiers() & (Modifier.STATIC | Modifier.PUBLIC)) == Modifier.PUBLIC) {
                    fields[j ++] = fields[i];
                }
            }
            final int finalLength = j;
            final Field[] finalFields;
            if (j < i) {
                finalFields = Arrays.copyOf(fields, j);
            } else {
                finalFields = fields;
            }
            if (finalLength == 0) {
                return t -> Collections.emptyMap();
            } else if (finalLength == 1) {
                final Field field = finalFields[0];
                return t -> {
                    try {
                        return Collections.singletonMap(field.getName(), String.valueOf(field.get(t)));
                    } catch (IllegalAccessException e) {
                        // impossible
                        throw new IllegalStateException(e);
                    }
                };
            }
            return t -> {
                Map map = new TreeMap<>();
                for (Field field : finalFields) {
                    try {
                        map.put(field.getName(), String.valueOf(field.get(t)));
                    } catch (IllegalAccessException e) {
                        // impossible
                        throw new IllegalStateException(e);
                    }
                }
                return Collections.unmodifiableMap(map);
            };
        }
    };
    private static final StackTraceElement[] EMPTY_STACK = new StackTraceElement[0];

    private final String exceptionClassName;
    private final Map fields;
    private transient String toString;

    RemoteExceptionCause(final String msg, final RemoteExceptionCause cause, final String exceptionClassName, final Map fields, boolean cloneFields) {
        super(msg);
        if (cause != null) {
            initCause(cause);
        }
        Assert.checkNotNullParam("exceptionClassName", exceptionClassName);
        this.exceptionClassName = exceptionClassName;
        if (cloneFields) {
            final Iterator> iterator = fields.entrySet().iterator();
            if (! iterator.hasNext()) {
                this.fields = Collections.emptyMap();
            } else {
                final Map.Entry e1 = iterator.next();
                final String name1 = e1.getKey();
                final String value1 = e1.getValue();
                if (name1 == null || value1 == null) {
                    throw CommonMessages.msg.cannotContainNullFieldNameOrValue();
                }
                if (! iterator.hasNext()) {
                    this.fields = Collections.singletonMap(name1, value1);
                } else {
                    Map map = new TreeMap<>();
                    map.put(name1, value1);
                    do {
                        final Map.Entry next = iterator.next();
                        map.put(next.getKey(), next.getValue());
                    } while (iterator.hasNext());
                    this.fields = Collections.unmodifiableMap(map);
                }
            }
        } else {
            this.fields = fields;
        }
    }

    /**
     * Constructs a new {@code RemoteExceptionCause} instance with an initial message.  No
     * cause is specified.
     *
     * @param msg the message
     * @param exceptionClassName the name of the exception's class (must not be {@code null})
     */
    public RemoteExceptionCause(final String msg, final String exceptionClassName) {
        this(msg, null, exceptionClassName, Collections.emptyMap(), false);
    }

    /**
     * Constructs a new {@code RemoteExceptionCause} instance with an initial message and cause.
     *
     * @param msg the message
     * @param cause the cause
     * @param exceptionClassName the name of the exception's class (must not be {@code null})
     */
    public RemoteExceptionCause(final String msg, final RemoteExceptionCause cause, final String exceptionClassName) {
        this(msg, cause, exceptionClassName, Collections.emptyMap(), false);
    }

    /**
     * Constructs a new {@code RemoteExceptionCause} instance with an initial message.  No
     * cause is specified.
     *
     * @param msg the message
     * @param exceptionClassName the name of the exception's class (must not be {@code null})
     * @param fields the public fields of the remote exception (must not be {@code null})
     */
    public RemoteExceptionCause(final String msg, final String exceptionClassName, final Map fields) {
        this(msg, null, exceptionClassName, fields, true);
    }

    /**
     * Constructs a new {@code RemoteExceptionCause} instance with an initial message and cause.
     *
     * @param msg the message
     * @param cause the cause
     * @param exceptionClassName the name of the exception's class (must not be {@code null})
     * @param fields the public fields of the remote exception (must not be {@code null})
     */
    public RemoteExceptionCause(final String msg, final RemoteExceptionCause cause, final String exceptionClassName, final Map fields) {
        this(msg, cause, exceptionClassName, fields, true);
    }

    /**
     * Get a remote exception cause for the given {@link Throwable}.  All of the cause and suppressed exceptions will
     * also be converted.
     *
     * @param t the throwable, or {@code null}
     * @return the remote exception cause, or {@code null} if {@code null} was passed in
     */
    public static RemoteExceptionCause of(Throwable t) {
        return of(t, new IdentityHashMap<>());
    }

    private static RemoteExceptionCause of(Throwable t, IdentityHashMap seen) {
        if (t == null) return null;
        if (t instanceof RemoteExceptionCause) {
            return (RemoteExceptionCause) t;
        } else {
            final RemoteExceptionCause existing = seen.get(t);
            if (existing != null) {
                return existing;
            }
            final RemoteExceptionCause e = new RemoteExceptionCause(t.getMessage(), t.getClass().getName(), fieldGetterValue.get(t.getClass()).apply(t));
            e.setStackTrace(t.getStackTrace());
            seen.put(t, e);
            final Throwable cause = t.getCause();
            if (cause != null) e.initCause(of(cause, seen));
            for (Throwable throwable : t.getSuppressed()) {
                e.addSuppressed(of(throwable, seen));
            }
            return e;
        }
    }

    /**
     * Convert this remote exception cause to a plain throwable for sending to peers which use serialization and do not
     * have this class present.  Note that this does not recursively apply; normally, a serialization framework will
     * handle the recursive application of this operation through object resolution.
     *
     * @return the throwable (not {@code null})
     */
    public Throwable toPlainThrowable() {
        final Throwable throwable = new Throwable(toString(), getCause());
        throwable.setStackTrace(getStackTrace());
        for (Throwable s : getSuppressed()) {
            throwable.addSuppressed(s);
        }
        return throwable;
    }

    /**
     * Get the original exception class name.
     *
     * @return the original exception class name (not {@code null})
     */
    public String getExceptionClassName() {
        return exceptionClassName;
    }

    /**
     * Get the field names of the remote exception.
     *
     * @return the field names of the remote exception
     */
    public Set getFieldNames() {
        return fields.keySet();
    }

    /**
     * Get the string value of the given field name.
     *
     * @param fieldName the name of the field (must not be {@code null})
     * @return the string value of the given field name
     */
    public String getFieldValue(String fieldName) {
        Assert.checkNotNullParam("fieldName", fieldName);
        return fields.get(fieldName);
    }

    /**
     * Get a string representation of this exception.  The representation will return an indication of the fact that
     * this was a remote exception, the remote exception type, and optionally details of the exception content, followed
     * by the exception message.
     *
     * @return the string representation of the exception
     */
    public String toString() {
        final String toString = this.toString;
        if (toString == null) {
            final String message = getMessage();
            StringBuilder b = new StringBuilder();
            b.append(message == null ? CommonMessages.msg.remoteException(exceptionClassName) : CommonMessages.msg.remoteException(exceptionClassName, message));
            Iterator> iterator = fields.entrySet().iterator();
            if (iterator.hasNext()) {
                b.append("\n\tPublic fields:");
                do {
                    final Map.Entry entry = iterator.next();
                    b.append('\n').append('\t').append('\t').append(entry.getKey()).append('=').append(entry.getValue());
                } while (iterator.hasNext());
            }
            return this.toString = b.toString();
        }
        return toString;
    }

    // Format:
    //   class name
    //   null | message
    //   stack trace
    //   count ( field-name field-value )*
    //   null | caused-by
    //   count suppressed*
    // Add new data to the end; old versions must ignore extra data

    private static final int ST_NULL = 0;
    private static final int ST_NEW_STRING = 1; // utf8 data follows
    private static final int ST_NEW_STACK_ELEMENT_V8 = 2; // string string string int
    private static final int ST_NEW_STACK_ELEMENT_V9 = 3; // string string string string string string int
    private static final int ST_NEW_EXCEPTION_CAUSE = 4; // recurse
    private static final int ST_INT8 = 5; // one byte
    private static final int ST_INT16 = 6; // two bytes
    private static final int ST_INT32 = 7; // four bytes
    private static final int ST_INT_MINI = 0x20; // low 5 bits == signed value
    private static final int ST_BACKREF_FAR = 0x40; // low 6 bits + next byte are distance
    private static final int ST_BACKREF_NEAR = 0x80; // low 7 bits are distance

    /**
     * Write this remote exception cause to the given stream, without using serialization.
     *
     * @param output the output stream (must not be {@code null})
     * @throws IOException if an error occurs writing the data
     */
    public void writeToStream(DataOutput output) throws IOException {
        Assert.checkNotNullParam("output", output);
        writeToStream(output, new IdentityIntMap(), new HashMap(), 0);
    }

    private static int readPackedInt(DataInput is) throws IOException {
        final int b = is.readUnsignedByte();
        if ((b & 0xE0) == ST_INT_MINI) {
            // sign-extend it
            return b << 27 >> 27;
        } else if (b == ST_INT8) {
            return is.readByte();
        } else if (b == ST_INT16) {
            return is.readShort();
        } else if (b == ST_INT32) {
            return is.readInt();
        } else {
            throw CommonMessages.msg.corruptedStream();
        }
    }

    private static void writePackedInt(DataOutput os, int val) throws IOException {
        if (-0x10 <= val && val < 0x10) {
            os.write(ST_INT_MINI | val & 0b01_1111);
        } else if (-0x80 <= val && val < 0x80) {
            os.write(ST_INT8);
            os.write(val);
        } else if (-0x8000 <= val && val < 0x8000) {
            os.write(ST_INT16);
            os.writeShort(val);
        } else {
            os.write(ST_INT32);
            os.writeInt(val);
        }
    }

    private int writeToStream(DataOutput output, IdentityIntMap seen, HashMap stringCache, int cnt) throws IOException {
        // register in cycle map
        seen.put(this, cnt++);
        // write the header byte
        output.writeByte(ST_NEW_EXCEPTION_CAUSE);
        // first write class name
        cnt = writeString(output, exceptionClassName, seen, stringCache, cnt);
        // null or message
        cnt = writeString(output, getMessage(), seen, stringCache, cnt);
        // stack trace
        cnt = writeStackTrace(output, getStackTrace(), seen, stringCache, cnt);
        // fields
        cnt = writeFields(output, fields, seen, stringCache, cnt);
        // caused-by
        cnt = writeThrowable(output, getCause(), seen, stringCache, cnt);
        // suppressed
        final Throwable[] suppressed = getSuppressed();
        writePackedInt(output, suppressed.length);
        for (final Throwable t : suppressed) {
            cnt = writeThrowable(output, t, seen, stringCache, cnt);
        }
        return cnt;
    }

    private int writeFields(final DataOutput output, final Map fields, final IdentityIntMap seen, final HashMap stringCache, int cnt) throws IOException {
        writePackedInt(output, fields.size());
        for (Map.Entry entry : fields.entrySet()) {
            cnt = writeString(output, entry.getKey(), seen, stringCache, cnt);
            cnt = writeString(output, entry.getValue(), seen, stringCache, cnt);
        }
        return cnt;
    }

    private int writeStackTrace(final DataOutput output, final StackTraceElement[] stackTrace, final IdentityIntMap seen, final HashMap stringCache, int cnt) throws IOException {
        // don't bother recording in seen because stack traces are always copied
        final int length = stackTrace.length;
        writePackedInt(output, length);
        for (StackTraceElement element : stackTrace) {
            cnt = writeStackElement(output, element, seen, stringCache, cnt);
        }
        return cnt;
    }

    private int writeStackElement(final DataOutput output, final StackTraceElement element, final IdentityIntMap seen, final HashMap stringCache, int cnt) throws IOException {
        final int idx = seen.get(element, - 1);
        final int distance = cnt - idx;
        if (idx == -1 || distance > (1 << 14) - 1) {
            output.write(ST_NEW_STACK_ELEMENT_V8);
            cnt = writeString(output, element.getClassName(), seen, stringCache, cnt);
            cnt = writeString(output, element.getMethodName(), seen, stringCache, cnt);
            cnt = writeString(output, element.getFileName(), seen, stringCache, cnt);
            writePackedInt(output, element.getLineNumber());
            seen.put(element, cnt++);
            return cnt;
        } else {
            if (distance < 127) {
                output.writeByte(ST_BACKREF_NEAR | distance);
            } else {
                assert distance <= 0x3fff;
                output.writeByte(ST_BACKREF_FAR | distance >> 8);
                output.writeByte(distance);
            }
            return cnt;
        }
    }

    private int writeThrowable(final DataOutput output, final Throwable throwable, final IdentityIntMap seen, final HashMap stringCache, final int cnt) throws IOException {
        if (throwable == null) {
            output.write(ST_NULL);
            return cnt;
        } else {
            final int idx = seen.get(throwable, - 1);
            final int distance = cnt - idx;
            if (idx == - 1 || distance >= 0x4000) {
                RemoteExceptionCause nested;
                if (throwable instanceof RemoteExceptionCause) {
                    nested = (RemoteExceptionCause) throwable;
                } else {
                    seen.put(throwable, cnt); // do not increment yet
                    nested = of(throwable);
                }
                return nested.writeToStream(output, seen, stringCache, cnt); // this will increment it
            } else {
                if (distance < 127) {
                    output.writeByte(ST_BACKREF_NEAR | distance);
                } else {
                    assert distance <= 0x3fff;
                    output.writeByte(ST_BACKREF_FAR | distance >> 8);
                    output.writeByte(distance);
                }
                return cnt;
            }
        }
    }

    private int writeString(final DataOutput output, String string, final IdentityIntMap seen, final HashMap stringCache, final int cnt) throws IOException {
        if (string == null) {
            output.write(ST_NULL);
            return cnt;
        }
        // make sure we never duplicate a string
        string = stringCache.computeIfAbsent(string, Function.identity());
        final int idx = seen.get(string, - 1);
        final int distance = cnt - idx;
        if (idx == -1 || distance > (1 << 14) - 1) {
            seen.put(string, cnt);
            output.write(ST_NEW_STRING);
            output.writeUTF(string);
            return cnt + 1;
        } else {
            if (distance < 127) {
                output.writeByte(ST_BACKREF_NEAR | distance);
            } else {
                assert distance <= 0x3fff;
                output.writeByte(ST_BACKREF_FAR | distance >> 8);
                output.writeByte(distance);
            }
            return cnt;
        }
    }

    public static RemoteExceptionCause readFromStream(DataInput input) throws IOException {
        return readObject(input, RemoteExceptionCause.class, new ArrayList<>(), false);
    }

    private static  T readObject(DataInput input, Class expect, ArrayList cache, final boolean allowNull) throws IOException {
        final int b = input.readUnsignedByte();
        if (b == ST_NULL) {
            if (! allowNull) {
                throw CommonMessages.msg.corruptedStream();
            }
            return null;
        } else if (b == ST_NEW_STRING) {
            if (expect != String.class) {
                throw CommonMessages.msg.corruptedStream();
            }
            final String str = input.readUTF();
            cache.add(str);
            return expect.cast(str);
        } else if (b == ST_NEW_EXCEPTION_CAUSE) {
            if (expect != RemoteExceptionCause.class) {
                throw CommonMessages.msg.corruptedStream();
            }
            final int idx = cache.size();
            cache.add(null);
            String exClassName = readObject(input, String.class, cache, false);
            String exMessage = readObject(input, String.class, cache, true);
            int length = readPackedInt(input);
            StackTraceElement[] stackTrace;
            if (length == 0) {
                stackTrace = EMPTY_STACK;
            } else {
                stackTrace = new StackTraceElement[length];
                for (int i = 0; i < length; i++) {
                    stackTrace[i] = readObject(input, StackTraceElement.class, cache, false);
                }
            }
            Map fields;
            length = readPackedInt(input);
            if (length == 0) {
                fields = Collections.emptyMap();
            } else if (length == 1) {
                fields = Collections.singletonMap(readObject(input, String.class, cache, false), readObject(input, String.class, cache, false));
            } else {
                fields = new HashMap<>(length);
                for (int i = 0; i < length; i++) {
                    fields.put(readObject(input, String.class, cache, false), readObject(input, String.class, cache, false));
                }
            }
            final RemoteExceptionCause result = new RemoteExceptionCause(exMessage, null, exClassName, fields, false);
            cache.set(idx, result);
            RemoteExceptionCause causedBy = readObject(input, RemoteExceptionCause.class, cache, true);
            result.initCause(causedBy);
            length = readPackedInt(input);
            result.setStackTrace(stackTrace);
            for (int i = 0; i < length; i++) {
                // this can't actually be null because we passed {@code false} in to allowNull
                //noinspection ConstantConditions
                result.addSuppressed(readObject(input, RemoteExceptionCause.class, cache, false));
            }
            return expect.cast(result);
        } else if (b == ST_NEW_STACK_ELEMENT_V8) {
            if (expect != StackTraceElement.class) {
                throw CommonMessages.msg.corruptedStream();
            }
            // this can't actually be null because we passed {@code false} in to allowNull
            //noinspection ConstantConditions
            final StackTraceElement element = new StackTraceElement(
                readObject(input, String.class, cache, false),
                readObject(input, String.class, cache, false),
                readObject(input, String.class, cache, true),
                readPackedInt(input)
            );
            cache.add(element);
            return expect.cast(element);
        } else if (b == ST_NEW_STACK_ELEMENT_V9) {
            if (expect != StackTraceElement.class) {
                throw CommonMessages.msg.corruptedStream();
            }
            // discard CL name, module name, and module version
            readObject(input, String.class, cache, true);
            readObject(input, String.class, cache, true);
            readObject(input, String.class, cache, true);
            // this can't actually be null because we passed {@code false} in to allowNull
            //noinspection ConstantConditions
            final StackTraceElement element = new StackTraceElement(
                readObject(input, String.class, cache, false),
                readObject(input, String.class, cache, false),
                readObject(input, String.class, cache, true),
                readPackedInt(input)
            );
            cache.add(element);
            return expect.cast(element);
        } else if ((b & ST_BACKREF_NEAR) != 0) {
            int idx = b & 0x7f;
            if (idx > cache.size()) {
                throw CommonMessages.msg.corruptedStream();
            }
            Object obj = cache.get(cache.size() - idx);
            if (expect.isInstance(obj)) {
                return expect.cast(obj);
            } else {
                throw CommonMessages.msg.corruptedStream();
            }
        } else if ((b & ST_BACKREF_FAR) != 0) {
            final int b2 = input.readUnsignedByte();
            int idx = (b & 0x3f) << 8 | b2;
            if (idx > cache.size()) {
                throw CommonMessages.msg.corruptedStream();
            }
            Object obj = cache.get(cache.size() - idx);
            if (expect.isInstance(obj)) {
                return expect.cast(obj);
            } else {
                throw CommonMessages.msg.corruptedStream();
            }
        } else {
            throw CommonMessages.msg.corruptedStream();
        }
    }

    private static final String[] NO_STRINGS = new String[0];
    private static final RemoteExceptionCause[] NO_REMOTE_EXCEPTION_CAUSES = new RemoteExceptionCause[0];

    Object writeReplace() {
        final Throwable[] origSuppressed = getSuppressed();
        final int length = origSuppressed.length;
        final RemoteExceptionCause[] suppressed;
        if (length == 0) {
            suppressed = NO_REMOTE_EXCEPTION_CAUSES;
        } else {
            suppressed = new RemoteExceptionCause[length];
            for (int i = 0; i < length; i ++) {
                suppressed[i] = of(origSuppressed[i]);
            }
        }
        String[] fieldArray;
        final int size = fields.size();
        if (size == 0) {
            fieldArray = NO_STRINGS;
        } else {
            fieldArray = new String[size << 1];
            int i = 0;
            for (Map.Entry entry : fields.entrySet()) {
                fieldArray[i++] = entry.getKey();
                fieldArray[i++] = entry.getValue();
            }
        }
        return new Serialized(getMessage(), exceptionClassName, of(getCause()), suppressed, getStackTrace(), fieldArray);
    }

    public RemoteExceptionCause getCause() {
        return (RemoteExceptionCause) super.getCause();
    }

    static final class Serialized implements Serializable {
        private static final long serialVersionUID = - 2201431870774913071L;

        // small field names serialize smaller

        final String m;
        final String cn;
        final RemoteExceptionCause c;
        final RemoteExceptionCause[] s;
        final StackTraceElement[] st;
        final String[] f;

        Serialized(final String m, final String cn, final RemoteExceptionCause c, final RemoteExceptionCause[] s, final StackTraceElement[] st, final String[] f) {
            this.m = m;
            this.cn = cn;
            this.c = c;
            this.s = s;
            this.st = st;
            this.f = f;
        }

        Object readResolve() {
            final Map fields;
            if (f == null) {
                fields = Collections.emptyMap();
            } else {
                final int fl = f.length;
                if ((fl & 1) != 0) {
                    throw CommonMessages.msg.invalidOddFields();
                } else if (fl == 0) {
                    fields = Collections.emptyMap();
                } else if (fl == 2) {
                    fields = Collections.singletonMap(f[0], f[1]);
                } else {
                    final TreeMap map = new TreeMap<>();
                    for (int i = 0; i < fl; i += 2) {
                        map.put(f[i], f[i + 1]);
                    }
                    fields = Collections.unmodifiableMap(map);
                }
            }
            final RemoteExceptionCause ex = new RemoteExceptionCause(m, c, cn, fields, false);
            ex.setStackTrace(st);
            final RemoteExceptionCause[] suppressed = s;
            if (suppressed != null) for (RemoteExceptionCause c : suppressed) {
                ex.addSuppressed(c);
            }
            return ex;
        }
    }
}