All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mockito.internal.creation.bytebuddy.MockMethodAdvice Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2016 Mockito contributors
 * This program is made available under the terms of the MIT License.
 */
package org.mockito.internal.creation.bytebuddy;

import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.method.MethodDescription;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.dynamic.scaffold.MethodGraph;
import net.bytebuddy.implementation.bind.annotation.Argument;
import net.bytebuddy.implementation.bind.annotation.This;
import net.bytebuddy.implementation.bytecode.assign.Assigner;
import org.mockito.exceptions.base.MockitoException;
import org.mockito.internal.debugging.LocationImpl;
import org.mockito.internal.exceptions.stacktrace.ConditionalStackTraceFilter;
import org.mockito.internal.invocation.RealMethod;
import org.mockito.internal.invocation.SerializableMethod;
import org.mockito.internal.util.concurrent.WeakConcurrentMap;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.ref.SoftReference;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;

public class MockMethodAdvice extends MockMethodDispatcher {

    final WeakConcurrentMap interceptors;

    private final String identifier;

    private final SelfCallInfo selfCallInfo = new SelfCallInfo();
    private final MethodGraph.Compiler compiler = MethodGraph.Compiler.Default.forJavaHierarchy();
    private final WeakConcurrentMap, SoftReference> graphs
        = new WeakConcurrentMap.WithInlinedExpunction, SoftReference>();

    public MockMethodAdvice(WeakConcurrentMap interceptors, String identifier) {
        this.interceptors = interceptors;
        this.identifier = identifier;
    }

    @SuppressWarnings("unused")
    @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
    private static Callable enter(@Identifier String identifier,
                                     @Advice.This Object mock,
                                     @Advice.Origin Method origin,
                                     @Advice.AllArguments Object[] arguments) throws Throwable {
        MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, mock);
        if (dispatcher == null || !dispatcher.isMocked(mock) || dispatcher.isOverridden(mock, origin)) {
            return null;
        } else {
            return dispatcher.handle(mock, origin, arguments);
        }
    }

    @SuppressWarnings({"unused", "UnusedAssignment"})
    @Advice.OnMethodExit
    private static void exit(@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned,
                             @Advice.Enter Callable mocked) throws Throwable {
        if (mocked != null) {
            returned = mocked.call();
        }
    }

    public static Throwable hideRecursiveCall(Throwable throwable, int current, Class targetType) {
        try {
            StackTraceElement[] stack = throwable.getStackTrace();
            int skip = 0;
            StackTraceElement next;
            do {
                next = stack[stack.length - current - ++skip];
            } while (!next.getClassName().equals(targetType.getName()));
            int top = stack.length - current - skip;
            StackTraceElement[] cleared = new StackTraceElement[stack.length - skip];
            System.arraycopy(stack, 0, cleared, 0, top);
            System.arraycopy(stack, top + skip, cleared, top, current);
            throwable.setStackTrace(cleared);
            return throwable;
        } catch (RuntimeException ignored) {
            // This should not happen unless someone instrumented or manipulated exception stack traces.
            return throwable;
        }
    }

    @Override
    public Callable handle(Object instance, Method origin, Object[] arguments) throws Throwable {
        MockMethodInterceptor interceptor = interceptors.get(instance);
        if (interceptor == null) {
            return null;
        }
        RealMethod realMethod;
        if (instance instanceof Serializable) {
            realMethod = new SerializableRealMethodCall(identifier, origin, instance, arguments);
        } else {
            realMethod = new RealMethodCall(selfCallInfo, origin, instance, arguments);
        }
        Throwable t = new Throwable();
        t.setStackTrace(skipInlineMethodElement(t.getStackTrace()));
        return new ReturnValueWrapper(interceptor.doIntercept(instance,
                origin,
                arguments,
                realMethod,
                new LocationImpl(t)));
    }

    @Override
    public boolean isMock(Object instance) {
        // We need to exclude 'interceptors.target' explicitly to avoid a recursive check on whether
        // the map is a mock object what requires reading from the map.
        return instance != interceptors.target && interceptors.containsKey(instance);
    }

    @Override
    public boolean isMocked(Object instance) {
        return selfCallInfo.checkSuperCall(instance) && isMock(instance);
    }

    @Override
    public boolean isOverridden(Object instance, Method origin) {
        SoftReference reference = graphs.get(instance.getClass());
        MethodGraph methodGraph = reference == null ? null : reference.get();
        if (methodGraph == null) {
            methodGraph = compiler.compile(new TypeDescription.ForLoadedType(instance.getClass()));
            graphs.put(instance.getClass(), new SoftReference(methodGraph));
        }
        MethodGraph.Node node = methodGraph.locate(new MethodDescription.ForLoadedMethod(origin).asSignatureToken());
        return !node.getSort().isResolved() || !node.getRepresentative().asDefined().getDeclaringType().represents(origin.getDeclaringClass());
    }

    private static class RealMethodCall implements RealMethod {

        private final SelfCallInfo selfCallInfo;

        private final Method origin;

        private final Object instance;

        private final Object[] arguments;

        private RealMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments) {
            this.selfCallInfo = selfCallInfo;
            this.origin = origin;
            this.instance = instance;
            this.arguments = arguments;
        }

        @Override
        public boolean isInvokable() {
            return true;
        }

        @Override
        public Object invoke() throws Throwable {
            if (!Modifier.isPublic(origin.getDeclaringClass().getModifiers() & origin.getModifiers())) {
                origin.setAccessible(true);
            }
            selfCallInfo.set(instance);
            return tryInvoke(origin, instance, arguments);
        }

    }

    private static class SerializableRealMethodCall implements RealMethod {

        private final String identifier;

        private final SerializableMethod origin;

        private final Object instance;

        private final Object[] arguments;

        private SerializableRealMethodCall(String identifier, Method origin, Object instance, Object[] arguments) {
            this.origin = new SerializableMethod(origin);
            this.identifier = identifier;
            this.instance = instance;
            this.arguments = arguments;
        }

        @Override
        public boolean isInvokable() {
            return true;
        }

        @Override
        public Object invoke() throws Throwable {
            Method method = origin.getJavaMethod();
            if (!Modifier.isPublic(method.getDeclaringClass().getModifiers() & method.getModifiers())) {
                method.setAccessible(true);
            }
            MockMethodDispatcher mockMethodDispatcher = MockMethodDispatcher.get(identifier, instance);
            if (!(mockMethodDispatcher instanceof MockMethodAdvice)) {
                throw new MockitoException("Unexpected dispatcher for advice-based super call");
            }
            Object previous = ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.replace(instance);
            try {
                return tryInvoke(method, instance, arguments);
            } finally {
                ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.set(previous);
            }
        }
    }

    private static Object tryInvoke(Method origin, Object instance, Object[] arguments) throws Throwable {
        try {
            return origin.invoke(instance, arguments);
        } catch (InvocationTargetException exception) {
            Throwable cause = exception.getCause();
            new ConditionalStackTraceFilter().filter(hideRecursiveCall(cause, new Throwable().getStackTrace().length, origin.getDeclaringClass()));
            throw cause;
        }
    }

    // With inline mocking, mocks for concrete classes are not subclassed, so elements of the stubbing methods are not filtered out.
    // Therefore, if the method is inlined, skip the element.
    private static StackTraceElement[] skipInlineMethodElement(StackTraceElement[] elements) {
        List list = new ArrayList(elements.length);
        for (int i = 0; i < elements.length; i++) {
            StackTraceElement element = elements[i];
            list.add(element);
            if (element.getClassName().equals(MockMethodAdvice.class.getName()) && element.getMethodName().equals("handle")) {
                // If the current element is MockMethodAdvice#handle(), the next is assumed to be an inlined method.
                i++;
            }
        }
        return list.toArray(new StackTraceElement[list.size()]);
    }

    private static class ReturnValueWrapper implements Callable {

        private final Object returned;

        private ReturnValueWrapper(Object returned) {
            this.returned = returned;
        }

        @Override
        public Object call() {
            return returned;
        }
    }

    private static class SelfCallInfo extends ThreadLocal {

        Object replace(Object value) {
            Object current = get();
            set(value);
            return current;
        }

        boolean checkSuperCall(Object value) {
            if (value == get()) {
                set(null);
                return false;
            } else {
                return true;
            }
        }
    }

    @Retention(RetentionPolicy.RUNTIME)
    @interface Identifier {

    }

    static class ForHashCode {

        @SuppressWarnings("unused")
        @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
        private static boolean enter(@Identifier String id,
                                     @Advice.This Object self) {
            MockMethodDispatcher dispatcher = MockMethodDispatcher.get(id, self);
            return dispatcher != null && dispatcher.isMock(self);
        }

        @SuppressWarnings({"unused", "UnusedAssignment"})
        @Advice.OnMethodExit
        private static void enter(@Advice.This Object self,
                                  @Advice.Return(readOnly = false) int hashCode,
                                  @Advice.Enter boolean skipped) {
            if (skipped) {
                hashCode = System.identityHashCode(self);
            }
        }
    }

    static class ForEquals {

        @SuppressWarnings("unused")
        @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
        private static boolean enter(@Identifier String identifier,
                                     @Advice.This Object self) {
            MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, self);
            return dispatcher != null && dispatcher.isMock(self);
        }

        @SuppressWarnings({"unused", "UnusedAssignment"})
        @Advice.OnMethodExit
        private static void enter(@Advice.This Object self,
                                  @Advice.Argument(0) Object other,
                                  @Advice.Return(readOnly = false) boolean equals,
                                  @Advice.Enter boolean skipped) {
            if (skipped) {
                equals = self == other;
            }
        }
    }

    public static class ForReadObject {

        @SuppressWarnings("unused")
        public static void doReadObject(@Identifier String identifier,
                                        @This MockAccess thiz,
                                        @Argument(0) ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            MockMethodAdvice mockMethodAdvice = (MockMethodAdvice) MockMethodDispatcher.get(identifier, thiz);
            if (mockMethodAdvice != null) {
                mockMethodAdvice.interceptors.put(thiz, thiz.getMockitoInterceptor());
            }
        }
    }
}