All Downloads are FREE. Search and download functionalities are using the official Maven repository.

name.remal.reflection.ClassLoaderUtils Maven / Gradle / Ivy

package name.remal.reflection;

import static java.lang.Math.max;
import static java.lang.Thread.currentThread;
import static name.remal.SneakyThrow.sneakyThrow;
import static name.remal.UncheckedCast.uncheckedCast;
import static name.remal.reflection.ExtendedURLClassLoader.LoadingOrder.PARENT_ONLY;
import static name.remal.reflection.ExtendedURLClassLoader.LoadingOrder.THIS_ONLY;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.CodeSource;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import name.remal.lambda.Function1;
import name.remal.lambda.VoidFunction1;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public class ClassLoaderUtils {

    static class ClassLoaderWrapper extends ClassLoader {
        public ClassLoaderWrapper(@NotNull ClassLoader classLoader) {
            super(classLoader);
        }

        @Nullable
        public Package getPackageOrNull(@NotNull String name) {
            return this.getPackage(name);
        }
    }

    @Nullable
    public static Package getPackageOrNull(@NotNull ClassLoader classLoader, @NotNull String packageName) {
        return new ClassLoaderWrapper(classLoader).getPackageOrNull(packageName);
    }

    private static final Method ADD_URL_METHOD;

    static {
        try {
            ADD_URL_METHOD = URLClassLoader.class.getDeclaredMethod("addURL", URL.class);
            ADD_URL_METHOD.setAccessible(true);
        } catch (NoSuchMethodException e) {
            throw sneakyThrow(e);
        }
    }

    @SuppressFBWarnings("squid:S2445")
    public static void addURLsToClassLoader(@NotNull ClassLoader classLoader, @NotNull URL... urls) {
        if (0 == urls.length) return;

        ClassLoader systemClassLoader = ClassLoader.getSystemClassLoader();
        do {
            if (classLoader instanceof URLClassLoader) {
                synchronized (classLoader) {
                    for (URL url : urls) {
                        try {
                            ADD_URL_METHOD.invoke(classLoader, url);
                        } catch (@NotNull IllegalAccessException | InvocationTargetException e) {
                            throw sneakyThrow(e);
                        }
                    }
                }
                return;
            }
            if (systemClassLoader == classLoader) break;
            classLoader = classLoader.getParent();
            if (classLoader == null) classLoader = systemClassLoader;
        } while (true);

        throw new IllegalStateException("New URL can't be added to system ClassLoader: " + systemClassLoader);
    }

    public static  R forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class implementationType, @NotNull Function1 action) {
        return forInstantiated(classLoader, type, implementationType, new InstantiatedClassesPropagation(), action);
    }

    public static  void forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class implementationType, @NotNull VoidFunction1 action) {
        forInstantiated(classLoader, type, implementationType, it -> {
            action.invoke(it);
            return null;
        });
    }

    public static  R forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class implementationType, @NotNull InstantiatedClassesPropagation propagation, @NotNull Function1 action) {
        if (!type.isAssignableFrom(implementationType) || type == implementationType) throw new IllegalArgumentException(implementationType + " is not subtype of " + type);

        propagation.addClassName(implementationType);
        List propagatedClassInternalNames = propagation.getClassInternalNames();
        List propagatedPackageInternalNames = propagation.getPackageInternalNames();

        URL sourceURL = Optional.ofNullable(implementationType.getProtectionDomain()).map(ProtectionDomain::getCodeSource).map(CodeSource::getLocation).orElseThrow(() -> new IllegalStateException(implementationType + ": protectionDomain?.codeSource?.location == null"));
        try (URLClassLoader childClassLoader = new ExtendedURLClassLoader(
            resourceName -> {
                for (String propagatedPackageInternalName : propagatedPackageInternalNames) {
                    if (resourceName.startsWith(propagatedPackageInternalName + '/')) {
                        return THIS_ONLY;
                    }
                }
                for (String propagatedClassInternalName : propagatedClassInternalNames) {
                    if (resourceName.equals(propagatedClassInternalName + ".class")
                        || (resourceName.startsWith(propagatedClassInternalName + '$') && resourceName.endsWith(".class"))
                    ) {
                        return THIS_ONLY;
                    }
                }
                return PARENT_ONLY;
            },
            new URL[]{sourceURL},
            classLoader
        )) {

            Thread currentThread = currentThread();
            ClassLoader prevContextClassLoader = currentThread.getContextClassLoader();
            currentThread.setContextClassLoader(classLoader);
            try {
                T implementation = uncheckedCast(childClassLoader.loadClass(implementationType.getName()).newInstance());
                return action.invoke(implementation);

            } finally {
                currentThread.setContextClassLoader(prevContextClassLoader);
            }

        } catch (Throwable e) {
            throw sneakyThrow(e);
        }
    }

    public static  void forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class implementationType, @NotNull InstantiatedClassesPropagation propagation, @NotNull VoidFunction1 action) {
        forInstantiated(classLoader, type, implementationType, propagation, it -> {
            action.invoke(it);
            return null;
        });
    }

    public static  R forInstantiatedWithPropagatedPackage(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class implementationType, @NotNull Function1 action) {
        return forInstantiated(classLoader, type, implementationType, new InstantiatedClassesPropagation().addPackageName(implementationType), action);
    }

    public static  void forInstantiatedWithPropagatedPackage(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class implementationType, @NotNull VoidFunction1 action) {
        forInstantiatedWithPropagatedPackage(classLoader, type, implementationType, it -> {
            action.invoke(it);
            return null;
        });
    }

    public static class InstantiatedClassesPropagation {

        @NotNull
        private final Set<@NotNull String> classInternalNames = new TreeSet<>();

        @NotNull
        public List<@NotNull String> getClassInternalNames() {
            return new ArrayList<>(classInternalNames);
        }

        @NotNull
        public InstantiatedClassesPropagation addClassName(@NotNull String className) {
            classInternalNames.add(className.replace('.', '/'));
            return this;
        }

        @NotNull
        public InstantiatedClassesPropagation addClassNames(@NotNull String... classNames) {
            for (String className : classNames) {
                addClassName(className);
            }
            return this;
        }

        @NotNull
        public InstantiatedClassesPropagation addClassNames(@NotNull Iterable classNames) {
            for (String className : classNames) {
                addClassName(className);
            }
            return this;
        }

        @NotNull
        public InstantiatedClassesPropagation addClassName(@NotNull Class clazz) {
            return addClassName(clazz.getName());
        }

        @NotNull
        public InstantiatedClassesPropagation addClassNames(@NotNull Class... classes) {
            for (Class clazz : classes) {
                addClassName(clazz);
            }
            return this;
        }

        @NotNull
        private final Set<@NotNull String> packageInternalNames = new TreeSet<>();

        @NotNull
        public List<@NotNull String> getPackageInternalNames() {
            return new ArrayList<>(packageInternalNames);
        }

        @NotNull
        public InstantiatedClassesPropagation addPackageName(@NotNull String packageName) {
            packageInternalNames.add(packageName.replace('.', '/'));
            return this;
        }

        @NotNull
        public InstantiatedClassesPropagation addPackageNames(@NotNull String... packageNames) {
            for (String packageName : packageNames) {
                addPackageName(packageName);
            }
            return this;
        }

        @NotNull
        public InstantiatedClassesPropagation addPackageNames(@NotNull Iterable packageNames) {
            for (String packageName : packageNames) {
                addPackageName(packageName);
            }
            return this;
        }

        @NotNull
        public InstantiatedClassesPropagation addPackageName(@NotNull Class packageClass) {
            String packageClassName = packageClass.getName();
            return addPackageName(packageClassName.substring(0, max(0, packageClassName.lastIndexOf('.'))));
        }

        @NotNull
        public InstantiatedClassesPropagation addPackageNames(@NotNull Class... packageClasses) {
            for (Class packageClass : packageClasses) {
                addPackageName(packageClass);
            }
            return this;
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy