All Downloads are FREE. Search and download functionalities are using the official Maven repository.

co.streamx.fluent.extree.expression.ExpressionClassCracker Maven / Gradle / Ivy

The newest version!
package co.streamx.fluent.extree.expression;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.Type;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;

class ExpressionClassCracker {

    private static final String DUMP_FOLDER_SYSTEM_PROPERTY = "jdk.internal.lambda.dumpProxyClasses";
    private static final URLClassLoader lambdaClassLoader;
    private static final String lambdaClassLoaderCreationError;

    private static final ExpressionClassCracker instance = new ExpressionClassCracker();

    public static ExpressionClassCracker get() {
        return instance;
    }

    static {
        String folderPath = System.getProperty(DUMP_FOLDER_SYSTEM_PROPERTY);
        if (folderPath == null) {
            lambdaClassLoaderCreationError = "Ensure that the '" + DUMP_FOLDER_SYSTEM_PROPERTY
                    + "' system property is properly set.";
            lambdaClassLoader = null;
        } else {
            File folder = new File(folderPath);
            if (!folder.isDirectory()) {
                lambdaClassLoaderCreationError = "Ensure that the '" + DUMP_FOLDER_SYSTEM_PROPERTY
                        + "' system property is properly set (" + folderPath + " does not exist).";
                lambdaClassLoader = null;
            } else {
                URL folderURL;
                try {
                    folderURL = folder.toURI().toURL();
                } catch (MalformedURLException mue) {
                    throw new RuntimeException(mue);
                }

                lambdaClassLoaderCreationError = null;
                lambdaClassLoader = new URLClassLoader(new URL[] { folderURL });
            }
        }
    }

    private ExpressionClassCracker() {
    }

    private static final class ParameterReplacer extends SimpleExpressionVisitor {
        private List paramIndices;
        private final Object lambda;
        private LambdaExpression parsedLambda;
        private List prevParamIndices;

        public ParameterReplacer(int paramIndex, Object lambda) {
            this.paramIndices = Arrays.asList(paramIndex);
            this.lambda = lambda;
        }

        public LambdaExpression getParsedLambda() {
            return parsedLambda;
        }

        @Override
        public Expression visit(InvocationExpression e) {
            if (this.paramIndices.isEmpty())
                return e;
            prevParamIndices = this.paramIndices;
            try {
                return super.visit(e);
            } finally {
                this.paramIndices = prevParamIndices;
            }
        }

        @Override
        protected List visitArguments(List original) {
            try {
                return super.visitArguments(original);
            } finally {
                List paramIndices = this.paramIndices;
                List newParamIndices = new ArrayList<>();
                for (int i = 0; i < original.size(); i++) {
                    Expression e = original.get(i);
                    if (e.getExpressionType() == ExpressionType.Parameter) {
                        ParameterExpression p = (ParameterExpression) e;
                        if (paramIndices.contains(p.getIndex()))
                            newParamIndices.add(i);
                    }
                }

                this.paramIndices = newParamIndices;
            }
        }

        @Override
        public Expression visit(MemberExpression e) {
            Expression instance = e.getInstance();
            if (instance != null && instance.getExpressionType() == ExpressionType.Parameter) {
                int index = ((ParameterExpression) instance).getIndex();
                if (prevParamIndices.contains(index)) {
                    if (lambda != null && parsedLambda == null) {
                        Method method = (Method) e.getMember();
                        try {
                            method = lambda.getClass().getDeclaredMethod(method.getName(), method.getParameterTypes());
                        } catch (NoSuchMethodException nsme) {
                            // should never happen
                            throw new RuntimeException(nsme);
                        }
                        parsedLambda = ExpressionClassCracker.get().lambda(lambda, method, true);
                    }
                    return Expression.delegate(e.getResultType(), Expression.parameter(LambdaExpression.class, index),
                            visitParameters(e.getParameters()));
                }
            }
            return super.visit(e);
        }

    }

    private static class SerializedLambdaObjectExtractor extends ObjectOutputStream {

        private SerializedLambda serializedLambda;

        public SerializedLambdaObjectExtractor() throws IOException {
            super(new ByteArrayOutputStream(8));
            enableReplaceObject(true);
        }

        public SerializedLambda extract(Object lambda) throws IOException {
            writeObject(lambda);
            return serializedLambda;
        }

        @Override
        protected Object replaceObject(Object obj) throws IOException {
            serializedLambda = (SerializedLambda) obj;
            return null;
        }
    }

    LambdaExpression lambda(Object lambda,
                               boolean synthetic) {
        return lambda(lambda, null, synthetic);
    }

    private LambdaExpression lambda(Object lambda,
                                       Method lambdaMethod,
                                       boolean synthetic) {
        Class lambdaClass = lambda.getClass();
        if (!isFunctional(lambdaClass))
            throw new IllegalArgumentException("The requested object is not a Java lambda");

        if (lambda instanceof Serializable) {

            try (SerializedLambdaObjectExtractor extractor = new SerializedLambdaObjectExtractor()) {
                SerializedLambda extracted = extractor.extract(lambda);

                ClassLoader lambdaClassLoader = lambdaClass.getClassLoader();
                return lambda(extracted, lambdaClassLoader, synthetic);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }

        }

        return lambdaFromFileSystem(lambda, lambdaMethod, null);
    }

    LambdaExpression lambdaFromFileSystem(Object lambda,
                                             Method lambdaMethod,
                                             ClassLoader classLoader) {
        ExpressionClassVisitor lambdaVisitor = parseFromFileSystem(lambda, lambdaMethod, classLoader);

        return createLambda(lambdaVisitor, null);
    }

    LambdaExpression lambdaFromClassLoader(ClassLoader classLoader,
                                              String className,
                                              Expression instance,
                                              String method,
                                              String methodDescriptor) {

        SerializedDescriptor desc = new SerializedDescriptor(className, method, methodDescriptor, -1, methodDescriptor);
        boolean isCacheable = instance == null || instance.getResultType().isSynthetic();
        if (isCacheable) {
            LambdaExpression cached = cache.get(desc);
            if (cached != null) {
//                System.out.println("Cache hit #2: " + cached);
                return cached;
            }
        }

        ExpressionClassVisitor lambdaVisitor = parseClass(classLoader, className, instance, method, methodDescriptor);

        LambdaExpression parsed = createLambda(lambdaVisitor, desc);

        if (isCacheable) {
            LambdaExpression cached = cache.putIfAbsent(desc, parsed);
            if (cached != null)
                parsed = cached;
        }

        return parsed;
    }

    private LambdaExpression createLambda(ExpressionClassVisitor lambdaVisitor,
                                             SerializedDescriptor key) {
        Expression lambdaExpression = lambdaVisitor.getResult();
        Class lambdaType = lambdaVisitor.getType();
        List lambdaParams = Arrays.asList(lambdaVisitor.getParams());

        Expression stripped = lambdaType == Void.TYPE ? null : stripConvertExpressions(lambdaExpression);

        List block = lambdaVisitor.getStatements();
        if (block != null && !block.isEmpty()) {

            block = new ArrayList<>(block);
            if (lambdaExpression != null)
                block.add(lambdaExpression);

            lambdaExpression = Expression.block(lambdaType, block);
        } else if (stripped instanceof InvocationExpression) {

            InvocationExpression invocation = (InvocationExpression) stripped;
            InvocableExpression target = invocation.getTarget();
            if (target instanceof LambdaExpression) {
                REDUCE_CHECK: for (;;) {
                    if (!lambdaType.isAssignableFrom(target.getResultType()))
                        break;
                    List params = lambdaParams;
                    List args = invocation.getArguments();
                    int psize = params.size();
                    if (psize != args.size())
                        break;
                    for (int i = 0; i < psize; i++) {
                        Expression arg = args.get(i);
                        if (!(arg instanceof ParameterExpression))
                            break REDUCE_CHECK;
                        ParameterExpression parg = (ParameterExpression) arg;
                        ParameterExpression param = params.get(i);
                        if (parg.getIndex() != param.getIndex())
                            break REDUCE_CHECK;
                        if (!param.getResultType().isAssignableFrom(parg.getResultType()))
                            break REDUCE_CHECK;
                    }
                    return (LambdaExpression) target;
                }
            }

        }

        Expression actualExpression = TypeConverter.convert(lambdaExpression, lambdaType);
        return Expression.lambda(lambdaType, actualExpression, lambdaParams, lambdaVisitor.getLocals(), key);
    }

    LambdaExpression lambda(SerializedLambda extracted,
                               ClassLoader lambdaClassLoader) {
        return lambda(extracted, lambdaClassLoader, true);
    }

    @Data
    @EqualsAndHashCode
    @RequiredArgsConstructor
    private static class SerializedDescriptor {

        public SerializedDescriptor(SerializedLambda lambda) {
            this.implClass = lambda.getImplClass();
            this.implMethodName = lambda.getImplMethodName();
            this.implMethodSignature = lambda.getImplMethodSignature();
            this.implMethodKind = lambda.getImplMethodKind();
            this.instantiatedMethodType = lambda.getInstantiatedMethodType();
        }

        public SerializedDescriptor withImplClass(String implClass) {
            return new SerializedDescriptor(implClass, implClass, implClass, implMethodKind, implClass);
        }

        private final String implClass;
        private final String implMethodName;
        private final String implMethodSignature;
        private final int implMethodKind;
        private final String instantiatedMethodType;
    }

    private static final Map> cache = new ConcurrentHashMap<>();

    LambdaExpression lambda(SerializedLambda extracted,
                               ClassLoader lambdaClassLoader,
                               boolean synthetic) {
        int capturedLength = extracted.getCapturedArgCount();
        SerializedDescriptor desc = new SerializedDescriptor(extracted);

        boolean hasThis = extracted.getImplMethodKind() == MethodHandleInfo.REF_invokeInterface
                || extracted.getImplMethodKind() == MethodHandleInfo.REF_invokeSpecial
                || extracted.getImplMethodKind() == MethodHandleInfo.REF_invokeVirtual;

        Expression instance;

        if (hasThis) {
            if (capturedLength == 0) {
                try {
                    instance = Expression
                            .parameter(lambdaClassLoader.loadClass(extracted.getImplClass().replace('/', '.')), 0);
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
            } else {
                Object arg0 = extracted.getCapturedArg(0);
                if (desc != null)
                    desc = desc.withImplClass(arg0.getClass().getName());
                instance = Expression.constant(arg0);
            }
        } else {
            instance = null;
        }

        boolean noNeedHandleCapturedArgs = capturedLength == 0 || (hasThis && capturedLength <= 1);
        boolean isCacheable = noNeedHandleCapturedArgs && (!hasThis || instance.getResultType().isSynthetic());
        if (isCacheable) {
            LambdaExpression cached = cache.get(desc);
            if (cached != null) {
//                System.out.println("Cache hit #1: " + cached);
                return cached;
            }
        }

        ExpressionClassVisitor actualVisitor = parseClass(lambdaClassLoader, extracted.getImplClass(), 
                instance, extracted.getImplMethodName(), extracted.getImplMethodSignature(), synthetic);

        final Class type = actualVisitor.getType();
        Expression reducedExpression = type == Void.TYPE ? actualVisitor.getResult()
                : TypeConverter.convert(actualVisitor.getResult(), type);

        List block = actualVisitor.getStatements();
        if (block != null && !block.isEmpty()) {

            block = new ArrayList<>(block);
            if (reducedExpression != null)
                block.add(reducedExpression);

            reducedExpression = Expression.block(type, block);
        }

        ParameterExpression[] params;

        // in case there is no captured args, we my assume the instantiated method signature to be the most accurate,
        // e.g. handle the case of a parameter for this
        if (capturedLength == 0) {

            Type[] argTypes = Type.getArgumentTypes(extracted.getInstantiatedMethodType());
            params = new ParameterExpression[argTypes.length];

            for (int i = 0; i < argTypes.length; i++)
                params[i] = Expression.parameter(actualVisitor.getClass(argTypes[i]), i);
        } else {
            params = actualVisitor.getParams();
        }

        LambdaExpression extractedLambda = Expression.lambda(type, reducedExpression,
                Collections.unmodifiableList(Arrays.asList(params)), actualVisitor.getLocals(), desc);

        if (noNeedHandleCapturedArgs) {
            if (isCacheable) {
                LambdaExpression cached = cache.putIfAbsent(desc, extractedLambda);
                if (cached != null)
                    extractedLambda = cached;
            }
            return extractedLambda;
        }

        List args = new ArrayList<>(params.length);

        for (int i = hasThis ? 1 : 0; i < capturedLength; i++) {
            Object arg = extracted.getCapturedArg(i);
            if (arg instanceof SerializedLambda) {
                SerializedLambda argLambda = (SerializedLambda) arg;

                LambdaExpression argExtractedLambda = lambda(argLambda, lambdaClassLoader);

                extractedLambda = (LambdaExpression) extractedLambda
                        .accept(new ParameterReplacer(args.size(), null));

                arg = argExtractedLambda;
            }
            args.add(Expression.constant(arg));
        }

        List finalParams = new ArrayList<>(params.length - capturedLength);
        int boundArgs = args.size();
        for (int y = boundArgs; y < params.length; y++) {
            ParameterExpression param = params[y];
            ParameterExpression arg = Expression.parameter(param.getResultType(), y - boundArgs);
            args.add(arg);
            finalParams.add(arg);
        }

//        cached = cache.putIfAbsent(desc, extractedLambda);
//        if (cached != null)
//            extractedLambda = cached;

        InvocationExpression newTarget = Expression.invoke(extractedLambda, args);

        return Expression.lambda(type, newTarget, Collections.unmodifiableList(finalParams),
                Collections.emptyList(), desc);
    }

    @SuppressWarnings("unchecked")
     T parseSyntheticArguments(T expression,
                                                     List arguments) {

        for (int i = 0; i < arguments.size(); i++) {
            Expression e = arguments.get(i);
            if (e.getExpressionType() == ExpressionType.Constant) {
                Object value = ((ConstantExpression) e).getValue();
                if (value != null && isFunctional(value.getClass())) {
                    ParameterReplacer replacer = new ParameterReplacer(i, value);
                    expression = (T) expression.accept(replacer);
                    if (replacer.getParsedLambda() != null) {
                        arguments.set(i, Expression.constant(replacer.getParsedLambda()));
                    }
                }
            }
        }
        return expression;
    }

    private static boolean isFunctional(Class clazz) {
        if (clazz.isSynthetic())
            return true;

        for (Class i : clazz.getInterfaces())
            if (i.isAnnotationPresent(FunctionalInterface.class))
                return true;

        return false;
    }

    ExpressionClassVisitor parseFromFileSystem(Object lambda,
                                               Method lambdaMethod,
                                               ClassLoader classLoader) {
        if (classLoader == null) {
            if (lambdaClassLoader == null)
                throw new RuntimeException(lambdaClassLoaderCreationError);
            classLoader = lambdaClassLoader;
        }

        Class lambdaClass;

        if (lambdaMethod == null) {
            lambdaClass = lambda.getClass();
            lambdaMethod = findFunctionalMethod(lambdaClass);
        } else {
            lambdaClass = lambdaMethod.getDeclaringClass();
        }
        String lambdaClassName = lambdaClassName(lambdaClass);
        return parseClass(classLoader, lambdaClassName,
                lambda instanceof Expression ? (Expression) lambda : Expression.constant(lambda),
                lambdaMethod);
    }

    private String lambdaClassName(Class lambdaClass) {
        String lambdaClassName = lambdaClass.getName();
        int lastIndexOfSlash = lambdaClassName.lastIndexOf('/');
        String className = lastIndexOfSlash > 0 ? lambdaClassName.substring(0, lastIndexOfSlash) : lambdaClassName;
        return className;
    }

    private String classFilePath(String className) {
        return className.replace('.', '/') + ".class";
    }

    private Method findFunctionalMethod(Class functionalClass) {
        for (Method m : functionalClass.getMethods()) {
            if (!m.isDefault()) {
                return m;
            }
        }
        throw new IllegalArgumentException("Not a lambda expression. No non-default method.");
    }

    private ExpressionClassVisitor parseClass(ClassLoader classLoader,
                                              String className,
                                              Expression instance,
                                              Method method) {
        return parseClass(classLoader, className, instance, method.getName(), Type.getMethodDescriptor(method), false);
    }

    private ExpressionClassVisitor parseClass(ClassLoader classLoader,
                                              String className,
                                              Expression instance,
                                              String method,
                                              String methodDescriptor) {
        return parseClass(classLoader, className, instance, method, methodDescriptor, true);
    }

    private ExpressionClassVisitor parseClass(ClassLoader classLoader,
                                              String className,
                                              Expression instance,
                                              String method,
                                              String methodDescriptor,
                                              boolean synthetic) {
        String classFilePath = classFilePath(className);
        ExpressionClassVisitor visitor = new ExpressionClassVisitor(classLoader, instance, method, methodDescriptor,
                synthetic);
        try (InputStream classStream = getResourceAsStream(classLoader, classFilePath)) {
            ClassReader reader = new ClassReader(classStream);
            reader.accept(visitor, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
            return visitor;
        } catch (IOException e) {
            throw new RuntimeException("error parsing class file " + classFilePath, e);
        }
    }

    private InputStream getResourceAsStream(ClassLoader classLoader,
                                            String path)
            throws FileNotFoundException {
        InputStream stream = classLoader.getResourceAsStream(path);
        if (stream == null)
            throw new FileNotFoundException(path);
        return stream;
    }

    private Expression stripConvertExpressions(Expression expression) {
        while (expression.getExpressionType() == ExpressionType.Convert) {
            expression = ((UnaryExpression) expression).getFirst();
        }
        return expression;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy