All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.network.crypto.AuthEngine Maven / Gradle / Ivy

There is a newer version: 2.4.8
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.network.crypto;

import java.io.Closeable;
import java.io.IOException;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.Properties;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import static java.nio.charset.StandardCharsets.UTF_8;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Bytes;
import org.apache.commons.crypto.cipher.CryptoCipher;
import org.apache.commons.crypto.cipher.CryptoCipherFactory;
import org.apache.commons.crypto.random.CryptoRandom;
import org.apache.commons.crypto.random.CryptoRandomFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.util.TransportConf;

/**
 * A helper class for abstracting authentication and key negotiation details. This is used by
 * both client and server sides, since the operations are basically the same.
 */
class AuthEngine implements Closeable {

  private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class);
  private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 });

  private final byte[] appId;
  private final char[] secret;
  private final TransportConf conf;
  private final Properties cryptoConf;
  private final CryptoRandom random;

  private byte[] authNonce;

  @VisibleForTesting
  byte[] challenge;

  private TransportCipher sessionCipher;
  private CryptoCipher encryptor;
  private CryptoCipher decryptor;

  AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException {
    this.appId = appId.getBytes(UTF_8);
    this.conf = conf;
    this.cryptoConf = conf.cryptoConf();
    this.secret = secret.toCharArray();
    this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf);
  }

  /**
   * Create the client challenge.
   *
   * @return A challenge to be sent the remote side.
   */
  ClientChallenge challenge() throws GeneralSecurityException {
    this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
    SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
      authNonce, conf.encryptionKeyLength());
    initializeForAuth(conf.cipherTransformation(), authNonce, authKey);

    this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
    return new ClientChallenge(new String(appId, UTF_8),
      conf.keyFactoryAlgorithm(),
      conf.keyFactoryIterations(),
      conf.cipherTransformation(),
      conf.encryptionKeyLength(),
      authNonce,
      challenge(appId, authNonce, challenge));
  }

  /**
   * Validates the client challenge, and create the encryption backend for the channel from the
   * parameters sent by the client.
   *
   * @param clientChallenge The challenge from the client.
   * @return A response to be sent to the client.
   */
  ServerResponse respond(ClientChallenge clientChallenge)
    throws GeneralSecurityException {

    SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
      clientChallenge.nonce, clientChallenge.keyLength);
    initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey);

    byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge);
    byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge));
    byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
    byte[] inputIv = randomBytes(conf.ivLength());
    byte[] outputIv = randomBytes(conf.ivLength());

    SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
      sessionNonce, clientChallenge.keyLength);
    this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey,
      inputIv, outputIv);

    // Note the IVs are swapped in the response.
    return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv));
  }

  /**
   * Validates the server response and initializes the cipher to use for the session.
   *
   * @param serverResponse The response from the server.
   */
  void validate(ServerResponse serverResponse) throws GeneralSecurityException {
    byte[] response = validateChallenge(authNonce, serverResponse.response);

    byte[] expected = rawResponse(challenge);
    Preconditions.checkArgument(Arrays.equals(expected, response));

    byte[] nonce = decrypt(serverResponse.nonce);
    byte[] inputIv = decrypt(serverResponse.inputIv);
    byte[] outputIv = decrypt(serverResponse.outputIv);

    SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
      nonce, conf.encryptionKeyLength());
    this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey,
      inputIv, outputIv);
  }

  TransportCipher sessionCipher() {
    Preconditions.checkState(sessionCipher != null);
    return sessionCipher;
  }

  @Override
  public void close() throws IOException {
    // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that
    // internal state is cleaned up. Error handling here is just for paranoia, and not meant to
    // accurately report the errors when they happen.
    RuntimeException error = null;
    byte[] dummy = new byte[8];
    if (encryptor != null) {
      try {
        doCipherOp(Cipher.ENCRYPT_MODE, dummy, true);
      } catch (Exception e) {
        error = new RuntimeException(e);
      }
      encryptor = null;
    }
    if (decryptor != null) {
      try {
        doCipherOp(Cipher.DECRYPT_MODE, dummy, true);
      } catch (Exception e) {
        error = new RuntimeException(e);
      }
      decryptor = null;
    }
    random.close();

    if (error != null) {
      throw error;
    }
  }

  @VisibleForTesting
  byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException {
    return encrypt(Bytes.concat(appId, nonce, challenge));
  }

  @VisibleForTesting
  byte[] rawResponse(byte[] challenge) {
    BigInteger orig = new BigInteger(challenge);
    BigInteger response = orig.add(ONE);
    return response.toByteArray();
  }

  private byte[] decrypt(byte[] in) throws GeneralSecurityException {
    return doCipherOp(Cipher.DECRYPT_MODE, in, false);
  }

  private byte[] encrypt(byte[] in) throws GeneralSecurityException {
    return doCipherOp(Cipher.ENCRYPT_MODE, in, false);
  }

  private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key)
    throws GeneralSecurityException {

    // commons-crypto currently only supports ciphers that require an initial vector; so
    // create a dummy vector so that we can initialize the ciphers. In the future, if
    // different ciphers are supported, this will have to be configurable somehow.
    byte[] iv = new byte[conf.ivLength()];
    System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));

    CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
    _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
    this.encryptor = _encryptor;

    CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
    _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
    this.decryptor = _decryptor;
  }

  /**
   * Validates an encrypted challenge as defined in the protocol, and returns the byte array
   * that corresponds to the actual challenge data.
   */
  private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge)
    throws GeneralSecurityException {

    byte[] challenge = decrypt(encryptedChallenge);
    checkSubArray(appId, challenge, 0);
    checkSubArray(nonce, challenge, appId.length);
    return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length);
  }

  private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength)
    throws GeneralSecurityException {

    SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf);
    PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength);

    long start = System.nanoTime();
    SecretKey key = factory.generateSecret(spec);
    long end = System.nanoTime();

    LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(),
      (end - start) / 1000);

    return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
  }

  private byte[] doCipherOp(int mode, byte[] in, boolean isFinal)
    throws GeneralSecurityException {

    CryptoCipher cipher;
    switch (mode) {
      case Cipher.ENCRYPT_MODE:
        cipher = encryptor;
        break;
      case Cipher.DECRYPT_MODE:
        cipher = decryptor;
        break;
      default:
        throw new IllegalArgumentException(String.valueOf(mode));
    }

    Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error.");

    try {
      int scale = 1;
      while (true) {
        int size = in.length * scale;
        byte[] buffer = new byte[size];
        try {
          int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
            : cipher.update(in, 0, in.length, buffer, 0);
          if (outSize != buffer.length) {
            byte[] output = new byte[outSize];
            System.arraycopy(buffer, 0, output, 0, output.length);
            return output;
          } else {
            return buffer;
          }
        } catch (ShortBufferException e) {
          // Try again with a bigger buffer.
          scale *= 2;
        }
      }
    } catch (InternalError ie) {
      // SPARK-25535. The commons-cryto library will throw InternalError if something goes wrong,
      // and leave bad state behind in the Java wrappers, so it's not safe to use them afterwards.
      if (mode == Cipher.ENCRYPT_MODE) {
        this.encryptor = null;
      } else {
        this.decryptor = null;
      }
      throw ie;
    }
  }

  private byte[] randomBytes(int count) {
    byte[] bytes = new byte[count];
    random.nextBytes(bytes);
    return bytes;
  }

  /** Checks that the "test" array is in the data array starting at the given offset. */
  private void checkSubArray(byte[] test, byte[] data, int offset) {
    Preconditions.checkArgument(data.length >= test.length + offset);
    for (int i = 0; i < test.length; i++) {
      Preconditions.checkArgument(test[i] == data[i + offset]);
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy