All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.datakernel.di.util.ReflectionUtils Maven / Gradle / Ivy

package io.datakernel.di.util;

import io.datakernel.di.annotation.Optional;
import io.datakernel.di.annotation.*;
import io.datakernel.di.core.*;
import io.datakernel.di.impl.BindingInitializer;
import io.datakernel.di.impl.BindingLocator;
import io.datakernel.di.impl.CompiledBinding;
import io.datakernel.di.module.BindingDesc;
import io.datakernel.di.module.Module;
import io.datakernel.di.module.ModuleBuilder;
import io.datakernel.di.module.ModuleBuilderBinder;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.lang.annotation.Annotation;
import java.lang.reflect.*;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Stream;

import static io.datakernel.di.core.Name.uniqueName;
import static java.util.Collections.singleton;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

/**
 * These are various reflection utilities that are used by the DSL.
 * While you should not use them normally, they are pretty well organized and thus are left public.
 */
public final class ReflectionUtils {
	private static final String IDENT = "\\p{javaJavaIdentifierStart}\\p{javaJavaIdentifierPart}*";

	public static String getShortName(Type type) {
		String defaultName = type.getTypeName()
				.replaceAll("(?:" + IDENT + "\\.)*(?:" + IDENT + "\\$\\d*)?", "");
		ShortTypeName override = Types.getRawType(type).getDeclaredAnnotation(ShortTypeName.class);
		return override != null ?
				defaultName.replaceAll("^" + IDENT, override.value()) :
				defaultName;
	}

	@Nullable
	public static Name nameOf(AnnotatedElement annotatedElement) {
		Set names = Arrays.stream(annotatedElement.getDeclaredAnnotations())
				.filter(annotation -> annotation.annotationType().isAnnotationPresent(NameAnnotation.class))
				.collect(toSet());
		switch (names.size()) {
			case 0:
				return null;
			case 1:
				return Name.of(names.iterator().next());
			default:
				throw new DIException("More than one name annotation on " + annotatedElement);
		}
	}

	public static Set keySetsOf(AnnotatedElement annotatedElement) {
		return Arrays.stream(annotatedElement.getDeclaredAnnotations())
				.filter(annotation -> annotation.annotationType().isAnnotationPresent(KeySetAnnotation.class))
				.map(Name::of)
				.collect(toSet());
	}

	public static  Key keyOf(@Nullable Type container, Type type, AnnotatedElement annotatedElement) {
		Type resolved = container != null ? Types.resolveTypeVariables(type, container) : type;
		return Key.ofType(resolved, nameOf(annotatedElement));
	}

	public static Scope[] getScope(AnnotatedElement annotatedElement) {
		Annotation[] annotations = annotatedElement.getDeclaredAnnotations();

		Set scopes = Arrays.stream(annotations)
				.filter(annotation -> annotation.annotationType().isAnnotationPresent(ScopeAnnotation.class))
				.collect(toSet());

		Scopes nested = (Scopes) Arrays.stream(annotations)
				.filter(annotation -> annotation.annotationType() == Scopes.class)
				.findAny()
				.orElse(null);

		if (nested != null) {
			if (scopes.isEmpty()) {
				return Arrays.stream(nested.value()).map(Scope::of).toArray(Scope[]::new);
			}
			throw new DIException("Cannot have both @Scoped and a scope annotation on " + annotatedElement);
		}
		switch (scopes.size()) {
			case 0:
				return Scope.UNSCOPED;
			case 1:
				return new Scope[]{Scope.of(scopes.iterator().next())};
			default:
				throw new DIException("More than one scope annotation on " + annotatedElement);
		}
	}

	public static  List getAnnotatedElements(Class cls,
			Class annotationType, Function, T[]> extractor, boolean allowStatic) {

		List result = new ArrayList<>();
		while (cls != null) {
			for (T element : extractor.apply(cls)) {
				if (element.isAnnotationPresent(annotationType)) {
					if (!allowStatic && Modifier.isStatic(element.getModifiers())) {
						throw new DIException("@" + annotationType.getSimpleName() + " annotation is not allowed on " + element);
					}
					result.add(element);
				}
			}
			cls = cls.getSuperclass();
		}
		return result;
	}

	public static  Binding generateImplicitBinding(Key key) {
		Binding binding = generateConstructorBinding(key);
		return binding != null ?
				binding.initializeWith(generateInjectingInitializer(key)) :
				null;
	}

	@SuppressWarnings("unchecked")
	@Nullable
	public static  Binding generateConstructorBinding(Key key) {
		Class cls = key.getRawType();

		Inject classInjectAnnotation = cls.getAnnotation(Inject.class);
		Set> injectConstructors = Arrays.stream(cls.getDeclaredConstructors())
				.filter(c -> c.isAnnotationPresent(Inject.class))
				.collect(toSet());
		Set factoryMethods = Arrays.stream(cls.getDeclaredMethods())
				.filter(method -> method.isAnnotationPresent(Inject.class)
						&& method.getReturnType() == cls
						&& Modifier.isStatic(method.getModifiers()))
				.collect(toSet());

		if (classInjectAnnotation != null) {
			if (!injectConstructors.isEmpty()) {
				throw failedImplicitBinding(key, "inject annotation on class with inject constructor");
			}
			if (!factoryMethods.isEmpty()) {
				throw failedImplicitBinding(key, "inject annotation on class with inject factory method");
			}
			Class enclosingClass = cls.getEnclosingClass();
			if (enclosingClass != null && !Modifier.isStatic(cls.getModifiers())) {
				try {
					return bindingFromConstructor(key, (Constructor) cls.getDeclaredConstructor(enclosingClass));
				} catch (NoSuchMethodException e) {
					throw failedImplicitBinding(key, "inject annotation on local class that closes over outside variables and/or has no default constructor");
				}
			}
			try {
				return bindingFromConstructor(key, (Constructor) cls.getDeclaredConstructor());
			} catch (NoSuchMethodException e) {
				throw failedImplicitBinding(key, "inject annotation on class with no default constructor");
			}
		} else {
			if (injectConstructors.size() > 1) {
				throw failedImplicitBinding(key, "more than one inject constructor");
			}
			if (!injectConstructors.isEmpty()) {
				if (!factoryMethods.isEmpty()) {
					throw failedImplicitBinding(key, "both inject constructor and inject factory method are present");
				}
				return bindingFromConstructor(key, (Constructor) injectConstructors.iterator().next());
			}
		}

		if (factoryMethods.size() > 1) {
			throw failedImplicitBinding(key, "more than one inject factory method");
		}
		if (!factoryMethods.isEmpty()) {
			return bindingFromMethod(null, factoryMethods.iterator().next());
		}
		return null;
	}

	private static DIException failedImplicitBinding(Key requestedKey, String message) {
		return new DIException("Failed to generate implicit binding for " + requestedKey.getDisplayString() + ", " + message);
	}

	public static  BindingInitializer generateInjectingInitializer(Key container) {
		Class rawType = container.getRawType();
		List> initializers = Stream.concat(
				getAnnotatedElements(rawType, Inject.class, Class::getDeclaredFields, false).stream()
						.map(field -> fieldInjector(container, field, !field.isAnnotationPresent(Optional.class))),
				getAnnotatedElements(rawType, Inject.class, Class::getDeclaredMethods, true).stream()
						.filter(method -> !Modifier.isStatic(method.getModifiers())) // we allow them and just filter out to allow static factory methods
						.map(method -> methodInjector(container, method)))
				.collect(toList());
		return BindingInitializer.combine(initializers);
	}

	public static  BindingInitializer fieldInjector(Key container, Field field, boolean required) {
		field.setAccessible(true);
		Key key = keyOf(container.getType(), field.getGenericType(), field);
		return BindingInitializer.of(
				singleton(Dependency.toKey(key, required)),
				compiledBindings -> {
					CompiledBinding binding = compiledBindings.get(key);
					return (instance, instances, synchronizedScope) -> {
						Object arg = binding.getInstance(instances, synchronizedScope);
						if (arg == null) {
							return;
						}
						try {
							field.set(instance, arg);
						} catch (IllegalAccessException e) {
							throw new DIException("Not allowed to set injectable field " + field, e);
						}
					};
				});
	}

	public static  BindingInitializer methodInjector(Key container, Method method) {
		method.setAccessible(true);
		Dependency[] dependencies = toDependencies(container.getType(), method.getParameters());
		return BindingInitializer.of(
				Stream.of(dependencies).collect(toSet()),
				compiledBindings -> {
					CompiledBinding[] argBindings = Stream.of(dependencies)
							.map(dependency -> compiledBindings.get(dependency.getKey()))
							.toArray(CompiledBinding[]::new);
					return (instance, instances, synchronizedScope) -> {
						Object[] args = new Object[argBindings.length];
						for (int i = 0; i < argBindings.length; i++) {
							args[i] = argBindings[i].getInstance(instances, synchronizedScope);
						}
						try {
							method.invoke(instance, args);
						} catch (IllegalAccessException e) {
							throw new DIException("Not allowed to call injectable method " + method, e);
						} catch (InvocationTargetException e) {
							throw new DIException("Failed to call injectable method " + method, e.getCause());
						}
					};
				});
	}

	@NotNull
	public static Dependency[] toDependencies(@Nullable Type container, Parameter[] parameters) {
		Dependency[] dependencies = new Dependency[parameters.length];
		if (parameters.length == 0) {
			return dependencies;
		}
		// an actual JDK bug (fixed in Java 9)
		boolean workaround = parameters[0].getDeclaringExecutable().getParameterAnnotations().length != parameters.length;
		for (int i = 0; i < dependencies.length; i++) {
			Type type = parameters[i].getParameterizedType();
			Parameter parameter = parameters[workaround && i != 0 ? i - 1 : i];
			dependencies[i] = Dependency.toKey(keyOf(container, type, parameter), !parameter.isAnnotationPresent(Optional.class));
		}
		return dependencies;
	}

	@SuppressWarnings("unchecked")
	public static  Binding bindingFromMethod(@Nullable Object module, Method method) {
		method.setAccessible(true);

		Binding binding = Binding.to(
				args -> {
					try {
						return (T) method.invoke(module, args);
					} catch (IllegalAccessException e) {
						throw new DIException("Not allowed to call method " + method, e);
					} catch (InvocationTargetException e) {
						throw new DIException("Failed to call method " + method, e.getCause());
					}
				},
				toDependencies(module != null ? module.getClass() : method.getDeclaringClass(), method.getParameters()));

		return module != null ? binding.at(LocationInfo.from(module, method)) : binding;
	}

	@SuppressWarnings("unchecked")
	public static  Binding bindingFromGenericMethod(@Nullable Object module, Key requestedKey, Method method) {
		method.setAccessible(true);

		Type genericReturnType = method.getGenericReturnType();
		Map, Type> mapping = Types.extractMatchingGenerics(genericReturnType, requestedKey.getType());

		Dependency[] dependencies = Arrays.stream(method.getParameters())
				.map(parameter -> {
					Type type = Types.resolveTypeVariables(parameter.getParameterizedType(), mapping);
					Name name = nameOf(parameter);
					return Dependency.toKey(Key.ofType(type, name), !parameter.isAnnotationPresent(Optional.class));
				})
				.toArray(Dependency[]::new);

		Binding binding = Binding.to(
				args -> {
					try {
						return (T) method.invoke(module, args);
					} catch (IllegalAccessException e) {
						throw new DIException("Not allowed to call generic method " + method + " to provide requested key " + requestedKey, e);
					} catch (InvocationTargetException e) {
						throw new DIException("Failed to call generic method " + method + " to provide requested key " + requestedKey, e.getCause());
					}
				},
				dependencies);
		return module != null ? binding.at(LocationInfo.from(module, method)) : binding;
	}

	public static  Binding bindingFromConstructor(Key key, Constructor constructor) {
		constructor.setAccessible(true);

		Dependency[] dependencies = toDependencies(key.getType(), constructor.getParameters());

		return Binding.to(
				args -> {
					try {
						return constructor.newInstance(args);
					} catch (InstantiationException e) {
						throw new DIException("Cannot instantiate object from the constructor " + constructor + " to provide requested key " + key, e);
					} catch (IllegalAccessException e) {
						throw new DIException("Not allowed to call constructor " + constructor + " to provide requested key " + key, e);
					} catch (InvocationTargetException e) {
						throw new DIException("Failed to call constructor " + constructor + " to provide requested key " + key, e.getCause());
					}
				},
				dependencies);
	}

	public static class ProviderScanResults {
		private final List bindingDescs;
		private final Map, Set>> bindingGenerators;
		private final Map, Multibinder> multibinders;

		public ProviderScanResults(List bindingDescs, Map, Set>> bindingGenerators, Map, Multibinder> multibinders) {
			this.bindingDescs = bindingDescs;
			this.bindingGenerators = bindingGenerators;
			this.multibinders = multibinders;
		}

		public List getBindingDescs() {
			return bindingDescs;
		}

		public Map, Set>> getBindingGenerators() {
			return bindingGenerators;
		}

		public Map, Multibinder> getMultibinders() {
			return multibinders;
		}
	}

	public static Module scanClass(@NotNull Class moduleClass, @Nullable Object module) {
		return scanClassInto(moduleClass, module, Module.create());
	}

	public static Module scanClassInto(@NotNull Class moduleClass, @Nullable Object module, ModuleBuilder builder) {
		for (Method method : moduleClass.getDeclaredMethods()) {
			if (method.isAnnotationPresent(Provides.class)) {
				if (module == null && !Modifier.isStatic(method.getModifiers())) {
					throw new IllegalStateException("Found non-static provider method while scanning for statics");
				}

				Name name = nameOf(method);
				Set keySets = keySetsOf(method);
				Scope[] methodScope = getScope(method);
				boolean exported = method.isAnnotationPresent(Export.class);

				Type returnType = Types.resolveTypeVariables(method.getGenericReturnType(), module != null ? module.getClass() : moduleClass);
				TypeVariable[] typeVars = method.getTypeParameters();

				if (typeVars.length == 0) {
					Key key = Key.ofType(returnType, name);

					ModuleBuilderBinder binder = builder.bind(key).to(bindingFromMethod(module, method)).in(methodScope);
					keySets.forEach(binder::as);
					if (exported) {
						binder.export();
					}
					continue;
				}
				Set> unused = Arrays.stream(typeVars)
						.filter(typeVar -> !Types.contains(returnType, typeVar))
						.collect(toSet());
				if (!unused.isEmpty()) {
					throw new IllegalStateException("Generic type variables " + unused + " are not used in return type of templated provider method " + method);
				}
				if (!keySets.isEmpty()) {
					throw new IllegalStateException("Key set annotations are not supported by templated methods, method " + method);
				}
				if (exported) {
					throw new IllegalStateException("@Export annotation is not applicable for templated methods because they are generators and thus are always exported");
				}

				builder.generate(method.getReturnType(), new TemplatedProviderGenerator(methodScope, name, method, module, returnType));

			} else if (method.isAnnotationPresent(ProvidesIntoSet.class)) {
				if (module == null && !Modifier.isStatic(method.getModifiers())) {
					throw new IllegalStateException("Found non-static provider method while scanning for statics");
				}
				if (method.getTypeParameters().length != 0) {
					throw new IllegalStateException("@ProvidesIntoSet does not support templated methods, method " + method);
				}

				Type type = Types.resolveTypeVariables(method.getGenericReturnType(), module != null ? module.getClass() : moduleClass);
				Scope[] methodScope = getScope(method);
				Set keySets = keySetsOf(method);
				boolean exported = method.isAnnotationPresent(Export.class);

				Key key = Key.ofType(type, uniqueName());
				ModuleBuilderBinder binder = builder.bind(key).to(bindingFromMethod(module, method)).in(methodScope);
				keySets.forEach(binder::as);

				Key> setKey = Key.ofType(Types.parameterized(Set.class, type), nameOf(method));
				Binding> binding = Binding.to(Collections::singleton, key);
				if (module != null) {
					binding.at(LocationInfo.from(module, method));
				}

				ModuleBuilderBinder> setBinder = builder.bind(setKey).to(binding).in(methodScope);
				keySets.forEach(setBinder::as);
				if (exported) {
					setBinder.export();
				}
				builder.multibind(setKey, Multibinder.toSet());
			}
		}

		return builder;
	}

	public static Map, Module> scanClassHierarchy(@NotNull Class moduleClass, @Nullable Object module) {
		Map, Module> result = new HashMap<>();
		Class cls = moduleClass;
		while (cls != Object.class && cls != null) {
			result.put(cls, scanClass(cls, module));
			cls = cls.getSuperclass();
		}
		return result;
	}

	private static class TemplatedProviderGenerator implements BindingGenerator {
		private final Scope[] methodScope;
		@Nullable
		private final Name name;
		private final Method method;

		private final Object module;
		private final Type returnType;

		private TemplatedProviderGenerator(Scope[] methodScope, @Nullable Name name, Method method, Object module, Type returnType) {
			this.methodScope = methodScope;
			this.name = name;
			this.method = method;
			this.module = module;
			this.returnType = returnType;
		}

		@Override
		public @Nullable Binding generate(BindingLocator bindings, Scope[] scope, Key key) {
			if (scope.length < methodScope.length || (name != null && !name.equals(key.getName())) || !Types.matches(key.getType(), returnType)) {
				return null;
			}
			for (int i = 0; i < methodScope.length; i++) {
				if (!scope[i].equals(methodScope[i])) {
					return null;
				}
			}
			return bindingFromGenericMethod(module, key, method);
		}

		@Override
		public boolean equals(Object o) {
			if (this == o) return true;
			if (o == null || getClass() != o.getClass()) return false;

			TemplatedProviderGenerator generator = (TemplatedProviderGenerator) o;

			if (!Arrays.equals(methodScope, generator.methodScope)) return false;
			if (!Objects.equals(name, generator.name)) return false;
			return method.equals(generator.method);
		}

		@Override
		public int hashCode() {
			return 961 * Arrays.hashCode(methodScope) + 31 * (name != null ? name.hashCode() : 0) + method.hashCode();
		}
	}
}