All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.helidon.integrations.graal.mp.nativeimage.extension.HelidonMpFeature Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2023 Oracle and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.helidon.integrations.graal.mp.nativeimage.extension;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.helidon.config.mp.MpConfigProviderResolver;
import io.helidon.integrations.graal.nativeimage.extension.HelidonReflectionConfiguration;
import io.helidon.integrations.graal.nativeimage.extension.NativeTrace;
import io.helidon.integrations.graal.nativeimage.extension.NativeUtil;
import io.helidon.logging.common.LogConfig;

import io.github.classgraph.ClassGraph;
import io.github.classgraph.ScanResult;
import org.graalvm.nativeimage.hosted.Feature;
import org.graalvm.nativeimage.hosted.RuntimeProxyCreation;
import org.graalvm.nativeimage.hosted.RuntimeReflection;

/**
 * Helidon MP feature for GraalVM native image.
 */
public class HelidonMpFeature implements Feature {
    private static final String AT_REGISTER_REST_CLIENT = "org.eclipse.microprofile.rest.client.inject.RegisterRestClient";

    private final NativeTrace tracer = new NativeTrace();
    private NativeUtil util;

    @Override
    public void beforeAnalysis(BeforeAnalysisAccess access) {
        // need the application classloader
        Class logConfigClass = access.findClassByName(LogConfig.class.getName());
        ClassLoader classLoader = logConfigClass.getClassLoader();
        // load configuration
        HelidonReflectionConfiguration config = HelidonReflectionConfiguration.load(access, classLoader, tracer);

        // classpath scanning using the correct classloader
        ScanResult scan = new ClassGraph()
                .overrideClasspath(access.getApplicationClassPath())
                .enableAllInfo()
                .scan();

        util = NativeUtil.create(tracer,
                                 scan,
                                 access::findClassByName,
                                 config.excluded()::contains);
        BeforeAnalysisContext context = new BeforeAnalysisContext(access, scan, config.excluded());


        /*
        Now handle all MP specific tasks
         */
        // rest client registration (proxy support)
        processRegisterRestClient(context);

        // all classes used as return types and parameters in JAX-RS resources
        processJaxRsTypes(context);

        // JAX-RS types required for headers, query params etc.
        addJaxRsConversions(context);

        /*
         *
         *  And finally register with native image
         *
         */
        registerForReflection(context);
    }

    @Override
    public void beforeCompilation(BeforeCompilationAccess access) {
        MpConfigProviderResolver.buildTimeEnd();
    }

    @Override
    public void duringSetup(DuringSetupAccess access) {
        new WeldFeature().duringSetup(access);
    }

    private void registerForReflection(BeforeAnalysisContext context) {
        Collection toRegister = context.toRegister();

        tracer.section(() -> "Registering " + toRegister.size() + " classes for reflection");

        // register for reflection
        for (Register register : toRegister) {
            // first validate if all fields are on classpath
            if (!register.validated) {
                register.validate();
            }
            // only register classes on the image classpath (not necessarily discovered by the scanning)
            if (register.valid) {
                register(register.clazz);

                if (!register.clazz.isInterface()) {
                    register.fields.forEach(this::register);
                    register.constructors.forEach(this::register);
                }

                register.methods.forEach(this::register);
            } else {
                tracer.trace(() -> register.clazz.getName() + " is not registered, as it had failed fields or superclass.");
            }
        }
    }

    private void register(Class clazz) {
        tracer.trace(() -> "Registering " + clazz.getName() + " for reflection");

        RuntimeReflection.register(clazz);
    }

    private void register(Field field) {
        tracer.trace(() -> "    "
                + Modifier.toString(field.getModifiers())
                + " " + typeToString(field.getGenericType())
                + " " + field.getName());

        RuntimeReflection.register(field);
    }

    private void register(Constructor constructor) {
        tracer.trace(() -> "    " + constructor.getDeclaringClass().getSimpleName()
                + "("
                + params(constructor.getParameterTypes())
                + ")");

        RuntimeReflection.register(constructor);
    }

    private void register(Method method) {
        tracer.trace(() -> "    "
                + Modifier.toString(method.getModifiers())
                + " " + typeToString(method.getGenericReturnType())
                + " " + method.getName()
                + "(" + params(method.getGenericParameterTypes()) + ")");

        RuntimeReflection.register(method);
    }

    private String typeToString(Type type) {
        if (type instanceof Class) {
            return ((Class) type).getName();
        } else {
            return type.toString();
        }
    }

    private String params(Type[] parameterTypes) {
        if (parameterTypes.length == 0) {
            return "";
        }
        return Arrays.stream(parameterTypes)
                .map(this::typeToString)
                .collect(Collectors.joining(", "));
    }

    private void addJaxRsConversions(BeforeAnalysisContext context) {
        addJaxRsConversions(context, "jakarta.ws.rs.QueryParam");
        addJaxRsConversions(context, "jakarta.ws.rs.PathParam");
        addJaxRsConversions(context, "jakarta.ws.rs.HeaderParam");
        addJaxRsConversions(context, "jakarta.ws.rs.MatrixParam");
        addJaxRsConversions(context, "jakarta.ws.rs.BeanParam");
    }

    private void addJaxRsConversions(BeforeAnalysisContext context, String annotation) {
        tracer.parsing(() -> "Looking up annotated by " + annotation);

        Set> allTypes = new HashSet<>();

        // we need fields and method parameters
        context.scan()
                .getClassesWithFieldAnnotation(annotation)
                .stream()
                .flatMap(theClass -> theClass.getFieldInfo().stream())
                .filter(field -> field.hasAnnotation(annotation))
                .map(fieldInfo -> util.getSimpleType(context.access()::findClassByName, fieldInfo))
                .filter(Objects::nonNull)
                .forEach(allTypes::add);

        // method annotations
        context.scan()
                .getClassesWithMethodParameterAnnotation(annotation)
                .stream()
                .flatMap(theClass -> theClass.getMethodInfo().stream())
                .flatMap(theMethod -> Stream.of(theMethod.getParameterInfo()))
                .filter(param -> param.hasAnnotation(annotation))
                .map(param -> util.getSimpleType(context.access()::findClassByName, param))
                .filter(Objects::nonNull)
                .forEach(allTypes::add);

        // now let's find all static methods `valueOf` and `fromString`
        for (Class type : allTypes) {
            try {
                Method valueOf = type.getDeclaredMethod("valueOf", String.class);
                RuntimeReflection.register(valueOf);
                tracer.parsing(() -> "Registering " + valueOf);
            } catch (NoSuchMethodException ignored) {
                try {
                    Method fromString = type.getDeclaredMethod("fromString", String.class);
                    RuntimeReflection.register(fromString);
                    tracer.parsing(() -> "Registering " + fromString);
                } catch (NoSuchMethodException ignored2) {
                }
            }
        }
    }

    private void processJaxRsTypes(BeforeAnalysisContext context) {
        tracer.parsing(() -> "Looking up JAX-RS resource methods.");

        new JaxRsMethodAnalyzer(context, util)
                .find()
                .forEach(it -> {
                    tracer.parsing(() -> " class " + it);
                    context.register(it).addAll();
                });
    }

    @SuppressWarnings("unchecked")
    private void processRegisterRestClient(BeforeAnalysisContext context) {

        Class restClientAnnotation = (Class) context.access()
                .findClassByName(AT_REGISTER_REST_CLIENT);

        if (null == restClientAnnotation) {
            return;
        }

        tracer.parsing(() -> "Looking up annotated by " + AT_REGISTER_REST_CLIENT);

        Set> annotatedSet = util.findAnnotated(AT_REGISTER_REST_CLIENT);
        Class autoCloseable = context.access().findClassByName("java.lang.AutoCloseable");
        Class closeable = context.access().findClassByName("java.io.Closeable");

        annotatedSet.forEach(it -> {
            if (context.isExcluded(it)) {
                tracer.parsing(() -> "Class " + it.getName() + " annotated by " + AT_REGISTER_REST_CLIENT + " is excluded");
            } else {
                // we need to add it for reflection
                processClassHierarchy(context, it);
                // and we also need to create a proxy
                tracer.parsing(() -> "Registering a proxy for class " + it.getName());
                RuntimeProxyCreation.register(it, autoCloseable, closeable);
            }
        });
    }

    private void processClassHierarchy(BeforeAnalysisContext context, Class superclass) {

        // this class is always registered (interface or class)
        context.register(superclass).addDefaults();

        tracer.parsing(() -> "Looking up implementors of " + superclass.getName());

        processSubClasses(context, superclass);

        util.findInterfaces(superclass)
                .forEach(it -> addSingleClass(context, it));
    }

    private void addSingleClass(BeforeAnalysisContext context,
                                Class theClass) {
        if (context.process(theClass)) {
            tracer.parsing(theClass::getName);
            tracer.parsing(() -> "  Added for registration");
            superclasses(context, theClass);
            context.register(theClass).addDefaults();
        }
    }

    private void processClasses(BeforeAnalysisContext context, Set> classes) {
        for (Class aClass : classes) {
            if (context.process(aClass)) {
                tracer.parsing(() -> "    " + aClass.getName());
                tracer.parsing(() -> "        Added for registration");

                superclasses(context, aClass);
                context.register(aClass).addDefaults();

                int modifiers = aClass.getModifiers();
                if (!Modifier.isFinal(modifiers)) {
                    processSubClasses(context, aClass);
                }
            }
        }
    }

    private void superclasses(BeforeAnalysisContext context, Class aClass) {
        Set> superclasses = util.findSuperclasses(aClass);
        for (Class superclass : superclasses) {
            if (context.process(superclass)) {
                tracer.parsing(superclass::getName);
                tracer.parsing(() -> "  Added for registration");
                context.register(superclass).addDefaults();
            }
        }
    }

    private void processSubClasses(BeforeAnalysisContext context, Class aClass) {
        Set> subclasses = util.findSubclasses(aClass.getName());

        processClasses(context, subclasses);
    }

    final class BeforeAnalysisContext {
        private final BeforeAnalysisAccess access;
        private final Set> processed = new HashSet<>();
        private final Set> excluded = new HashSet<>();
        private final Map, Register> registers = new HashMap<>();
        private final ScanResult scan;

        private BeforeAnalysisContext(BeforeAnalysisAccess access, ScanResult scan, Set> excluded) {
            this.access = access;
            this.scan = scan;
            this.excluded.addAll(excluded);
        }

        public boolean process(Class theClass) {
            return processed.add(theClass);
        }

        public Register register(Class theClass) {
            return registers.computeIfAbsent(theClass, Register::new);
        }

        public Collection toRegister() {
            return registers.values();
        }

        BeforeAnalysisAccess access() {
            return access;
        }

        ScanResult scan() {
            return scan;
        }

        boolean isExcluded(Class theClass) {
            return excluded.contains(theClass);
        }
    }

    private class Register {
        private final Set methods = new HashSet<>();
        private final Set fields = new HashSet<>();
        private final Set> constructors = new HashSet<>();

        private final Class clazz;

        private boolean validated;
        private boolean valid = true;

        private Register(Class clazz) {
            this.clazz = clazz;
        }

        void validate() {
            validated = true;
            validateTypeParams();
            if (!valid) {
                return;
            }
            addFields(true, true);
        }

        boolean add(Method m) {
            return methods.add(m);
        }

        boolean add(Field f) {
            return fields.add(f);
        }

        boolean add(Constructor c) {
            return constructors.add(c);
        }

        void addAll() {
            if (!validated) {
                validated = true;
                validateTypeParams();
            }
            if (!valid) {
                return;
            }
            addFields(true, false);
            if (!valid) {
                return;
            }
            addMethods();
            if (clazz.isInterface()) {
                return;
            }
            addConstructors();
        }

        void addDefaults() {
            validated = true;
            validateTypeParams();
            if (!valid) {
                return;
            }
            addFields(false, false);
            if (!valid) {
                return;
            }
            addMethods();
            if (clazz.isInterface()) {
                return;
            }
            addConstructors();
        }

        void addFields(boolean all, boolean validateOnly) {
            try {
                Field[] fields = clazz.getFields();
                // add all public fields
                for (Field field : fields) {
                    if (!validateOnly) {
                        add(field);
                    }
                }
            } catch (NoClassDefFoundError e) {
                this.valid = false;

                if (validateOnly) {
                    tracer.trace(() -> "Validation of fields of "
                            + clazz.getName()
                            + " failed, as a type is not on classpath: "
                            + e.getMessage());
                } else {
                    tracer.trace(() -> "Public fields of "
                            + clazz.getName()
                            + " not added to reflection, as a type is not on classpath: "
                            + e.getMessage());
                }

            }
            try {
                for (Field declaredField : clazz.getDeclaredFields()) {
                    // there may be fields referencing classes not on the classpath
                    if (!Modifier.isPublic(declaredField.getModifiers())) {
                        // public already registered
                        if (all || declaredField.getAnnotations().length > 0) {
                            if (!validateOnly) {
                                add(declaredField);
                            }
                        }
                    }
                }
            } catch (NoClassDefFoundError e) {
                this.valid = false;

                if (validateOnly) {
                    tracer.trace(() -> "Validation of fields of "
                            + clazz.getName()
                            + " failed, as a type is not on classpath: "
                            + e.getMessage());
                } else {
                    tracer.trace(() -> "Fields of "
                            + clazz.getName()
                            + " not added to reflection, as a type is not on classpath: "
                            + e.getMessage());
                }
            }
        }

        void addMethods() {
            try {
                Method[] methods = clazz.getMethods();
                for (Method method : methods) {
                    boolean register;

                    // we do not want wait, notify etc
                    register = (method.getDeclaringClass() != Object.class);

                    if (register) {
                        // we do not want toString(), hashCode(), equals(java.lang.Object)
                        switch (method.getName()) {
                        case "hashCode":
                        case "toString":
                            register = !util.hasParams(method);
                            break;
                        case "equals":
                            register = !util.hasParams(method, Object.class);
                            break;
                        default:
                            // do nothing
                        }
                    }

                    if (register) {
                        tracer.trace(() -> "  " + method.getName() + "(" + Arrays.toString(method.getParameterTypes()) + ")");

                        add(method);
                    }
                }
            } catch (Throwable e) {
                tracer.trace(() -> "   Cannot register methods of " + clazz.getName() + ": "
                        + e.getClass().getName() + ": " + e.getMessage());
            }
        }

        @SuppressWarnings("ResultOfMethodCallIgnored")
        private void validateTypeParams() {
            try {
                clazz.getGenericSuperclass();
            } catch (Exception e) {
                // this is now reported with each build, because ProtobufEncoder is part of netty codec
                tracer.parsing(() -> "Type parameter of superclass is not on classpath of "
                        + clazz.getName()
                        + " error: "
                        + e.getMessage());
                valid = false;
            }
        }

        private void addConstructors() {
            try {
                Constructor[] constructors = clazz.getConstructors();
                for (Constructor constructor : constructors) {
                    add(constructor);
                }
            } catch (NoClassDefFoundError e) {
                tracer.trace(() -> "Public constructors of "
                        + clazz.getName()
                        + " not added to reflection, as a type is not on classpath: "
                        + e.getMessage());
            }
            try {
                // add all declared
                Constructor[] constructors = clazz.getDeclaredConstructors();
                for (Constructor constructor : constructors) {
                    add(constructor);
                }
            } catch (NoClassDefFoundError e) {
                tracer.trace(() -> "Constructors of "
                        + clazz.getName()
                        + " not added to reflection, as a type is not on classpath: "
                        + e.getMessage());
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy