All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.mechanism.scram.ScramClient Maven / Gradle / Ivy

There is a newer version: 2.6.0.Final
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2015 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.mechanism.scram;

import static org.wildfly.security.mechanism._private.ElytronMessages.saslScram;

import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;

import org.wildfly.common.bytes.ByteStringBuilder;
import org.wildfly.common.codec.DecodeException;
import org.wildfly.common.iteration.ByteIterator;
import org.wildfly.security.mechanism._private.MechanismUtil;
import org.wildfly.security.mechanism.AuthenticationMechanismException;
import org.wildfly.security.mechanism.ScramServerErrorCode;
import org.wildfly.security.password.interfaces.ScramDigestPassword;
import org.wildfly.security.password.spec.IteratedSaltedPasswordAlgorithmSpec;
import org.wildfly.security.sasl.util.StringPrep;

/**
 * A client-side implementation for the SCRAM authentication.
 *
 * @author David M. Lloyd
 */
public final class ScramClient {
    private final Supplier providers;
    private final ScramMechanism mechanism;
    private final String authorizationId;
    private final CallbackHandler callbackHandler;
    private final SecureRandom secureRandom;
    private final byte[] bindingData;
    private final String bindingType;
    private final int minimumIterationCount;
    private final int maximumIterationCount;

    /**
     * Constructs a new {@code ScramClient} instance.
     *
     * @param mechanism the SCRAM mechanism used for the authentication.
     * @param authorizationId the ID of the user to be authorized.
     * @param callbackHandler the callbackHandler used for the authentication.
     * @param secureRandom an optional secure RNG to use.
     * @param bindingData the binding data for the "PLUS" channel binding option.
     * @param bindingType the binding type for the "PLUS" channel binding option.
     * @param minimumIterationCount the minimum number of iterations for password hashing.
     * @param maximumIterationCount the maximum number of iterations for password hashing.
     * @param providers the security providers.
     */
    ScramClient(final ScramMechanism mechanism, final String authorizationId, final CallbackHandler callbackHandler, final SecureRandom secureRandom, final byte[] bindingData, final String bindingType, final int minimumIterationCount, final int maximumIterationCount, final Supplier providers) {
        this.mechanism = mechanism;
        this.authorizationId = authorizationId;
        this.callbackHandler = callbackHandler;
        this.secureRandom = secureRandom;
        this.bindingData = bindingData;
        this.bindingType = bindingType;
        this.minimumIterationCount = minimumIterationCount;
        this.maximumIterationCount = maximumIterationCount;
        this.providers = providers;
    }

    /**
     * Returns the secure RNG used for the authentication.
     *
     * @return the secure RNG used for the authentication.
     */
    Random getRandom() {
        return secureRandom != null ? secureRandom : ThreadLocalRandom.current();
    }

    /**
     * Returns the SCRAM mechanism used for the authentication.
     *
     * @return the SCRAM mechanism used for the authentication.
     */
    public ScramMechanism getMechanism() {
        return mechanism;
    }

    /**
     * Returns the ID of the user to be authorized.
     *
     * @return the ID of the user to be authorized.
     */
    public String getAuthorizationId() {
        return authorizationId;
    }

    /**
     * Returns the binding type for the "PLUS" channel binding option.
     *
     * @return the binding type for the "PLUS" channel binding option.
     */
    public String getBindingType() {
        return bindingType;
    }

    /**
     * Returns the binding data for the "PLUS" channel binding option.
     *
     * @return the binding data for the "PLUS" channel binding option.
     */
    byte[] getRawBindingData() {
        return bindingData;
    }

    /**
     * Returns a copy of the binding data for the "PLUS" channel binding option.
     *
     * @return a copy of the binding data for the "PLUS" channel binding option.
     */
    public byte[] getBindingData() {
        final byte[] bindingData = this.bindingData;
        return bindingData == null ? null : bindingData.clone();
    }

    /**
     * Create an initial response.  This will cause the callback handler to be initialized with an authentication name.
     *
     * @return the initial response to send to the server
     * @throws AuthenticationMechanismException if the client authentication failed for some reason
     */
    public ScramInitialClientMessage getInitialResponse() throws AuthenticationMechanismException {
        final NameCallback nameCallback = authorizationId == null || authorizationId.isEmpty() ?
                new NameCallback("User name") : new NameCallback("User name", authorizationId);
        try {
            MechanismUtil.handleCallbacks(saslScram, callbackHandler, nameCallback);
        } catch (UnsupportedCallbackException e) {
            throw saslScram.mechCallbackHandlerDoesNotSupportUserName(e);
        }
        final String name = nameCallback.getName();
        if (name == null) {
            throw saslScram.mechNoLoginNameGiven();
        }
        final ByteStringBuilder encoded = new ByteStringBuilder();
        final boolean binding;
        if (bindingData != null) {
            binding = true;
            if (mechanism.isPlus()) {
                encoded.append("p=");
                encoded.append(bindingType);
                encoded.append(',');
            } else {
                encoded.append("y,");
            }
        } else {
            binding = false;
            encoded.append("n,");
        }
        if (authorizationId != null) {
            encoded.append('a').append('=');
            StringPrep.encode(authorizationId, encoded, StringPrep.PROFILE_SASL_STORED | StringPrep.MAP_SCRAM_LOGIN_CHARS);
        }
        encoded.append(',');
        final int initialPartIndex = encoded.length();
        encoded.append('n').append('=');
        StringPrep.encode(name, encoded, StringPrep.PROFILE_SASL_STORED | StringPrep.MAP_SCRAM_LOGIN_CHARS);
        encoded.append(',').append('r').append('=');
        final byte[] nonce = ScramUtil.generateNonce(48, getRandom());
        encoded.append(nonce);

        return new ScramInitialClientMessage(this, name, binding, nonce, initialPartIndex, encoded.toArray());
    }

    /**
     * Parses the initial server message and creates {@link ScramInitialServerMessage} from parsed information.
     * Also checks if the message have all necessary properties.
     *
     * @param initialResponse the initial client response for the server.
     * @param bytes the byte array containing the initial server message to parse.
     * @return the initial server message.
     * @throws AuthenticationMechanismException if an error occurs during the parsing.
     */
    public ScramInitialServerMessage parseInitialServerMessage(final ScramInitialClientMessage initialResponse, final byte[] bytes) throws AuthenticationMechanismException {
        final byte[] challenge = bytes.clone();
        final ByteIterator bi = ByteIterator.ofBytes(challenge);
        final byte[] serverNonce;
        final byte[] salt;
        final int iterationCount;
        try {
            if (bi.peekNext() == 'e') {
                bi.next();
                if (bi.next() == '=') {
                    throw saslScram.scramServerRejectedAuthentication(ScramServerErrorCode.fromErrorString(bi.delimitedBy(',').asUtf8String().drainToString()));
                }
                throw saslScram.mechInvalidMessageReceived();
            }
            if (bi.next() != 'r' || bi.next() != '=') {
                throw saslScram.mechInvalidMessageReceived();
            }
            final byte[] clientNonce = initialResponse.getRawNonce();
            if (! bi.limitedTo(clientNonce.length).contentEquals(ByteIterator.ofBytes(clientNonce))) {
                throw saslScram.mechNoncesDoNotMatch();
            }
            serverNonce = bi.delimitedBy(',').drain();
            bi.next(); // it's a ,
            if (bi.next() != 's' || bi.next() != '=') {
                throw saslScram.mechInvalidMessageReceived();
            }
            salt = bi.delimitedBy(',').asUtf8String().base64Decode().drain();
            bi.next(); // it's a ,
            if (bi.next() != 'i' || bi.next() != '=') {
                throw saslScram.mechInvalidMessageReceived();
            }
            iterationCount = ScramUtil.parsePosInt(bi);
            if (iterationCount < minimumIterationCount) {
                throw saslScram.mechIterationCountIsTooLow(iterationCount, minimumIterationCount);
            }
            if (iterationCount > maximumIterationCount) {
                throw saslScram.mechIterationCountIsTooHigh(iterationCount, maximumIterationCount);
            }
        } catch (NoSuchElementException | DecodeException | NumberFormatException ex) {
            throw saslScram.mechInvalidMessageReceived();
        }
        return new ScramInitialServerMessage(initialResponse, serverNonce, salt, iterationCount, challenge);
    }

    /**
     * Handles the initial challenge from the server and create a response from the client.
     * The method uses a password credential obtained from the callback handler to derive a salted password,
     * which is then used to generate a client key, stored key, and client proof.
     *
     * @param initialResponse the initial client message.
     * @param initialChallenge the initial server message.
     * @return the final client message.
     * @throws AuthenticationMechanismException if an error occurs while obtaining the password,
     * creating the {@link ScramFinalClientMessage} or the mechanism in the initial response or challenge message
     * does not match the mechanism expected by the server
     */
    public ScramFinalClientMessage handleInitialChallenge(ScramInitialClientMessage initialResponse, ScramInitialServerMessage initialChallenge) throws AuthenticationMechanismException {
        boolean trace = saslScram.isTraceEnabled();

        if (initialResponse.getMechanism() != mechanism) {
            throw saslScram.mechUnmatchedMechanism(mechanism.toString(), initialResponse.getMechanism().toString());
        }
        if (initialChallenge.getMechanism() != mechanism) {
            throw saslScram.mechUnmatchedMechanism(mechanism.toString(), initialChallenge.getMechanism().toString());
        }

        final boolean plus = mechanism.isPlus();

        final ByteStringBuilder encoded = new ByteStringBuilder();
        encoded.append('c').append('=');
        ByteStringBuilder b2 = new ByteStringBuilder();
        if (bindingData != null) {
            if(trace) saslScram.tracef("[C] Binding data: %s%n", ByteIterator.ofBytes(bindingData).hexEncode().drainToString());
            if (plus) {
                b2.append("p=");
                b2.append(bindingType);
            } else {
                b2.append('y');
            }
            b2.append(',');
            if (getAuthorizationId() != null) {
                b2.append("a=");
                StringPrep.encode(getAuthorizationId(), b2, StringPrep.PROFILE_SASL_STORED | StringPrep.MAP_SCRAM_LOGIN_CHARS);
            }
            b2.append(',');
            if (plus) {
                b2.append(bindingData);
            }
            encoded.appendLatin1(b2.iterate().base64Encode());
        } else {
            b2.append('n');
            b2.append(',');
            if (getAuthorizationId() != null) {
                b2.append("a=");
                StringPrep.encode(getAuthorizationId(), b2, StringPrep.PROFILE_SASL_STORED | StringPrep.MAP_SCRAM_LOGIN_CHARS);
            }
            b2.append(',');
            assert !plus;
            encoded.appendLatin1(b2.iterate().base64Encode());
        }
        // nonce
        encoded.append(',').append('r').append('=').append(initialResponse.getRawNonce()).append(initialChallenge.getRawServerNonce());
        // no extensions

        final IteratedSaltedPasswordAlgorithmSpec parameters = new IteratedSaltedPasswordAlgorithmSpec(
            initialChallenge.getIterationCount(),
            initialChallenge.getRawSalt()
        );
        ScramDigestPassword password = MechanismUtil.getPasswordCredential(
            initialResponse.getAuthenticationName(),
            callbackHandler,
            ScramDigestPassword.class,
            mechanism.getPasswordAlgorithm(),
            parameters,
            parameters,
            providers,
            saslScram);
        final byte[] saltedPassword = password.getDigest();
        if (trace) saslScram.tracef("[C] Client salted password: %s", ByteIterator.ofBytes(saltedPassword).hexEncode().drainToString());

        try {
            final Mac mac = Mac.getInstance(getMechanism().getHmacName());
            final MessageDigest messageDigest = MessageDigest.getInstance(getMechanism().getMessageDigestName());

            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            final byte[] clientKey = mac.doFinal(ScramUtil.CLIENT_KEY_BYTES);
            if(trace) saslScram.tracef("[C] Client key: %s", ByteIterator.ofBytes(clientKey).hexEncode().drainToString());
            final byte[] storedKey = messageDigest.digest(clientKey);
            if(trace) saslScram.tracef("[C] Stored key: %s%n", ByteIterator.ofBytes(storedKey).hexEncode().drainToString());
            mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm()));
            final byte[] initialResponseBytes = initialResponse.getRawMessageBytes();
            mac.update(initialResponseBytes, initialResponse.getInitialPartIndex(), initialResponseBytes.length - initialResponse.getInitialPartIndex());
            if (trace) saslScram.tracef("[C] Using client first message: %s%n", ByteIterator.ofBytes(initialResponseBytes, initialResponse.getInitialPartIndex(), initialResponseBytes.length - initialResponse.getInitialPartIndex()).hexEncode().drainToString());
            mac.update((byte) ',');
            mac.update(initialChallenge.getRawMessageBytes());
            if(trace) saslScram.tracef("[C] Using server first message: %s%n", ByteIterator.ofBytes(initialChallenge.getRawMessageBytes()).hexEncode().drainToString());
            mac.update((byte) ',');
            encoded.updateMac(mac);
            if(trace) saslScram.tracef("[C] Using client final message without proof: %s%n", ByteIterator.ofBytes(encoded.toArray()).hexEncode().drainToString());
            final byte[] clientProof = mac.doFinal();
            if(trace) saslScram.tracef("[C] Client signature: %s%n", ByteIterator.ofBytes(clientProof).hexEncode().drainToString());
            ScramUtil.xor(clientProof, clientKey);
            if(trace) saslScram.tracef("[C] Client proof: %s%n", ByteIterator.ofBytes(clientProof).hexEncode().drainToString());
            int proofStart = encoded.length();
            // proof
            encoded.append(',').append('p').append('=');
            encoded.appendLatin1(ByteIterator.ofBytes(clientProof).base64Encode());
            if(trace) saslScram.tracef("[C] Client final message: %s%n", ByteIterator.ofBytes(encoded.toArray()).hexEncode().drainToString());
            return new ScramFinalClientMessage(initialResponse, initialChallenge, password, clientProof, encoded.toArray(), proofStart);
        } catch (NoSuchAlgorithmException | InvalidKeyException e) {
            throw saslScram.mechMacAlgorithmNotSupported(e);
        }
    }

    /**
     * Parses the final server message and creates {@link ScramFinalServerMessage} from parsed information.
     * Also checks if the message have all necessary properties.
     *
     * @param messageBytes the byte array of the final server message.
     * @return the final server message.
     * @throws AuthenticationMechanismException if an error occurs during the parsing or the server rejected the authentication request.
     */
    public ScramFinalServerMessage parseFinalServerMessage(final byte[] messageBytes) throws AuthenticationMechanismException {
        final ByteIterator bi = ByteIterator.ofBytes(messageBytes);
        final byte[] sig;
        try {
            int c = bi.next();
            if (c == 'e') {
                if (bi.next() == '=') {
                    throw saslScram.scramServerRejectedAuthentication(ScramServerErrorCode.fromErrorString(bi.delimitedBy(',').asUtf8String().drainToString()));
                }
                throw saslScram.mechInvalidMessageReceived();
            } else if (c == 'v' && bi.next() == '=') {
                sig = bi.delimitedBy(',').asUtf8String().base64Decode().drain();
            } else {
                throw saslScram.mechInvalidMessageReceived();
            }
            if (bi.hasNext()) {
                throw saslScram.mechInvalidMessageReceived();
            }
        } catch (IllegalArgumentException e) {
            throw saslScram.mechInvalidMessageReceived();
        }
        return new ScramFinalServerMessage(sig, messageBytes);
    }

    /**
     * Verifies the final challenge received from the server.
     *
     * @param finalResponse the final client message.
     * @param finalChallenge the final server message.
     * @throws AuthenticationMechanismException if an error occurs during the verification or the server signature is invalid.
     */
    public void verifyFinalChallenge(final ScramFinalClientMessage finalResponse, final ScramFinalServerMessage finalChallenge) throws AuthenticationMechanismException {
        boolean trace = saslScram.isTraceEnabled();

        try {
            final Mac mac = Mac.getInstance(getMechanism().getHmacName());

            // verify server signature
            ScramDigestPassword password = finalResponse.getPassword();
            mac.init(new SecretKeySpec(password.getDigest(), mac.getAlgorithm()));
            byte[] serverKey = mac.doFinal(ScramUtil.SERVER_KEY_BYTES);
            if(trace) saslScram.tracef("[C] Server key: %s%n", ByteIterator.ofBytes(serverKey).hexEncode().drainToString());
            mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm()));
            byte[] clientFirstMessage = finalResponse.getInitialResponse().getRawMessageBytes();
            int bareStart = finalResponse.getInitialResponse().getInitialPartIndex();
            mac.update(clientFirstMessage, bareStart, clientFirstMessage.length - bareStart);
            mac.update((byte) ',');
            byte[] serverFirstMessage = finalResponse.getInitialChallenge().getRawMessageBytes();
            mac.update(serverFirstMessage);
            mac.update((byte) ',');
            byte[] clientFinalMessage = finalResponse.getRawMessageBytes();
            mac.update(clientFinalMessage, 0, finalResponse.getProofOffset());
            byte[] serverSignature = mac.doFinal();
            if(trace) saslScram.tracef("[C] Recovered server signature: %s%n", ByteIterator.ofBytes(serverSignature).hexEncode().drainToString());
            if (! Arrays.equals(finalChallenge.getRawServerSignature(), serverSignature)) {
                throw saslScram.mechServerAuthenticityCannotBeVerified();
            }
        } catch (IllegalArgumentException | InvalidKeyException | NoSuchAlgorithmException e) {
            throw saslScram.mechMacAlgorithmNotSupported(e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy