All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mysql.cj.sasl.ScramSha1SaslClient Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta2
Show newest version
/*
 * Copyright (c) 2020, Oracle and/or its affiliates.
 *
 * This program is free software; you can redistribute it and/or modify it under
 * the terms of the GNU General Public License, version 2.0, as published by the
 * Free Software Foundation.
 *
 * This program is also distributed with certain software (including but not
 * limited to OpenSSL) that is licensed under separate terms, as designated in a
 * particular file or component or in included license documentation. The
 * authors of MySQL hereby grant you an additional permission to link the
 * program and your derivative works with the separately licensed software that
 * they have included with MySQL.
 *
 * Without limiting anything contained in the foregoing, this file, which is
 * part of MySQL Connector/J, is also subject to the Universal FOSS Exception,
 * version 1.0, a copy of which can be found at
 * http://oss.oracle.com/licenses/universal-foss-exception.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License, version 2.0,
 * for more details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
 */

package com.mysql.cj.sasl;

import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

import com.mysql.cj.exceptions.ExceptionFactory;
import com.mysql.cj.util.SaslPrep;
import com.mysql.cj.util.SaslPrep.StringType;
import com.mysql.cj.util.StringUtils;

/**
 * A {@link SaslClient} implementation for SCRAM-SHA-1.
 * The mechanism name was renamed to "MYSQLCJ-SCRAM-SHA-1" in order to avoid future conflicts with an officially supported one. If/when there is one official
 * SCRAM-SHA-1 then it will have to be thoroughly tested and this code can be obsoleted.
 */
public class ScramSha1SaslClient implements SaslClient {
    public static final String MECHANISM_NAME = "MYSQLCJ-SCRAM-SHA-1";

    private enum ScramExchangeStage {
        TERMINATED(null), SERVER_FINAL(TERMINATED), SERVER_FIRST_CLIENT_FINAL(SERVER_FINAL), CLIENT_FIRST(SERVER_FIRST_CLIENT_FINAL);

        private ScramExchangeStage next;

        private ScramExchangeStage(ScramExchangeStage next) {
            this.next = next;
        }

        public ScramExchangeStage getNext() {
            return this.next == null ? this : this.next;
        }
    }

    private static final int MINIMUM_ITERATIONS = 4096;
    private static final String SHA1_ALGORITHM = "SHA-1";
    private static final String HMAC_SHA1_ALGORITHM = "HmacSHA1";
    private static final String GS2_CBIND_FLAG = "n";
    private static final byte[] CLIENT_KEY = "Client Key".getBytes();
    private static final byte[] SERVER_KEY = "Server Key".getBytes();

    private String authorizationId;
    private String authenticationId;
    private String password;

    private ScramExchangeStage scramStage = ScramExchangeStage.CLIENT_FIRST;
    private String cNonce;
    private String gs2Header;
    private String clientFirstMessageBare;
    private byte[] serverSignature;

    public ScramSha1SaslClient(String authorizationId, String authenticationId, String password) throws SaslException {
        this.authorizationId = StringUtils.isNullOrEmpty(authorizationId) ? "" : authorizationId;
        this.authenticationId = StringUtils.isNullOrEmpty(authenticationId) ? this.authorizationId : authenticationId;
        if (StringUtils.isNullOrEmpty(this.authenticationId)) {
            throw new SaslException("The authenticationId cannot be null or empty.");
        }
        this.password = StringUtils.isNullOrEmpty(password) ? "" : password;
        this.scramStage = ScramExchangeStage.CLIENT_FIRST;
    }

    @Override
    public String getMechanismName() {
        return MECHANISM_NAME;
    }

    @Override
    public boolean hasInitialResponse() {
        return true;
    }

    @Override
    public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
        try {
            switch (this.scramStage) {
                case CLIENT_FIRST: // Process client-first-message.
                    // client-first-message      = gs2-header client-first-message-bare
                    // gs2-header                = gs2-cbind-flag "," [ authzid ] ","
                    // gs2-cbind-flag            = "n"
                    // authzid                   = "a=" saslname
                    // client-first-message-bare = username "," nonce
                    // username                  = "n=" saslname
                    // saslname                  = 1*(value-safe-char / "=2C" / "=3D")
                    // value-safe-char           = UTF8
                    //                             ;; UTF8 except NUL, "=", and ",".
                    // nonce                     = "r=" c-nonce [s-nonce]
                    //                             ;; Second part provided by server.
                    // c-nonce                   = printable
                    // printable                 = %x21-2B / %x2D-7E
                    //                             ;; Printable ASCII except ",".
                    this.gs2Header = GS2_CBIND_FLAG + "," + (StringUtils.isNullOrEmpty(this.authorizationId) ? "" : "a=" + prepUserName(this.authorizationId))
                            + ",";
                    this.cNonce = generateRandomPrintableAsciiString(32);
                    this.clientFirstMessageBare = "n=" + prepUserName(this.authenticationId) + ",r=" + this.cNonce;
                    String clientFirstMessage = this.gs2Header + this.clientFirstMessageBare;

                    return StringUtils.getBytes(clientFirstMessage, "UTF-8");

                case SERVER_FIRST_CLIENT_FINAL: // Process server-first-message & client-final-message.
                    // 1st part: server-first-message.
                    String serverFirstMessage = StringUtils.toString(challenge, "UTF-8");
                    Map serverFirstAttributes = parseChallenge(serverFirstMessage);

                    if (!serverFirstAttributes.containsKey("r") || !serverFirstAttributes.containsKey("s") || !serverFirstAttributes.containsKey("i")) {
                        throw new SaslException("Missing required SCRAM attribute from server first message.");
                    }

                    String sNonce = serverFirstAttributes.get("r");
                    if (!sNonce.startsWith(this.cNonce)) {
                        throw new SaslException("Invalid server nonce for SCRAM-SHA-1 authentication.");
                    }
                    byte[] salt = Base64.getDecoder().decode(serverFirstAttributes.get("s"));
                    int iterations = Integer.parseInt(serverFirstAttributes.get("i"));
                    if (iterations < MINIMUM_ITERATIONS) {
                        throw new SaslException("Announced SCRAM-SHA-1 iteration count is too low.");
                    }

                    // 2nd part: client-final-message.

                    // client-final-message-without-proof = channel-binding "," nonce
                    // channel-binding                    = "c=" base64
                    //                                      ;; base64 encoding of cbind-input.
                    // cbind-input                        = gs2-header
                    String clientFinalMessageWithoutProof = "c=" + Base64.getEncoder().encodeToString(StringUtils.getBytes(this.gs2Header, "UTF-8")) + ",r="
                            + sNonce;

                    // Compute ClientProof:
                    //   SaltedPassword  := Hi(Normalize(password), salt, i)
                    //   ClientKey       := HMAC(SaltedPassword, "Client Key")
                    //   StoredKey       := H(ClientKey)
                    //   AuthMessage     := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
                    //   ClientSignature := HMAC(StoredKey, AuthMessage)
                    //   ClientProof     := ClientKey XOR ClientSignature
                    byte[] saltedPassword = hi(SaslPrep.prepare(this.password, StringType.STORED), salt, iterations);
                    byte[] clientKey = hmac(saltedPassword, CLIENT_KEY);
                    byte[] storedKey = h(clientKey);
                    String authMessage = this.clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof;
                    byte[] clientSignature = hmac(storedKey, StringUtils.getBytes(authMessage, "UTF-8"));
                    byte[] clientProof = clientKey.clone();
                    xorInPlace(clientProof, clientSignature);

                    // client-final-message               = client-final-message-without-proof "," proof
                    // proof                              = "p=" base64
                    String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.getEncoder().encodeToString(clientProof);

                    // Compute ServerSignature (for future verification):
                    //   ServerKey       := HMAC(SaltedPassword, "Server Key")
                    //   ServerSignature := HMAC(ServerKey, AuthMessage)
                    byte[] serverKey = hmac(saltedPassword, SERVER_KEY);
                    this.serverSignature = hmac(serverKey, StringUtils.getBytes(authMessage, "UTF-8"));

                    return StringUtils.getBytes(clientFinalMessage, "UTF-8");

                case SERVER_FINAL: // Process server-final-message.
                    String serverFinalMessage = StringUtils.toString(challenge, "UTF-8");
                    Map serverFinalAttributes = parseChallenge(serverFinalMessage);

                    if (serverFinalAttributes.containsKey("e")) {
                        throw new SaslException("Authentication failed due to server error '" + serverFinalAttributes.get("e") + "'.");
                    }

                    if (!serverFinalAttributes.containsKey("v")) {
                        throw new SaslException("Missing required SCRAM attribute from server final message.");
                    }

                    // verifier = "v=" base64
                    //            ;; base-64 encoded ServerSignature.
                    byte[] verifier = Base64.getDecoder().decode(serverFinalAttributes.get("v"));

                    if (!MessageDigest.isEqual(this.serverSignature, verifier)) {
                        throw new SaslException("SCRAM-SHA-1 server signature could not be verified.");
                    }
                    break;

                default:
                    throw new SaslException("Unexpected SCRAM authentication message.");
            }

            return null;
        } catch (Throwable e) {
            this.scramStage = ScramExchangeStage.TERMINATED;
            throw e;
        } finally {
            this.scramStage = this.scramStage.getNext();
        }
    }

    @Override
    public boolean isComplete() {
        return this.scramStage == ScramExchangeStage.TERMINATED;
    }

    @Override
    public byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException {
        throw new IllegalStateException("Integrity and/or privacy has not been negotiated.");
    }

    @Override
    public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException {
        throw new IllegalStateException("Integrity and/or privacy has not been negotiated.");
    }

    @Override
    public Object getNegotiatedProperty(String propName) {
        return null;
    }

    @Override
    public void dispose() throws SaslException {
    }

    private String prepUserName(String userName) {
        return SaslPrep.prepare(userName, StringType.QUERY).replace("=", "=2D").replace(",", "=2C");
    }

    /**
     * Parses a SASL challenge.
     * 
     * @param challenge
     *            the server message (challenge) to parse.
     * @return
     *         a {@link Map} with the key/value pairs obtained from the server challenge.
     */
    private Map parseChallenge(String challenge) {
        Map attributesMap = new HashMap<>();
        for (String attribute : challenge.split(",")) {
            String[] keyValue = attribute.split("=", 2);
            attributesMap.put(keyValue[0], keyValue[1]);
        }
        return attributesMap;
    }

    /**
     * Generates a RFC 5802 safe nonce: "a sequence of random printable ASCII characters excluding ','"
     * 
     * @param length
     *            the length of the nonce.
     * @return
     *         a randomly generated string formed by printable ASCII characters except comma.
     */
    private static String generateRandomPrintableAsciiString(int length) {
        final int first = 0x21; // First printable ASCII character: exclamation mark (!).
        final int last = 0x7E; // Last printable ASCII character: tilde (~).
        final int excl = 0x2C; // Comma (,) is excluded as per RFC 5802 (https://tools.ietf.org/html/rfc5802#section-5.1).
        final int bound = last - first;
        Random random = new SecureRandom();
        char[] result = new char[length];

        for (int i = 0; i < length;) {
            int randomValue = random.nextInt(bound) + first;
            if (randomValue != excl) {
                result[i++] = (char) randomValue;
            }
        }
        return new String(result);
    }

    /**
     * The "H(str)" cryptographic hash function as described in RFC 5802, Section 2.2. This
     * implementation corresponds to SHA-1.
     * 
     * @param str
     *            the string to hash.
     * @return
     *         the hash value of the given string.
     */
    private static byte[] h(byte[] str) {
        try {
            MessageDigest sha1 = MessageDigest.getInstance(SHA1_ALGORITHM);
            return sha1.digest(str);
        } catch (NoSuchAlgorithmException e) {
            throw ExceptionFactory.createException("Failed computing authentication hashes", e);
        }
    }

    /**
     * The "HMAC(key, str)" HMAC keyed hash algorithm as described in RFC 5802, Section 2.2.
     * This implementation corresponds to 'HmacSHA1'.
     * 
     * @param key
     *            the hash key.
     * @param str
     *            the input string.
     * @return
     *         the hashed value of the given params.
     */
    private static byte[] hmac(byte[] key, byte[] str) {
        Mac hmacSha1;
        try {
            hmacSha1 = Mac.getInstance(HMAC_SHA1_ALGORITHM);
            hmacSha1.init(new SecretKeySpec(key, HMAC_SHA1_ALGORITHM));

            return hmacSha1.doFinal(str);
        } catch (NoSuchAlgorithmException | InvalidKeyException e) {
            throw ExceptionFactory.createException("Failed computing authentication hashes", e);
        }
    }

    /**
     * The "Hi(str, salt, i)" PBKDF2 function as described in RFC 5802, Section 2.2.
     * This implementation corresponds to 'PBKDF2WithHmacSHA1'.
     * 
     * @param str
     *            the string value to use as the internal HMAC key.
     * @param salt
     *            the input string to hash in the initial iteration.
     * @param iterations
     *            the number of iterations to run the algorithm.
     * 
     * @return
     *         an hash value with an output length equal to the length of H(str).
     */
    private static byte[] hi(String str, byte[] salt, int iterations) {
        KeySpec spec = new PBEKeySpec(str.toCharArray(), salt, iterations, 20 * 8); // SHA-1 hashes are 20 Bytes long.
        try {
            SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1");
            return factory.generateSecret(spec).getEncoded();
        } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
            throw ExceptionFactory.createException(e.getMessage());
        }
    }

    /**
     * Combines the two byte arrays in a XOR operation, changing the contents of the first.
     * 
     * @param inOut
     *            the left operand of the XOR operation and the destination of the result.
     * @param other
     *            the right operand of the XOR operation.
     * @return
     *         the same as the param inOut, after being updated.
     */
    private static byte[] xorInPlace(byte[] inOut, byte[] other) {
        for (int i = 0; i < inOut.length; i++) {
            inOut[i] ^= other[i];
        }
        return inOut;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy