All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.cognifide.qa.bb.junit5.guice.GuiceExtension Maven / Gradle / Ivy

/*-
 * #%L
 * Bobcat
 * %%
 * Copyright (C) 2018 Cognifide Ltd.
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package com.cognifide.qa.bb.junit5.guice;

import static com.cognifide.qa.bb.junit5.JUnit5Constants.NAMESPACE;
import static java.util.stream.Collectors.toSet;
import static org.junit.platform.commons.support.AnnotationSupport.findAnnotation;

import com.google.common.collect.Sets;
import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.junit.jupiter.api.extension.TestInstancePostProcessor;

/**
 * Extension that will start guice and is responsible for the injections to test instance Based on
 * Guice
 * Extension
 */
public class GuiceExtension implements TestInstancePostProcessor {

  @Override
  public void postProcessTestInstance(Object testInstance, ExtensionContext context)
      throws Exception {

    getOrCreateInjector(context).ifPresent(injector -> injector.injectMembers(testInstance));

  }

  /**
   * Create {@link Injector} or get existing one from test context
   */
  private static Optional getOrCreateInjector(ExtensionContext context)
      throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {

    Optional optionalAnnotatedElement = context.getElement();
    if (!optionalAnnotatedElement.isPresent()) {
      return Optional.empty();
    }

    AnnotatedElement element = optionalAnnotatedElement.get();
    Store store = context.getStore(NAMESPACE);

    Injector injector = store.get(element, Injector.class);
    if (injector == null) {
      injector = createInjector(context);
      store.put(element, injector);
    }

    return Optional.of(injector);
  }

  /**
   * Creates {@link Injector} from test context
   */
  private static Injector createInjector(ExtensionContext context)
      throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
    Optional parentInjector = getParentInjector(context);
    List modules = getNewModules(context);

    return parentInjector
        .map(injector -> injector.createChildInjector(modules))
        .orElseGet(() -> Guice.createInjector(modules));
  }

  /**
   * Retrieves {@link Injector} from parent test context
   */
  private static Optional getParentInjector(ExtensionContext context)
      throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
    final Optional optionalParent = context.getParent();
    if (optionalParent.isPresent()) {
      return getOrCreateInjector(optionalParent.get());
    }
    return Optional.empty();
  }

  /**
   * Gets all new {@link Module} instances for injections
   */
  private static List getNewModules(ExtensionContext context)
      throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
    Set> moduleTypes = getNewModuleTypes(context);
    List modules = new ArrayList<>(moduleTypes.size());
    for (Class moduleType : moduleTypes) {
      Constructor moduleCtor = moduleType.getDeclaredConstructor();
      moduleCtor.setAccessible(true);

      modules.add(moduleCtor.newInstance());
    }

    context.getElement().ifPresent(element -> {
      if (element instanceof Class) {
        modules.add(new AbstractModule() {
          @Override
          protected void configure() {
            requestStaticInjection((Class) element);
          }
        });
      }
    });

    return modules;
  }

  /**
   * Returns all new {@link Module} declared in current context (returns empty set if modules where
   * already declared)
   */
  private static Set> getNewModuleTypes(ExtensionContext context) {
    Optional optionalAnnotatedElement = context.getElement();
    if (!optionalAnnotatedElement.isPresent()) {
      return Collections.emptySet();
    }

    Set> moduleTypes = getModuleTypes(optionalAnnotatedElement.get());
    context.getParent()
        .map(GuiceExtension::getContextModuleTypes)
        .ifPresent(moduleTypes::removeAll);

    return moduleTypes;
  }

  private static Set> getContextModuleTypes(ExtensionContext context) {
    return getContextModuleTypes(Optional.of(context));
  }

  /**
   * Returns module types that are present on the given context or any of its enclosing contexts.
   */
  private static Set> getContextModuleTypes(
      Optional context) {

    Set> contextModuleTypes = new LinkedHashSet<>();
    while (context.isPresent() && (hasAnnotatedElement(context) || hasParent(context))) {
      context
          .flatMap(ExtensionContext::getElement)
          .map(GuiceExtension::getModuleTypes)
          .ifPresent(contextModuleTypes::addAll);
      context = context.flatMap(ExtensionContext::getParent);
    }

    return contextModuleTypes;
  }

  private static boolean hasAnnotatedElement(Optional context) {
    return context.flatMap(ExtensionContext::getElement).isPresent();
  }

  private static boolean hasParent(Optional context) {
    return context.flatMap(ExtensionContext::getParent).isPresent();
  }

  private static Set> getModuleTypes(AnnotatedElement element) {

    Optional[]> classes = findAnnotation(element, Modules.class)
        .map(Modules::value);

    if (classes.isPresent()) {
      return Arrays.stream(classes.get()).collect(toSet());
    }
    return Sets.newHashSet();

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy