org.wildfly.security.sasl.util.LocalPrincipalSaslClientFactory Maven / Gradle / Ivy
The newest version!
/*
* JBoss, Home of Professional Open Source.
* Copyright 2017 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.sasl.util;
import static org.wildfly.security.sasl._private.ElytronMessages.sasl;
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import javax.net.ssl.SSLSession;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.x500.X500Principal;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import org.wildfly.security.auth.callback.CredentialCallback;
import org.wildfly.security.auth.callback.SSLCallback;
import org.wildfly.security.auth.principal.AnonymousPrincipal;
import org.wildfly.security.auth.principal.NamePrincipal;
import org.wildfly.security.credential.Credential;
import org.wildfly.security.credential.X509CertificateChainCredential;
import org.wildfly.security.sasl.WildFlySasl;
/**
* A delegating SASL client factory whose instances can track and return the assumed principal used for authentication. Use
* the {@link WildFlySasl#PRINCIPAL} negotiated property to retrieve the principal.
*
* @author David M. Lloyd
*/
public final class LocalPrincipalSaslClientFactory extends AbstractDelegatingSaslClientFactory {
/**
* Construct a new instance.
*
* @param delegate the delegate client factory (must not be {@code null})
*/
public LocalPrincipalSaslClientFactory(final SaslClientFactory delegate) {
super(delegate);
}
public SaslClient createSaslClient(final String[] mechanisms, final String authorizationId, final String protocol, final String serverName, final Map props, final CallbackHandler cbh) throws SaslException {
Supplier principalSupplier;
CallbackHandler realCallbackHandler;
if (authorizationId != null) {
Principal principal = new NamePrincipal(authorizationId);
principalSupplier = () -> principal;
realCallbackHandler = cbh;
} else {
final ClientPrincipalQueryCallbackHandler handler = new ClientPrincipalQueryCallbackHandler(cbh);
principalSupplier = handler::getPrincipal;
realCallbackHandler = handler;
}
final SaslClient delegate = super.createSaslClient(mechanisms, authorizationId, protocol, serverName, props, realCallbackHandler);
if (delegate == null) {
return null;
}
return new LocalPrincipalSaslClient(delegate, principalSupplier);
}
static final class ClientPrincipalQueryCallbackHandler implements CallbackHandler {
private final CallbackHandler delegate;
private final AtomicReference principalRef = new AtomicReference<>(AnonymousPrincipal.getInstance());
ClientPrincipalQueryCallbackHandler(final CallbackHandler delegate) {
this.delegate = delegate;
}
public void handle(final Callback[] callbacks) throws IOException, UnsupportedCallbackException {
try {
delegate.handle(callbacks);
} finally {
// try to determine the used principal
for (Callback callback : callbacks) {
if (callback instanceof NameCallback) {
final String name = ((NameCallback) callback).getName();
if (name != null) {
principalRef.set(new NamePrincipal(name));
}
} else if (callback instanceof CredentialCallback) {
final Credential credential = ((CredentialCallback) callback).getCredential();
if (credential instanceof X509CertificateChainCredential) {
final X500Principal principal = ((X509CertificateChainCredential) credential).getFirstCertificate().getSubjectX500Principal();
if (principal != null) {
principalRef.set(principal);
}
}
} else if (callback instanceof SSLCallback) {
// SSL callback always comes before name callback
final SSLSession sslSession = ((SSLCallback) callback).getSslConnection().getSession();
if (sslSession != null) {
final Principal localPrincipal = sslSession.getLocalPrincipal();
if (localPrincipal != null) {
principalRef.set(localPrincipal);
}
}
}
}
}
}
public Principal getPrincipal() {
return principalRef.get();
}
}
final class LocalPrincipalSaslClient extends AbstractDelegatingSaslClient {
private final Supplier principalSupplier;
LocalPrincipalSaslClient(final SaslClient delegate, final Supplier principalSupplier) {
super(delegate);
this.principalSupplier = principalSupplier;
}
@Override
public Object getNegotiatedProperty(final String propName) {
if (! isComplete()) {
throw sasl.mechAuthenticationNotComplete();
}
// The mechanism might be smart enough to know its principal; if so, use their value instead of our guess.
final Object value = super.getNegotiatedProperty(propName);
return value == null && WildFlySasl.PRINCIPAL.equals(propName) ? principalSupplier.get() : value;
}
}
}