org.wildfly.security.sasl.util.SaslFactories Maven / Gradle / Ivy
The newest version!
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.sasl.util;
import static org.wildfly.security.provider.util.ProviderUtil.INSTALLED_PROVIDERS;
import java.security.Provider;
import java.util.Map;
import java.util.function.BiPredicate;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;
import org.wildfly.common.Assert;
/**
* A utility class for discovering SASL client and server factories.
*
* @author David M. Lloyd
*/
public final class SaslFactories {
private SaslFactories() {
}
private static final SecurityProviderSaslClientFactory providerSaslClientFactory = new SecurityProviderSaslClientFactory(INSTALLED_PROVIDERS);
private static final SecurityProviderSaslServerFactory providerSaslServerFactory = new SecurityProviderSaslServerFactory(INSTALLED_PROVIDERS);
static final String[] NO_STRINGS = new String[0];
static final String PROVIDER_FILTER_KEY = "org.wildfly.internal.PFK";
private static final SaslClientFactory EMPTY_SASL_CLIENT_FACTORY = new SaslClientFactory() {
public SaslClient createSaslClient(final String[] mechanisms, final String authorizationId, final String protocol, final String serverName, final Map props, final CallbackHandler cbh) throws SaslException {
return null;
}
public String[] getMechanismNames(final Map props) {
return NO_STRINGS;
}
};
private static final SaslServerFactory EMPTY_SASL_SERVER_FACTORY = new SaslServerFactory() {
public SaslServer createSaslServer(final String mechanism, final String protocol, final String serverName, final Map props, final CallbackHandler cbh) throws SaslException {
return null;
}
public String[] getMechanismNames(final Map props) {
return NO_STRINGS;
}
};
/**
* Get a SASL client factory which returns an Elytron-provided mechanism.
*
* @return the SASL client factory (not {@code null})
*/
public static SaslClientFactory getElytronSaslClientFactory() {
return new ServiceLoaderSaslClientFactory(SaslFactories.class.getClassLoader());
}
/**
* Get a SASL server factory which returns an Elytron-provided mechanism.
*
* @return the SASL server factory (not {@code null})
*/
public static SaslServerFactory getElytronSaslServerFactory() {
return new ServiceLoaderSaslServerFactory(SaslFactories.class.getClassLoader());
}
/**
* Get a SASL client factory which searches all the given class loaders in order for mechanisms.
*
* @param classLoaders the class loaders to search
* @return the SASL client factory (not {@code null})
*/
public static SaslClientFactory getSearchSaslClientFactory(ClassLoader... classLoaders) {
Assert.checkNotNullParam("classLoaders", classLoaders);
SaslClientFactory[] factories = new SaslClientFactory[classLoaders.length];
for (int i = 0, classLoadersLength = classLoaders.length; i < classLoadersLength; i++) {
factories[i] = new ServiceLoaderSaslClientFactory(classLoaders[i]);
}
return new AggregateSaslClientFactory(factories);
}
/**
* Get a SASL server factory which searches all the given class loaders in order for mechanisms.
*
* @param classLoaders the class loaders to search
* @return the SASL server factory (not {@code null})
*/
public static SaslServerFactory getSearchSaslServerFactory(ClassLoader... classLoaders) {
Assert.checkNotNullParam("classLoaders", classLoaders);
SaslServerFactory[] factories = new SaslServerFactory[classLoaders.length];
for (int i = 0, classLoadersLength = classLoaders.length; i < classLoadersLength; i++) {
factories[i] = new ServiceLoaderSaslServerFactory(classLoaders[i]);
}
return new AggregateSaslServerFactory(factories);
}
/**
* Get a SASL client factory which uses the currently installed security providers to locate mechanisms.
*
* @return the SASL client factory (not {@code null})
*/
public static SecurityProviderSaslClientFactory getProviderSaslClientFactory() {
return providerSaslClientFactory;
}
/**
* Get a SASL server factory which uses the currently installed security providers to locate mechanisms.
*
* @return the SASL server factory (not {@code null})
*/
public static SecurityProviderSaslServerFactory getProviderSaslServerFactory() {
return providerSaslServerFactory;
}
/**
* Get a {@link SaslClientFactory} which does not provide any mechanisms.
*/
public static SaslClientFactory getEmptySaslClientFactory() {
return EMPTY_SASL_CLIENT_FACTORY;
}
/**
* Get a {@link SaslServerFactory} which does not provide any mechanisms.
*/
public static SaslServerFactory getEmptySaslServerFactory() {
return EMPTY_SASL_SERVER_FACTORY;
}
/**
* Efficiently, recursively filter mechanisms by provider. If the filter accepts all array elements, the original
* array is returned unchanged. If the filter accepts none of the array elements, an empty array is returned.
*
* @param orig the original mech array
* @param idx the recursive read index (start from 0)
* @param len the recursive write length (start from 0)
* @param currentProvider the provider being tested
* @param mechFilter the filter predicate
* @return the filtered array
*/
static String[] filterMechanismsByProvider(String[] orig, int idx, int len, final Provider currentProvider, final BiPredicate mechFilter) {
if (idx == orig.length) {
// end of array
if (len == 0) {
// no mechs
return NO_STRINGS;
} else {
return idx == len ? orig : new String[len];
}
} else if (mechFilter.test(orig[idx], currentProvider)) {
// ok
String[] filtered = filterMechanismsByProvider(orig, idx + 1, len + 1, currentProvider, mechFilter);
if (orig != filtered) {
// we made a copy; populate it
filtered[len] = orig[idx];
}
return filtered;
} else {
// skip this element
return filterMechanismsByProvider(orig, idx + 1, len, currentProvider, mechFilter);
}
}
@SuppressWarnings("unchecked")
static BiPredicate getProviderFilterPredicate(final Map props) {
final Object filterObj = props.get(PROVIDER_FILTER_KEY);
final BiPredicate mechFilter;
if (filterObj instanceof BiPredicate) {
mechFilter = (BiPredicate) filterObj;
} else {
mechFilter = (s, provider) -> true;
}
return mechFilter;
}
}