All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.serenitybdd.jbehave.ClassFinder Maven / Gradle / Ivy

There is a newer version: 1.46.0
Show newest version
package net.serenitybdd.jbehave;


import ch.lambdaj.function.convert.Converter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.reflections.Reflections;
import org.reflections.scanners.MethodAnnotationsScanner;
import org.reflections.scanners.ResourcesScanner;
import org.reflections.scanners.SubTypesScanner;
import org.reflections.scanners.TypeAnnotationsScanner;

import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.Set;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;

import static ch.lambdaj.Lambda.convert;

/**
 * Load classes from a given package.
 */
public class ClassFinder {

    private final ClassLoader classLoader;

    public ClassFinder(ClassLoader classLoader) {
        this.classLoader = classLoader;
    }

    public static ClassFinder loadClasses() {
        return new ClassFinder(getDefaultClassLoader());
    }

    public ClassFinder withClassLoader(ClassLoader classLoader) {
        return new ClassFinder(classLoader);
    }

    /**
     * Scans all classes accessible from the context class loader which belong to the given package and subpackages.
     *
     * @param packageName The base package
     * @return The classes
     */
    public List> fromPackage(String packageName) {
        if (expectedAnnotations == null) {
            return allClassesInPackage(packageName);
        } else {
            return annotatedClassesInPackage(packageName);
        }
    }

    private List> allClassesInPackage(String packageName) {
        try {
            String path = packageName.replace('.', '/');
            Enumeration resources = classResourcesOn(path);
            List dirs = new ArrayList();
            while (resources.hasMoreElements()) {
                URL resource = resources.nextElement();
                dirs.add(resource.toURI());
            }
            List> classes = Lists.newArrayList();
            for (URI directory : dirs) {
                classes.addAll(findClasses(directory, packageName));
            }
            return classes;
        } catch (Exception e) {
            throw new RuntimeException("failed to find all classes in package [" + packageName + "]", e);
        }
    }


    private List> expectedAnnotations;

    public ClassFinder annotatedWith(Class... someAnnotations) {
        expectedAnnotations = ImmutableList.copyOf(someAnnotations);
        return this;
    }

    public List> annotatedClassesInPackage(String packageName) {

        Reflections reflections = new Reflections(packageName,
                new SubTypesScanner(),
                new TypeAnnotationsScanner(),
                new MethodAnnotationsScanner(),
                new ResourcesScanner(), getClassLoader());

        Set> matchingClasses = Sets.newHashSet();
        for (Class expectedAnnotation : expectedAnnotations) {
            matchingClasses.addAll(reflections.getTypesAnnotatedWith(expectedAnnotation));
            matchingClasses.addAll(classesFrom(reflections.getMethodsAnnotatedWith(expectedAnnotation)));
        }
        return ImmutableList.copyOf(matchingClasses);

    }

    private Collection> classesFrom(Set annotatedMethods) {
        return convert(annotatedMethods, toDeclaringCasses());
    }

    private Converter> toDeclaringCasses() {
        return new Converter>() {

            public Class convert(Method from) {
                return from.getDeclaringClass();
            }
        };
    }

    private Enumeration classResourcesOn(String path) {
        try {
            return getClassLoader().getResources(path);
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not access class path at " + path, e);
        }
    }

    /**
     * Recursive method used to find all classes in a given directory and subdirs.
     *
     * @param directory   The base directory
     * @param packageName The package name for classes found inside the base directory
     * @return The classes
     */
    private List> findClasses(URI directory, String packageName) {
        try {
            final String scheme = directory.getScheme();
            final String schemeSpecificPart = directory.getSchemeSpecificPart();

            if (scheme.equals("jar") && schemeSpecificPart.contains("!")) {
                return findClassesInJar(directory, packageName);
            } else if (scheme.equals("file")) {
                return findClassesInFileSystemDirectory(directory, packageName);
            }

            throw new IllegalArgumentException("cannot handle URI with scheme [" + scheme + "]");
        } catch (Exception e) {
            throw new RuntimeException(
                    "failed to find classes" +
                    "in directory=[" + directory + "], with packageName=[" + packageName + "]",
                    e
            );
        }

    }

    private List> findClassesInJar(URI jarDirectory, String packageName) throws IOException {
        final String schemeSpecificPart = jarDirectory.getSchemeSpecificPart();

        List> classes = Lists.newArrayList();
        String [] split = schemeSpecificPart.split("!");
        URL jar = new URL(split[0]);
        ZipInputStream zip = new ZipInputStream(jar.openStream());
        ZipEntry entry;
        while ((entry = zip.getNextEntry()) != null) {
            if (entry.getName().endsWith(".class")) {
                String className = classNameFor(entry);
                if (className.startsWith(packageName) && isNotAnInnerClass(className)) {
                    classes.add(loadClassWithName(className));
                }
            }
        }

        return classes;
    }

    private List> findClassesInFileSystemDirectory(URI jarDirectory, String packageName) {
        List> classes = Lists.newArrayList();

        File directory = new File(jarDirectory);

        if (!directory.exists()) {
            return classes;
        }
        File[] files = directory.listFiles();
        if (files != null) {
            for (File file : files) {
                if (file.isDirectory()) {
                    classes.addAll(findClasses(file.toURI(), packageName + "." + file.getName()));
                } else if (file.getName().endsWith(".class") && isNotAnInnerClass(file.getName())) {
                    classes.add(correspondingClass(packageName, file));
                }
            }
        }

        return classes;
    }

    private static String classNameFor(ZipEntry entry) {
        return entry.getName().replaceAll("[$].*", "").replaceAll("[.]class", "").replace('/', '.');
    }

    private Class loadClassWithName(String className){
        try {
            return getClassLoader().loadClass(className);
        } catch (ClassNotFoundException e) {
            throw new IllegalArgumentException("Could not find or access class for " + className, e);
        }
     }

    private Class correspondingClass(String packageName, File file) {
        String fullyQualifiedClassName = packageName + '.' + simpleClassNameOf(file);
        return loadClassWithName(fullyQualifiedClassName);
    }

    private static ClassLoader getDefaultClassLoader() {
        return Thread.currentThread().getContextClassLoader();
    }

    private String simpleClassNameOf(File file) {
        return file.getName().substring(0, file.getName().length() - 6);
    }

    private boolean isNotAnInnerClass(String className) {
        return (!className.contains("$"));
    }

    public ClassLoader getClassLoader() {
        return classLoader;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy