All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.mapreduce.security.SecureShuffleUtils Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.mapreduce.security;


import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import javax.crypto.SecretKey;
import javax.servlet.http.HttpServletRequest;

import org.apache.commons.codec.binary.Base64;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.security.token.JobTokenSecretManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Charsets;

/**
 * 
 * utilities for generating kyes, hashes and verifying them for shuffle
 *
 */
@InterfaceAudience.Private
@InterfaceStability.Unstable
public class SecureShuffleUtils {
  private static final Logger LOG =
      LoggerFactory.getLogger(SecureShuffleUtils.class);
  
  public static final String HTTP_HEADER_URL_HASH = "UrlHash";
  public static final String HTTP_HEADER_REPLY_URL_HASH = "ReplyHash";
  
  /**
   * Base64 encoded hash of msg
   * @param msg
   */
  public static String generateHash(byte[] msg, SecretKey key) {
    return new String(Base64.encodeBase64(generateByteHash(msg, key)), 
        Charsets.UTF_8);
  }
  
  /**
   * calculate hash of msg
   * @param msg
   * @return
   */
  private static byte[] generateByteHash(byte[] msg, SecretKey key) {
    return JobTokenSecretManager.computeHash(msg, key);
  }
  
  /**
   * verify that hash equals to HMacHash(msg)
   * @param newHash
   * @return true if is the same
   */
  private static boolean verifyHash(byte[] hash, byte[] msg, SecretKey key) {
    byte[] msg_hash = generateByteHash(msg, key);
    return WritableComparator.compareBytes(msg_hash, 0, msg_hash.length, hash, 0, hash.length) == 0;
  }
  
  /**
   * Aux util to calculate hash of a String
   * @param enc_str
   * @param key
   * @return Base64 encodedHash
   * @throws IOException
   */
  public static String hashFromString(String enc_str, SecretKey key) 
  throws IOException {
    return generateHash(enc_str.getBytes(Charsets.UTF_8), key); 
  }
  
  /**
   * verify that base64Hash is same as HMacHash(msg)  
   * @param base64Hash (Base64 encoded hash)
   * @param msg
   * @throws IOException if not the same
   */
  public static void verifyReply(String base64Hash, String msg, SecretKey key)
  throws IOException {
    byte[] hash = Base64.decodeBase64(base64Hash.getBytes(Charsets.UTF_8));
    
    boolean res = verifyHash(hash, msg.getBytes(Charsets.UTF_8), key);
    
    if(res != true) {
      throw new IOException("Verification of the hashReply failed");
    }
  }
  
  /**
   * Shuffle specific utils - build string for encoding from URL
   * @param url
   * @return string for encoding
   */
  public static String buildMsgFrom(URL url) {
    return buildMsgFrom(url.getPath(), url.getQuery(), url.getPort());
  }
  /**
   * Shuffle specific utils - build string for encoding from URL
   * @param request
   * @return string for encoding
   */
  public static String buildMsgFrom(HttpServletRequest request ) {
    return buildMsgFrom(request.getRequestURI(), request.getQueryString(),
        request.getLocalPort());
  }
  /**
   * Shuffle specific utils - build string for encoding from URL
   * @param uri_path
   * @param uri_query
   * @return string for encoding
   */
  private static String buildMsgFrom(String uri_path, String uri_query, int port) {
    return String.valueOf(port) + uri_path + "?" + uri_query;
  }

  /**
   * byte array to Hex String
   * 
   * @param ba
   * @return string with HEX value of the key
   */
  public static String toHex(byte[] ba) {
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    String strHex = "";
    try {
      PrintStream ps = new PrintStream(baos, false, "UTF-8");
      for (byte b : ba) {
        ps.printf("%x", b);
      }
      strHex = baos.toString("UTF-8");
    } catch (UnsupportedEncodingException e) {
    }
    return strHex;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy