
name.remal.reflection.ClassLoaderUtils Maven / Gradle / Ivy
package name.remal.reflection;
import static java.lang.Math.max;
import static java.lang.Thread.currentThread;
import static name.remal.SneakyThrow.sneakyThrow;
import static name.remal.UncheckedCast.uncheckedCast;
import static name.remal.reflection.ExtendedURLClassLoader.LoadingOrder.PARENT_ONLY;
import static name.remal.reflection.ExtendedURLClassLoader.LoadingOrder.THIS_ONLY;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.CodeSource;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import name.remal.lambda.Function1;
import name.remal.lambda.VoidFunction1;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class ClassLoaderUtils {
static class ClassLoaderWrapper extends ClassLoader {
public ClassLoaderWrapper(@NotNull ClassLoader classLoader) {
super(classLoader);
}
@Nullable
public Package getPackageOrNull(@NotNull String name) {
return this.getPackage(name);
}
}
@Nullable
public static Package getPackageOrNull(@NotNull ClassLoader classLoader, @NotNull String packageName) {
return new ClassLoaderWrapper(classLoader).getPackageOrNull(packageName);
}
private static final Method ADD_URL_METHOD;
static {
try {
ADD_URL_METHOD = URLClassLoader.class.getDeclaredMethod("addURL", URL.class);
ADD_URL_METHOD.setAccessible(true);
} catch (NoSuchMethodException e) {
throw sneakyThrow(e);
}
}
@SuppressFBWarnings("squid:S2445")
public static void addURLsToClassLoader(@NotNull ClassLoader classLoader, @NotNull URL... urls) {
if (0 == urls.length) return;
ClassLoader systemClassLoader = ClassLoader.getSystemClassLoader();
do {
if (classLoader instanceof URLClassLoader) {
synchronized (classLoader) {
for (URL url : urls) {
try {
ADD_URL_METHOD.invoke(classLoader, url);
} catch (@NotNull IllegalAccessException | InvocationTargetException e) {
throw sneakyThrow(e);
}
}
}
return;
}
if (systemClassLoader == classLoader) break;
classLoader = classLoader.getParent();
if (classLoader == null) classLoader = systemClassLoader;
} while (true);
throw new IllegalStateException("New URL can't be added to system ClassLoader: " + systemClassLoader);
}
public static R forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class extends T> implementationType, @NotNull Function1 action) {
return forInstantiated(classLoader, type, implementationType, new InstantiatedClassesPropagation(), action);
}
public static void forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class extends T> implementationType, @NotNull VoidFunction1 action) {
forInstantiated(classLoader, type, implementationType, it -> {
action.invoke(it);
return null;
});
}
public static R forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class extends T> implementationType, @NotNull InstantiatedClassesPropagation propagation, @NotNull Function1 action) {
if (!type.isAssignableFrom(implementationType) || type == implementationType) throw new IllegalArgumentException(implementationType + " is not subtype of " + type);
propagation.addClassName(implementationType);
List propagatedClassInternalNames = propagation.getClassInternalNames();
List propagatedPackageInternalNames = propagation.getPackageInternalNames();
URL sourceURL = Optional.ofNullable(implementationType.getProtectionDomain()).map(ProtectionDomain::getCodeSource).map(CodeSource::getLocation).orElseThrow(() -> new IllegalStateException(implementationType + ": protectionDomain?.codeSource?.location == null"));
try (URLClassLoader childClassLoader = new ExtendedURLClassLoader(
resourceName -> {
for (String propagatedPackageInternalName : propagatedPackageInternalNames) {
if (resourceName.startsWith(propagatedPackageInternalName + '/')) {
return THIS_ONLY;
}
}
for (String propagatedClassInternalName : propagatedClassInternalNames) {
if (resourceName.equals(propagatedClassInternalName + ".class")
|| (resourceName.startsWith(propagatedClassInternalName + '$') && resourceName.endsWith(".class"))
) {
return THIS_ONLY;
}
}
return PARENT_ONLY;
},
new URL[]{sourceURL},
classLoader
)) {
Thread currentThread = currentThread();
ClassLoader prevContextClassLoader = currentThread.getContextClassLoader();
currentThread.setContextClassLoader(classLoader);
try {
T implementation = uncheckedCast(childClassLoader.loadClass(implementationType.getName()).newInstance());
return action.invoke(implementation);
} finally {
currentThread.setContextClassLoader(prevContextClassLoader);
}
} catch (Throwable e) {
throw sneakyThrow(e);
}
}
public static void forInstantiated(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class extends T> implementationType, @NotNull InstantiatedClassesPropagation propagation, @NotNull VoidFunction1 action) {
forInstantiated(classLoader, type, implementationType, propagation, it -> {
action.invoke(it);
return null;
});
}
public static R forInstantiatedWithPropagatedPackage(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class extends T> implementationType, @NotNull Function1 action) {
return forInstantiated(classLoader, type, implementationType, new InstantiatedClassesPropagation().addPackageName(implementationType), action);
}
public static void forInstantiatedWithPropagatedPackage(@Nullable ClassLoader classLoader, @NotNull Class type, @NotNull Class extends T> implementationType, @NotNull VoidFunction1 action) {
forInstantiatedWithPropagatedPackage(classLoader, type, implementationType, it -> {
action.invoke(it);
return null;
});
}
public static class InstantiatedClassesPropagation {
@NotNull
private final Set<@NotNull String> classInternalNames = new TreeSet<>();
@NotNull
public List<@NotNull String> getClassInternalNames() {
return new ArrayList<>(classInternalNames);
}
@NotNull
public InstantiatedClassesPropagation addClassName(@NotNull String className) {
classInternalNames.add(className.replace('.', '/'));
return this;
}
@NotNull
public InstantiatedClassesPropagation addClassNames(@NotNull String... classNames) {
for (String className : classNames) {
addClassName(className);
}
return this;
}
@NotNull
public InstantiatedClassesPropagation addClassNames(@NotNull Iterable classNames) {
for (String className : classNames) {
addClassName(className);
}
return this;
}
@NotNull
public InstantiatedClassesPropagation addClassName(@NotNull Class> clazz) {
return addClassName(clazz.getName());
}
@NotNull
public InstantiatedClassesPropagation addClassNames(@NotNull Class>... classes) {
for (Class> clazz : classes) {
addClassName(clazz);
}
return this;
}
@NotNull
private final Set<@NotNull String> packageInternalNames = new TreeSet<>();
@NotNull
public List<@NotNull String> getPackageInternalNames() {
return new ArrayList<>(packageInternalNames);
}
@NotNull
public InstantiatedClassesPropagation addPackageName(@NotNull String packageName) {
packageInternalNames.add(packageName.replace('.', '/'));
return this;
}
@NotNull
public InstantiatedClassesPropagation addPackageNames(@NotNull String... packageNames) {
for (String packageName : packageNames) {
addPackageName(packageName);
}
return this;
}
@NotNull
public InstantiatedClassesPropagation addPackageNames(@NotNull Iterable packageNames) {
for (String packageName : packageNames) {
addPackageName(packageName);
}
return this;
}
@NotNull
public InstantiatedClassesPropagation addPackageName(@NotNull Class> packageClass) {
String packageClassName = packageClass.getName();
return addPackageName(packageClassName.substring(0, max(0, packageClassName.lastIndexOf('.'))));
}
@NotNull
public InstantiatedClassesPropagation addPackageNames(@NotNull Class>... packageClasses) {
for (Class> packageClass : packageClasses) {
addPackageName(packageClass);
}
return this;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy