org.deeplearning4j.spark.util.SparkUtils Maven / Gradle / Ivy
package org.deeplearning4j.spark.util;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocatedFileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RemoteIterator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction2;
import org.deeplearning4j.spark.impl.common.repartition.AssignIndexFunction;
import org.deeplearning4j.spark.impl.common.repartition.BalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.MapTupleToPairFlatMap;
import org.slf4j.Logger;
import scala.Tuple2;
import java.io.*;
import java.lang.reflect.Array;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* Various utilities for Spark
*
* @author Alex Black
*/
public class SparkUtils {
private SparkUtils() {
}
/**
* Check the spark configuration for incorrect Kryo configuration, logging a warning message if necessary
*
* @param javaSparkContext Spark context
* @param log Logger to log messages to
* @return True if ok (no kryo, or correct kryo setup)
*/
public static boolean checkKryoConfiguration(JavaSparkContext javaSparkContext, Logger log) {
//Check if kryo configuration is correct:
String serializer = javaSparkContext.getConf().get("spark.serializer", null);
if (serializer != null && serializer.equals("org.apache.spark.serializer.KryoSerializer")) {
//conf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
String kryoRegistrator = javaSparkContext.getConf().get("spark.kryo.registrator", null);
if (kryoRegistrator == null || !kryoRegistrator.equals("org.nd4j.Nd4jRegistrator")) {
log.warn("***** Kryo serialization detected without Nd4j Registrator *****");
log.warn("***** ND4J Kryo registrator is required to avoid serialization (NullPointerException) issues on NDArrays *****");
log.warn("***** Use nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\"); *****");
return false;
}
}
return true;
}
/**
* Write a String to a file (on HDFS or local) in UTF-8 format
*
* @param path Path to write to
* @param toWrite String to write
* @param sc Spark context
*/
public static void writeStringToFile(String path, String toWrite, JavaSparkContext sc) throws IOException {
writeStringToFile(path, toWrite, sc.sc());
}
/**
* Write a String to a file (on HDFS or local) in UTF-8 format
*
* @param path Path to write to
* @param toWrite String to write
* @param sc Spark context
*/
public static void writeStringToFile(String path, String toWrite, SparkContext sc) throws IOException {
FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) {
bos.write(toWrite.getBytes("UTF-8"));
}
}
/**
* Read a UTF-8 format String from HDFS (or local)
*
* @param path Path to write the string
* @param sc Spark context
*/
public static String readStringFromFile(String path, JavaSparkContext sc) throws IOException {
return readStringFromFile(path, sc.sc());
}
/**
* Read a UTF-8 format String from HDFS (or local)
*
* @param path Path to write the string
* @param sc Spark context
*/
public static String readStringFromFile(String path, SparkContext sc) throws IOException {
FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
try (BufferedInputStream bis = new BufferedInputStream(fileSystem.open(new Path(path)))) {
byte[] asBytes = IOUtils.toByteArray(bis);
return new String(asBytes, "UTF-8");
}
}
/**
* Write an object to HDFS (or local) using default Java object serialization
*
* @param path Path to write the object to
* @param toWrite Object to write
* @param sc Spark context
*/
public static void writeObjectToFile(String path, Object toWrite, JavaSparkContext sc) throws IOException {
writeObjectToFile(path, toWrite, sc.sc());
}
/**
* Write an object to HDFS (or local) using default Java object serialization
*
* @param path Path to write the object to
* @param toWrite Object to write
* @param sc Spark context
*/
public static void writeObjectToFile(String path, Object toWrite, SparkContext sc) throws IOException {
FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) {
ObjectOutputStream oos = new ObjectOutputStream(bos);
oos.writeObject(toWrite);
}
}
/**
* Read an object from HDFS (or local) using default Java object serialization
*
* @param path File to read
* @param type Class of the object to read
* @param sc Spark context
* @param Type of the object to read
*/
public static T readObjectFromFile(String path, Class type, JavaSparkContext sc) throws IOException {
return readObjectFromFile(path, type, sc.sc());
}
/**
* Read an object from HDFS (or local) using default Java object serialization
*
* @param path File to read
* @param type Class of the object to read
* @param sc Spark context
* @param Type of the object to read
*/
public static T readObjectFromFile(String path, Class type, SparkContext sc) throws IOException {
FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(fileSystem.open(new Path(path))))) {
Object o;
try {
o = ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
return (T) o;
}
}
/**
* Repartition the specified RDD (or not) using the given {@link Repartition} and {@link RepartitionStrategy} settings
*
* @param rdd RDD to repartition
* @param repartition Setting for when repartiting is to be conducted
* @param repartitionStrategy Setting for how repartitioning is to be conducted
* @param objectsPerPartition Desired number of objects per partition
* @param numPartitions Total number of partitions
* @param Type of the RDD
* @return Repartitioned RDD, or original RDD if no repartitioning was conducted
*/
public static JavaRDD repartition(JavaRDD rdd, Repartition repartition, RepartitionStrategy repartitionStrategy,
int objectsPerPartition, int numPartitions) {
if (repartition == Repartition.Never) return rdd;
switch (repartitionStrategy) {
case SparkDefault:
if (repartition == Repartition.NumPartitionsWorkersDiffers && rdd.partitions().size() == numPartitions)
return rdd;
//Either repartition always, or workers/num partitions differs
return rdd.repartition(numPartitions);
case Balanced:
return repartitionBalanceIfRequired(rdd, repartition, objectsPerPartition, numPartitions);
default:
throw new RuntimeException("Unknown repartition strategy: " + repartitionStrategy);
}
}
/**
* Repartition a RDD (given the {@link Repartition} setting) such that we have approximately {@code numPartitions} partitions,
* each of which has {@code objectsPerPartition} objects.
*
* @param rdd RDD to repartition
* @param repartition Repartitioning setting
* @param objectsPerPartition Number of objects we want in each partition
* @param numPartitions Number of partitions to have
* @param Type of RDD
* @return Repartitioned RDD, or the original RDD if no repartitioning was performed
*/
public static JavaRDD repartitionBalanceIfRequired(JavaRDD rdd, Repartition repartition, int objectsPerPartition, int numPartitions) {
int origNumPartitions = rdd.partitions().size();
switch (repartition) {
case Never:
return rdd;
case NumPartitionsWorkersDiffers:
if (origNumPartitions == numPartitions) return rdd;
case Always:
//Repartition: either always, or origNumPartitions != numWorkers
//First: count number of elements in each partition. Need to know this so we can work out how to properly index each example,
// so we can in turn create properly balanced partitions after repartitioning
//Because the objects (DataSets etc) should be small, this should be OK
//Count each partition...
List> partitionCounts = rdd.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect();
int totalObjects = 0;
int initialPartitions = partitionCounts.size();
boolean allCorrectSize = true;
int[] countPerPartition = new int[partitionCounts.size()];
int x = 0;
for (Tuple2 t2 : partitionCounts) {
int partitionSize = t2._2();
countPerPartition[x++] = partitionSize;
allCorrectSize &= (partitionSize == objectsPerPartition);
totalObjects += t2._2();
}
if (numPartitions * objectsPerPartition < totalObjects) {
allCorrectSize = true;
for (Tuple2 t2 : partitionCounts) {
allCorrectSize &= (t2._2() == objectsPerPartition);
}
}
if (initialPartitions == numPartitions && allCorrectSize) {
//Don't need to do any repartitioning here - already in the format we want
return rdd;
}
//In each partition: work out the start offset (so we can work out the index of each element)
int[] elementStartOffsetByPartitions = new int[countPerPartition.length];
for (int i = 1; i < elementStartOffsetByPartitions.length; i++) {
elementStartOffsetByPartitions[i] = elementStartOffsetByPartitions[i - 1] + countPerPartition[i - 1];
}
//Index each element for repartitioning (can only do manual repartitioning on a JavaPairRDD)
JavaRDD> indexed = rdd.mapPartitionsWithIndex(new AssignIndexFunction(elementStartOffsetByPartitions), true);
JavaPairRDD pairIndexed = indexed.mapPartitionsToPair(new MapTupleToPairFlatMap(), true);
int remainder = (totalObjects - numPartitions * objectsPerPartition) % numPartitions;
pairIndexed = pairIndexed.partitionBy(new BalancedPartitioner(numPartitions, objectsPerPartition, remainder));
return pairIndexed.values();
default:
throw new RuntimeException("Unknown setting for repartition: " + repartition);
}
}
/**
* Random split the specified RDD into a number of RDDs, where each has {@code numObjectsPerSplit} in them.
*
* This similar to how RDD.randomSplit works (i.e., split via filtering), but this should result in more
* equal splits (instead of independent binomial sampling that is used there, based on weighting)
* This balanced splitting approach is important when the number of DataSet objects we want in each split is small,
* as random sampling variance of {@link JavaRDD#randomSplit(double[])} is quite large relative to the number of examples
* in each split. Note however that this method doesn't guarantee that partitions will be balanced
*
* Downside is we need total object count (whereas {@link JavaRDD#randomSplit(double[])} does not). However, randomSplit
* requires a full pass of the data anyway (in order to do filtering upon it) so this should not add much overhead in practice
*
* @param totalObjectCount Total number of objects in the RDD to split
* @param numObjectsPerSplit Number of objects in each split
* @param data Data to split
* @param Generic type for the RDD
* @return The RDD split up (without replacetement) into a number of smaller RDDs
*/
public static JavaRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD data) {
return balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong());
}
/**
* Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} with control over the RNG seed
*/
public static JavaRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD data, long rngSeed) {
JavaRDD[] splits;
if (totalObjectCount <= numObjectsPerSplit) {
splits = (JavaRDD[]) Array.newInstance(JavaRDD.class, 1);
splits[0] = data;
} else {
int numSplits = totalObjectCount / numObjectsPerSplit; //Intentional round down
splits = (JavaRDD[]) Array.newInstance(JavaRDD.class, numSplits);
for (int i = 0; i < numSplits; i++) {
splits[i] = data.mapPartitionsWithIndex(new SplitPartitionsFunction(i, numSplits, rngSeed), true);
}
}
return splits;
}
/**
* Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} but for Pair RDDs
*/
public static JavaPairRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaPairRDD data) {
return balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong());
}
/**
* Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} but for pair RDDs, and with control over the RNG seed
*/
public static JavaPairRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaPairRDD data, long rngSeed) {
JavaPairRDD[] splits;
if (totalObjectCount <= numObjectsPerSplit) {
splits = (JavaPairRDD[]) Array.newInstance(JavaPairRDD.class, 1);
splits[0] = data;
} else {
int numSplits = totalObjectCount / numObjectsPerSplit; //Intentional round down
splits = (JavaPairRDD[]) Array.newInstance(JavaPairRDD.class, numSplits);
for (int i = 0; i < numSplits; i++) {
//What we really need is a .mapPartitionsToPairWithIndex function
//but, of course Spark doesn't provide this
//So we need to do a two-step process here...
JavaRDD> split = data.mapPartitionsWithIndex(new SplitPartitionsFunction2(i, numSplits, rngSeed), true);
splits[i] = split.mapPartitionsToPair(new MapTupleToPairFlatMap(), true);
}
}
return splits;
}
/**
* List of the files in the given directory (path), as a {@code JavaRDD}
*
* @param sc Spark context
* @param path Path to list files in
* @return Paths in the directory
* @throws IOException If error occurs getting directory contents
*/
public static JavaRDD listPaths(JavaSparkContext sc, String path) throws IOException {
List paths = new ArrayList<>();
Configuration config = new Configuration();
FileSystem hdfs = FileSystem.get(URI.create(path), config);
RemoteIterator fileIter = hdfs.listFiles(new org.apache.hadoop.fs.Path(path), false);
while (fileIter.hasNext()) {
String filePath = fileIter.next().getPath().toString();
paths.add(filePath);
}
return sc.parallelize(paths);
}
}