All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.sasl.util.SaslFactories Maven / Gradle / Ivy

The newest version!
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.sasl.util;

import static org.wildfly.security.provider.util.ProviderUtil.INSTALLED_PROVIDERS;

import java.security.Provider;
import java.util.Map;
import java.util.function.BiPredicate;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;

import org.wildfly.common.Assert;

/**
 * A utility class for discovering SASL client and server factories.
 *
 * @author David M. Lloyd
 */
public final class SaslFactories {

    private SaslFactories() {
    }

    private static final SecurityProviderSaslClientFactory providerSaslClientFactory = new SecurityProviderSaslClientFactory(INSTALLED_PROVIDERS);
    private static final SecurityProviderSaslServerFactory providerSaslServerFactory = new SecurityProviderSaslServerFactory(INSTALLED_PROVIDERS);

    static final String[] NO_STRINGS = new String[0];
    static final String PROVIDER_FILTER_KEY = "org.wildfly.internal.PFK";

    private static final SaslClientFactory EMPTY_SASL_CLIENT_FACTORY = new SaslClientFactory() {
        public SaslClient createSaslClient(final String[] mechanisms, final String authorizationId, final String protocol, final String serverName, final Map props, final CallbackHandler cbh) throws SaslException {
            return null;
        }

        public String[] getMechanismNames(final Map props) {
            return NO_STRINGS;
        }
    };

    private static final SaslServerFactory EMPTY_SASL_SERVER_FACTORY = new SaslServerFactory() {
        public SaslServer createSaslServer(final String mechanism, final String protocol, final String serverName, final Map props, final CallbackHandler cbh) throws SaslException {
            return null;
        }

        public String[] getMechanismNames(final Map props) {
            return NO_STRINGS;
        }
    };

    /**
     * Get a SASL client factory which returns an Elytron-provided mechanism.
     *
     * @return the SASL client factory (not {@code null})
     */
    public static SaslClientFactory getElytronSaslClientFactory() {
        return new ServiceLoaderSaslClientFactory(SaslFactories.class.getClassLoader());
    }

    /**
     * Get a SASL server factory which returns an Elytron-provided mechanism.
     *
     * @return the SASL server factory (not {@code null})
     */
    public static SaslServerFactory getElytronSaslServerFactory() {
        return new ServiceLoaderSaslServerFactory(SaslFactories.class.getClassLoader());
    }

    /**
     * Get a SASL client factory which searches all the given class loaders in order for mechanisms.
     *
     * @param classLoaders the class loaders to search
     * @return the SASL client factory (not {@code null})
     */
    public static SaslClientFactory getSearchSaslClientFactory(ClassLoader... classLoaders) {
        Assert.checkNotNullParam("classLoaders", classLoaders);
        SaslClientFactory[] factories = new SaslClientFactory[classLoaders.length];
        for (int i = 0, classLoadersLength = classLoaders.length; i < classLoadersLength; i++) {
            factories[i] = new ServiceLoaderSaslClientFactory(classLoaders[i]);
        }
        return new AggregateSaslClientFactory(factories);
    }

    /**
     * Get a SASL server factory which searches all the given class loaders in order for mechanisms.
     *
     * @param classLoaders the class loaders to search
     * @return the SASL server factory (not {@code null})
     */
    public static SaslServerFactory getSearchSaslServerFactory(ClassLoader... classLoaders) {
        Assert.checkNotNullParam("classLoaders", classLoaders);
        SaslServerFactory[] factories = new SaslServerFactory[classLoaders.length];
        for (int i = 0, classLoadersLength = classLoaders.length; i < classLoadersLength; i++) {
            factories[i] = new ServiceLoaderSaslServerFactory(classLoaders[i]);
        }
        return new AggregateSaslServerFactory(factories);
    }

    /**
     * Get a SASL client factory which uses the currently installed security providers to locate mechanisms.
     *
     * @return the SASL client factory (not {@code null})
     */
    public static SecurityProviderSaslClientFactory getProviderSaslClientFactory() {
        return providerSaslClientFactory;
    }

    /**
     * Get a SASL server factory which uses the currently installed security providers to locate mechanisms.
     *
     * @return the SASL server factory (not {@code null})
     */
    public static SecurityProviderSaslServerFactory getProviderSaslServerFactory() {
        return providerSaslServerFactory;
    }

    /**
     * Get a {@link SaslClientFactory} which does not provide any mechanisms.
     */
    public static SaslClientFactory getEmptySaslClientFactory() {
        return EMPTY_SASL_CLIENT_FACTORY;
    }

    /**
     * Get a {@link SaslServerFactory} which does not provide any mechanisms.
     */
    public static SaslServerFactory getEmptySaslServerFactory() {
        return EMPTY_SASL_SERVER_FACTORY;
    }

    /**
     * Efficiently, recursively filter mechanisms by provider.  If the filter accepts all array elements, the original
     * array is returned unchanged.  If the filter accepts none of the array elements, an empty array is returned.
     *
     * @param orig the original mech array
     * @param idx the recursive read index (start from 0)
     * @param len the recursive write length (start from 0)
     * @param currentProvider the provider being tested
     * @param mechFilter the filter predicate
     * @return the filtered array
     */
    static String[] filterMechanismsByProvider(String[] orig, int idx, int len, final Provider currentProvider, final BiPredicate mechFilter) {
        if (idx == orig.length) {
            // end of array
            if (len == 0) {
                // no mechs
                return NO_STRINGS;
            } else {
                return idx == len ? orig : new String[len];
            }
        } else if (mechFilter.test(orig[idx], currentProvider)) {
            // ok
            String[] filtered = filterMechanismsByProvider(orig, idx + 1, len + 1, currentProvider, mechFilter);
            if (orig != filtered) {
                // we made a copy; populate it
                filtered[len] = orig[idx];
            }
            return filtered;
        } else {
            // skip this element
            return filterMechanismsByProvider(orig, idx + 1, len, currentProvider, mechFilter);
        }
    }

    @SuppressWarnings("unchecked")
    static BiPredicate getProviderFilterPredicate(final Map props) {
        final Object filterObj = props.get(PROVIDER_FILTER_KEY);
        final BiPredicate mechFilter;
        if (filterObj instanceof BiPredicate) {
            mechFilter = (BiPredicate) filterObj;
        } else {
            mechFilter = (s, provider) -> true;
        }
        return mechFilter;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy