All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jgroups.auth.sasl.SaslUtils Maven / Gradle / Ivy

package org.jgroups.auth.sasl;

import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslServerFactory;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.Security;
import java.util.*;

/**
 * Utility methods for handling SASL authentication
 *
 * @author David M. Lloyd
 * @author Tristan Tarrant
 */
public final class SaslUtils {

    private SaslUtils() {
    }

    /**
     * Returns an iterator of all of the registered {@code SaslServerFactory}s where the order is
     * based on the order of the Provider registration and/or class path order. Class path providers
     * are listed before global providers; in the event of a name conflict, the class path provider
     * is preferred.
     *
     * @param classLoader
     *            the class loader to use
     * @param includeGlobal
     *            {@code true} to include globally registered providers, {@code false} to exclude
     *            them
     * @return the {@code Iterator} of {@code SaslServerFactory}s
     */
    public static Iterator getSaslServerFactories(ClassLoader classLoader, boolean includeGlobal) {
        return getFactories(SaslServerFactory.class, classLoader, includeGlobal);
    }

    /**
     * Returns an iterator of all of the registered {@code SaslServerFactory}s where the order is
     * based on the order of the Provider registration and/or class path order.
     *
     * @return the {@code Iterator} of {@code SaslServerFactory}s
     */
    public static Iterator getSaslServerFactories() {
        return getFactories(SaslServerFactory.class, null, true);
    }

    /**
     * Returns an iterator of all of the registered {@code SaslClientFactory}s where the order is
     * based on the order of the Provider registration and/or class path order. Class path providers
     * are listed before global providers; in the event of a name conflict, the class path provider
     * is preferred.
     *
     * @param classLoader
     *            the class loader to use
     * @param includeGlobal
     *            {@code true} to include globally registered providers, {@code false} to exclude
     *            them
     * @return the {@code Iterator} of {@code SaslClientFactory}s
     */
    public static Iterator getSaslClientFactories(ClassLoader classLoader, boolean includeGlobal) {
        return getFactories(SaslClientFactory.class, classLoader, includeGlobal);
    }

    /**
     * Returns an iterator of all of the registered {@code SaslClientFactory}s where the order is
     * based on the order of the Provider registration and/or class path order.
     *
     * @return the {@code Iterator} of {@code SaslClientFactory}s
     */
    public static Iterator getSaslClientFactories() {
        return getFactories(SaslClientFactory.class, null, true);
    }

    private static  Iterator getFactories(Class type, ClassLoader classLoader, boolean includeGlobal) {
        Set factories = new LinkedHashSet<>();
        final ServiceLoader loader = ServiceLoader.load(type, classLoader);
        for (T factory : loader) {
            factories.add(factory);
        }
        if (includeGlobal) {
            Set loadedClasses = new HashSet<>();
            final String filter = type.getSimpleName() + ".";

            Provider[] providers = Security.getProviders();
            for (Provider currentProvider : providers) {
                final ClassLoader cl = currentProvider.getClass().getClassLoader();
                currentProvider.keySet().stream().filter(currentKey -> currentKey instanceof String && ((String)currentKey).startsWith(filter)
                  && ((String)currentKey).indexOf(' ') < 0).forEach(currentKey -> {
                    String className=currentProvider.getProperty((String)currentKey);
                    if(className != null && loadedClasses.add(className)) {
                        try {
                            int index=((String)currentKey).indexOf(".");
                            if(index >= 0) {
                                String service_type=((String)currentKey).substring(0, index);
                                String algorithm=((String)currentKey).substring(index+1);
                                Provider.Service svc=currentProvider.getService(service_type, algorithm);
                                if(svc != null) {
                                    Object inst=svc.newInstance(null);
                                    factories.add((T)inst);
                                }
                            }
                            else {
                                Class clazz=Class.forName(className, true, cl);
                                factories.add(clazz.asSubclass(type).getDeclaredConstructor().newInstance());
                            }
                        }
                        catch(ClassCastException | ReflectiveOperationException | NoSuchAlgorithmException e) {
                        }
                    }
                });
            }
        }
        return factories.iterator();
    }

    public static SaslServerFactory getSaslServerFactory(String mech, Map props) {
        Iterator saslFactories = SaslUtils.getSaslServerFactories(SaslUtils.class.getClassLoader(), true);
        while (saslFactories.hasNext()) {
            SaslServerFactory saslFactory = saslFactories.next();
            for (String supportedMech : saslFactory.getMechanismNames(props)) {
                if (supportedMech.equals(mech)) {
                    return saslFactory;
                }
            }
        }
        throw new IllegalArgumentException("No SASL server factory for mech " + mech);
    }

    public static SaslClientFactory getSaslClientFactory(String mech, Map props) {
        Iterator saslFactories = SaslUtils.getSaslClientFactories(SaslUtils.class.getClassLoader(), true);
        while (saslFactories.hasNext()) {
            SaslClientFactory saslFactory = saslFactories.next();
            for (String supportedMech : saslFactory.getMechanismNames(props)) {
                if (mech.equals(supportedMech)) {
                    return saslFactory;
                }
            }
        }
        throw new IllegalArgumentException("No SASL client factory for mech " + mech);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy