All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.sasl.gs2.Gs2SaslServer Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

The newest version!
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2015 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.sasl.gs2;

import static org.wildfly.security.asn1.ASN1.APPLICATION_SPECIFIC_MASK;
import static org.wildfly.security.mechanism._private.ElytronMessages.saslGs2;

import java.io.IOException;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

import org.ietf.jgss.ChannelBinding;
import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSCredential;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;
import org.wildfly.common.Assert;
import org.wildfly.common.bytes.ByteStringBuilder;
import org.wildfly.common.iteration.ByteIterator;
import org.wildfly.common.iteration.CodePointIterator;
import org.wildfly.security.asn1.ASN1Exception;
import org.wildfly.security.asn1.DEREncoder;
import org.wildfly.security.auth.callback.IdentityCredentialCallback;
import org.wildfly.security.auth.callback.ServerCredentialCallback;
import org.wildfly.security.credential.GSSKerberosCredential;
import org.wildfly.security.sasl.util.AbstractSaslServer;

/**
 * SaslServer for the GS2 mechanism family as defined by
 * RFC 5801.
 *
 * @author Farah Juma
 */
final class Gs2SaslServer extends AbstractSaslServer {

    private static final int ST_NO_MESSAGE = 1;
    private static final int ST_FIRST_MESSAGE = 2;
    private static final int ST_ACCEPTOR = 3;

    private final boolean plus;
    private final String bindingType;
    private final byte[] bindingData;
    private final Oid mechanism;
    private GSSContext gssContext;
    private String authorizationID;
    private String boundServerName;

    Gs2SaslServer(final String mechanismName, final String protocol, final String serverName, final CallbackHandler callbackHandler,
            final GSSManager gssManager, final boolean plus, final String bindingType, final byte[] bindingData) throws SaslException {
        super(mechanismName, protocol, serverName, callbackHandler, saslGs2);
        this.plus = plus;
        this.bindingType = bindingType;
        this.bindingData = bindingData;

        try {
            mechanism = Gs2.getMechanismForSaslName(gssManager, mechanismName);
        } catch (GSSException e) {
            throw saslGs2.mechMechanismToOidMappingFailed(e).toSaslException();
        }

        // Attempt to obtain a credential
        GSSCredential credential = null;
        ServerCredentialCallback credentialCallback = new ServerCredentialCallback(GSSKerberosCredential.class);

        try {
            saslGs2.trace("Obtaining GSSCredential for the service from callback handler");
            callbackHandler.handle(new Callback[] { credentialCallback });
            credential = credentialCallback.applyToCredential(GSSKerberosCredential.class, GSSKerberosCredential::getGssCredential);
        } catch (IOException e) {
            throw saslGs2.mechCallbackHandlerFailedForUnknownReason(e).toSaslException();
        } catch (UnsupportedCallbackException e) {
            saslGs2.trace("Unable to obtain GSSCredential from callback handler", e);
        }

        try {
            if (credential == null) {
                GSSName ourName;

                if (serverName != null) {
                    String localName = protocol + "@" + serverName;
                    saslGs2.tracef("Our name is '%s'", localName);
                    ourName = gssManager.createName(localName, GSSName.NT_HOSTBASED_SERVICE, mechanism);
                } else {
                    saslGs2.tracef("Our name is unbound");
                    ourName = null;
                }

                credential = gssManager.createCredential(ourName, GSSContext.INDEFINITE_LIFETIME, mechanism, GSSCredential.ACCEPT_ONLY);
            }
            gssContext = gssManager.createContext(credential);
        } catch (GSSException e) {
            throw saslGs2.mechUnableToCreateGssContext(e).toSaslException();
        }
    }

    public void init() {
        setNegotiationState(ST_NO_MESSAGE);
    }

    public String getAuthorizationID() {
        return authorizationID;
    }

    protected byte[] evaluateMessage(final int state, final byte[] message) throws SaslException {
        switch (state) {
            case ST_NO_MESSAGE: {
                if (message == null || message.length == 0) {
                    setNegotiationState(ST_ACCEPTOR);
                    // Initial challenge
                    return NO_BYTES;
                }
                // Fall through
            }
            case ST_FIRST_MESSAGE: {
                assert gssContext.isEstablished() == false;
                if (message == null || message.length == 0) {
                    throw saslGs2.mechClientRefusesToInitiateAuthentication().toSaslException();
                }
                ByteIterator bi = ByteIterator.ofBytes(message);
                ByteIterator di = bi.delimitedBy(',');
                CodePointIterator cpi = di.asUtf8String();
                boolean gs2CbFlagPUsed = false;

                // == Parse message ==
                boolean nonStd = false;
                int b = bi.next();

                // gs2-nonstd-flag
                if (b == 'F') {
                    skipDelimiter(bi);
                    nonStd = true;
                    b = bi.next();
                }

                // gs2-cb-flag
                if (b == 'p') {
                    gs2CbFlagPUsed = true;
                    if (! plus) {
                        throw saslGs2.mechChannelBindingNotSupported().toSaslException();
                    }
                    if (bi.next() != '=') {
                        throw saslGs2.mechInvalidMessageReceived().toSaslException();
                    }
                    assert bindingType != null;
                    assert bindingData != null;
                    if (! bindingType.equals(cpi.drainToString())) {
                        throw saslGs2.mechChannelBindingTypeMismatch().toSaslException();
                    }
                    skipDelimiter(bi);
                } else if (b == 'y') {
                    if (plus || (bindingType != null && bindingData != null)) {
                        // server supports channel binding
                        throw saslGs2.mechChannelBindingNotProvided().toSaslException();
                    }
                    skipDelimiter(bi);
                } else if (b == 'n') {
                    if (plus) {
                        throw saslGs2.mechChannelBindingNotProvided().toSaslException();
                    }
                    skipDelimiter(bi);
                } else {
                    throw saslGs2.mechInvalidMessageReceived().toSaslException();
                }

                // gs2-authzid
                b = bi.next();
                if (b == 'a') {
                    if (bi.next() != '=') {
                        throw saslGs2.mechInvalidMessageReceived().toSaslException();
                    }
                    authorizationID = cpi.drainToString();
                    skipDelimiter(bi);
                } else if (b != ',') {
                    throw saslGs2.mechInvalidMessageReceived().toSaslException();
                }

                // Restore the initial context token header, if necessary
                byte[] token;
                int gs2HeaderStartIndex;
                int gs2HeaderLength;
                if (nonStd) {
                    gs2HeaderStartIndex = 2;
                    gs2HeaderLength = (int) (bi.getIndex() - 2);
                    token = bi.drain();
                } else {
                    gs2HeaderStartIndex = 0;
                    gs2HeaderLength = (int) bi.getIndex();
                    try {
                        token = restoreTokenHeader(bi.drain());
                    } catch (ASN1Exception e) {
                        throw saslGs2.mechUnableToCreateResponseTokenWithCause(e).toSaslException();
                    }
                }

                ByteStringBuilder gs2HeaderExcludingNonStdFlag = new ByteStringBuilder();
                gs2HeaderExcludingNonStdFlag.append(message, gs2HeaderStartIndex, gs2HeaderLength);
                try {
                    ChannelBinding channelBinding = Gs2Util.createChannelBinding(gs2HeaderExcludingNonStdFlag.toArray(), gs2CbFlagPUsed, bindingData);
                    gssContext.setChannelBinding(channelBinding);
                } catch (GSSException e) {
                    throw saslGs2.mechUnableToSetChannelBinding(e).toSaslException();
                }

                try {
                    byte[] response = gssContext.acceptSecContext(token, 0, token.length);
                    if (gssContext.isEstablished()) {
                        Oid actualMechanism = gssContext.getMech();
                        if (! mechanism.equals(actualMechanism)) {
                            throw saslGs2.mechGssApiMechanismMismatch().toSaslException();
                        }
                        storeBoundServerName();
                        checkAuthorizationID();
                        storeDelegatedGSSCredential();
                        negotiationComplete();
                    } else {
                        // Expecting subsequent exchanges
                        setNegotiationState(ST_ACCEPTOR);
                    }
                    return response;
                } catch (GSSException e) {
                    throw saslGs2.mechUnableToAcceptClientMessage(e).toSaslException();
                }
            }
            case ST_ACCEPTOR: {
                assert gssContext.isEstablished() == false;
                try {
                    byte[] response = gssContext.acceptSecContext(message, 0, message.length);
                    if (gssContext.isEstablished()) {
                        Oid actualMechanism = gssContext.getMech();
                        if (! mechanism.equals(actualMechanism)) {
                            throw saslGs2.mechGssApiMechanismMismatch().toSaslException();
                        }
                        storeBoundServerName();
                        checkAuthorizationID();
                        storeDelegatedGSSCredential();
                        negotiationComplete();
                    }
                    return response;
                } catch (GSSException e) {
                    throw saslGs2.mechUnableToAcceptClientMessage(e).toSaslException();
                }
            }
            case COMPLETE_STATE: {
                if (message != null && message.length != 0) {
                    throw saslGs2.mechMessageAfterComplete().toSaslException();
                }
                return null;
            }
        }
        throw Assert.impossibleSwitchCase(state);
    }

    public void dispose() throws SaslException {
        try {
            gssContext.dispose();
        } catch (GSSException e) {
            throw saslGs2.mechUnableToDisposeGssContext(e).toSaslException();
        } finally {
            gssContext = null;
        }
    }

    /**
     * Recompute and restore the initial context token header for the given token.
     *
     * @param token the initial context token without the token header
     * @return the initial context token with the token header restored
     * @throws ASN1Exception if the mechanism OID cannot be DER encoded
     */
    private byte[] restoreTokenHeader(byte[] token) throws ASN1Exception {
        final DEREncoder encoder = new DEREncoder();
        encoder.encodeImplicit(APPLICATION_SPECIFIC_MASK, 0);
        encoder.startSequence();
        try {
            encoder.writeEncoded(mechanism.getDER());
        } catch (GSSException e) {
            throw new ASN1Exception(e);
        }
        encoder.writeEncoded(token);
        encoder.endSequence();
        return encoder.getEncoded();
    }

    private void storeBoundServerName() throws SaslException {
        try {
            String targetName = gssContext.getTargName().toString();
            String[] targetNameParts = targetName.split("[/@]");
            boundServerName = targetNameParts.length > 1 ? targetNameParts[1] : targetName;
        } catch (GSSException e) {
            throw saslGs2.mechUnableToDetermineBoundServerName(e).toSaslException();
        }
    }

    private void checkAuthorizationID() throws SaslException {
        final String authenticationID;
        try {
            authenticationID = gssContext.getSrcName().toString();
        } catch (GSSException e) {
            throw saslGs2.mechUnableToDeterminePeerName(e).toSaslException();
        }
        saslGs2.tracef("checking if [%s] is authorized to act as [%s]...", authenticationID, authorizationID);
        if (authorizationID == null || authorizationID.isEmpty()) {
            authorizationID = authenticationID;
        }
        AuthorizeCallback authorizeCallback = new AuthorizeCallback(authenticationID, authorizationID);
        handleCallbacks(authorizeCallback);
        if (! authorizeCallback.isAuthorized()) {
            throw saslGs2.mechAuthorizationFailed(authenticationID, authorizationID).toSaslException();
        }
        saslGs2.trace("authorization id check successful");
    }

    private void storeDelegatedGSSCredential() throws SaslException {
        try {
            GSSCredential gssCredential = gssContext.getDelegCred();
            if (gssCredential != null) {
                tryHandleCallbacks(new IdentityCredentialCallback(new GSSKerberosCredential(gssCredential), true));
            } else {
                saslGs2.trace("No GSSCredential delegated during authentication.");
            }
        } catch (UnsupportedCallbackException | GSSException e) {
            // ignored
        }
    }

    private void skipDelimiter(ByteIterator bi) throws SaslException {
        if (bi.next() != ',') {
            throw saslGs2.mechInvalidMessageReceived().toSaslException();
        }
    }

    @Override
    public Object getNegotiatedProperty(String propName) {
        assertComplete();
        if (Sasl.BOUND_SERVER_NAME.equals(propName)) {
            return boundServerName;
        }
        return null;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy