All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.util.SparkUtils Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.util;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocatedFileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RemoteIterator;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.data.BatchDataSetsFunction;
import org.deeplearning4j.spark.data.shuffle.SplitDataSetExamplesPairFlatMapFunction;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction2;
import org.deeplearning4j.spark.impl.common.repartition.BalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.HashingBalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.MapTupleToPairFlatMap;
import org.deeplearning4j.spark.impl.repartitioner.EqualRepartitioner;
import org.deeplearning4j.core.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import scala.Tuple2;

import java.io.*;
import java.lang.reflect.Array;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.*;

@Slf4j
public class SparkUtils {

    private static final String KRYO_EXCEPTION_MSG = "Kryo serialization detected without an appropriate registrator "
                    + "for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid"
                    + " serialization issues (NullPointerException) with off-heap data in INDArrays.\n"
                    + "Use nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.kryo.Nd4jRegistrator\");\n"
                    + "See https://deeplearning4j.konduit.ai/distributed-deep-learning/howto#how-to-use-kryo-serialization-with-dl-4-j-and-nd-4-j for more details";

    private static String sparkExecutorId;

    private SparkUtils() {}

    /**
     * Check the spark configuration for incorrect Kryo configuration, logging a warning message if necessary
     *
     * @param javaSparkContext Spark context
     * @param log              Logger to log messages to
     * @return True if ok (no kryo, or correct kryo setup)
     */
    public static boolean checkKryoConfiguration(JavaSparkContext javaSparkContext, Logger log) {
        //Check if kryo configuration is correct:
        String serializer = javaSparkContext.getConf().get("spark.serializer", null);
        if (serializer != null && serializer.equals("org.apache.spark.serializer.KryoSerializer")) {
            String kryoRegistrator = javaSparkContext.getConf().get("spark.kryo.registrator", null);
            if (kryoRegistrator == null || !kryoRegistrator.equals("org.nd4j.kryo.Nd4jRegistrator")) {

                //It's probably going to fail later due to Kryo failing on the INDArray deserialization (off-heap data)
                //But: the user might be using a custom Kryo registrator that can handle ND4J INDArrays, even if they
                // aren't using the official ND4J-provided one
                //Either way: Let's test serialization now of INDArrays now, and fail early if necessary
                SerializerInstance si;
                ByteBuffer bb;
                try {
                    si = javaSparkContext.env().serializer().newInstance();
                    bb = si.serialize(Nd4j.linspace(1, 5, 5), null);
                } catch (Exception e) {
                    //Failed for some unknown reason during serialization - should never happen
                    throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
                }

                if (bb == null) {
                    //Should probably never happen
                    throw new RuntimeException(
                                    KRYO_EXCEPTION_MSG + "\n(Got: null ByteBuffer from Spark SerializerInstance)");
                } else {
                    //Could serialize successfully, but still may not be able to deserialize if kryo config is wrong
                    boolean equals;
                    INDArray deserialized;
                    try {
                        deserialized = (INDArray) si.deserialize(bb, null);
                        //Equals method may fail on malformed INDArrays, hence should be within the try-catch
                        equals = Nd4j.linspace(1, 5, 5).equals(deserialized);
                    } catch (Exception e) {
                        throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
                    }
                    if (!equals) {
                        throw new RuntimeException(KRYO_EXCEPTION_MSG + "\n(Error during deserialization: test array"
                                        + " was not deserialized successfully)");
                    }

                    //Otherwise: serialization/deserialization was successful using Kryo
                    return true;
                }
            }
        }
        return true;
    }

    /**
     * Write a String to a file (on HDFS or local) in UTF-8 format
     *
     * @param path    Path to write to
     * @param toWrite String to write
     * @param sc      Spark context
     */
    public static void writeStringToFile(String path, String toWrite, JavaSparkContext sc) throws IOException {
        writeStringToFile(path, toWrite, sc.sc());
    }

    /**
     * Write a String to a file (on HDFS or local) in UTF-8 format
     *
     * @param path    Path to write to
     * @param toWrite String to write
     * @param sc      Spark context
     */
    public static void writeStringToFile(String path, String toWrite, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) {
            bos.write(toWrite.getBytes("UTF-8"));
        }
    }

    /**
     * Read a UTF-8 format String from HDFS (or local)
     *
     * @param path Path to write the string
     * @param sc   Spark context
     */
    public static String readStringFromFile(String path, JavaSparkContext sc) throws IOException {
        return readStringFromFile(path, sc.sc());
    }

    /**
     * Read a UTF-8 format String from HDFS (or local)
     *
     * @param path Path to write the string
     * @param sc   Spark context
     */
    public static String readStringFromFile(String path, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (BufferedInputStream bis = new BufferedInputStream(fileSystem.open(new Path(path)))) {
            byte[] asBytes = IOUtils.toByteArray(bis);
            return new String(asBytes, "UTF-8");
        }
    }

    /**
     * Write an object to HDFS (or local) using default Java object serialization
     *
     * @param path    Path to write the object to
     * @param toWrite Object to write
     * @param sc      Spark context
     */
    public static void writeObjectToFile(String path, Object toWrite, JavaSparkContext sc) throws IOException {
        writeObjectToFile(path, toWrite, sc.sc());
    }

    /**
     * Write an object to HDFS (or local) using default Java object serialization
     *
     * @param path    Path to write the object to
     * @param toWrite Object to write
     * @param sc      Spark context
     */
    public static void writeObjectToFile(String path, Object toWrite, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) {
            ObjectOutputStream oos = new ObjectOutputStream(bos);
            oos.writeObject(toWrite);
        }
    }

    /**
     * Read an object from HDFS (or local) using default Java object serialization
     *
     * @param path File to read
     * @param type Class of the object to read
     * @param sc   Spark context
     * @param   Type of the object to read
     */
    public static  T readObjectFromFile(String path, Class type, JavaSparkContext sc) throws IOException {
        return readObjectFromFile(path, type, sc.sc());
    }

    /**
     * Read an object from HDFS (or local) using default Java object serialization
     *
     * @param path File to read
     * @param type Class of the object to read
     * @param sc   Spark context
     * @param   Type of the object to read
     */
    public static  T readObjectFromFile(String path, Class type, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(fileSystem.open(new Path(path))))) {
            Object o;
            try {
                o = ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            return (T) o;
        }
    }

    /**
     * Repartition the specified RDD (or not) using the given {@link Repartition} and {@link RepartitionStrategy} settings
     *
     * @param rdd                 RDD to repartition
     * @param repartition         Setting for when repartiting is to be conducted
     * @param repartitionStrategy Setting for how repartitioning is to be conducted
     * @param objectsPerPartition Desired number of objects per partition
     * @param numPartitions       Total number of partitions
     * @param                  Type of the RDD
     * @return Repartitioned RDD, or original RDD if no repartitioning was conducted
     */
    public static  JavaRDD repartition(JavaRDD rdd, Repartition repartition,
                    RepartitionStrategy repartitionStrategy, int objectsPerPartition, int numPartitions) {
        if (repartition == Repartition.Never)
            return rdd;

        switch (repartitionStrategy) {
            case SparkDefault:
                if (repartition == Repartition.NumPartitionsWorkersDiffers && rdd.partitions().size() == numPartitions)
                    return rdd;

                //Either repartition always, or workers/num partitions differs
                return rdd.repartition(numPartitions);
            case Balanced:
                return repartitionBalanceIfRequired(rdd, repartition, objectsPerPartition, numPartitions);
            case ApproximateBalanced:
                return repartitionApproximateBalance(rdd, repartition, numPartitions);
            default:
                throw new RuntimeException("Unknown repartition strategy: " + repartitionStrategy);
        }
    }

    public static  JavaRDD repartitionApproximateBalance(JavaRDD rdd, Repartition repartition,
                    int numPartitions) {
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never:
                return rdd;
            case NumPartitionsWorkersDiffers:
                if (origNumPartitions == numPartitions)
                    return rdd;
            case Always:
                // Count each partition...
                List partitionCounts =
                                rdd.mapPartitionsWithIndex(new Function2, Iterator>() {
                                    @Override
                                    public Iterator call(Integer integer, Iterator tIterator)
                                                    throws Exception {
                                        int count = 0;
                                        while (tIterator.hasNext()) {
                                            tIterator.next();
                                            count++;
                                        }
                                        return Collections.singletonList(count).iterator();
                                    }
                                }, true).collect();

                Integer totalCount = 0;
                for (Integer i : partitionCounts)
                    totalCount += i;
                List partitionWeights = new ArrayList<>(Math.max(numPartitions, origNumPartitions));
                Double ideal = (double) totalCount / numPartitions;
                // partitions in the initial set and not in the final one get -1 => elements always jump
                // partitions in the final set not in the initial one get 0 => aim to receive the average amount
                for (int i = 0; i < Math.min(origNumPartitions, numPartitions); i++) {
                    partitionWeights.add((double) partitionCounts.get(i) / ideal);
                }
                for (int i = Math.min(origNumPartitions, numPartitions); i < Math.max(origNumPartitions,
                                numPartitions); i++) {
                    // we shrink the # of partitions
                    if (i >= numPartitions)
                        partitionWeights.add(-1D);
                    // we enlarge the # of partitions
                    else
                        partitionWeights.add(0D);
                }

                // this method won't trigger a spark job, which is different from {@link org.apache.spark.rdd.RDD#zipWithIndex}

                JavaPairRDD, T> indexedRDD = rdd.zipWithUniqueId()
                                .mapToPair(new PairFunction, Tuple2, T>() {
                                    @Override
                                    public Tuple2, T> call(Tuple2 tLongTuple2) {
                                        return new Tuple2<>(
                                                        new Tuple2(tLongTuple2._2(), 0),
                                                        tLongTuple2._1());
                                    }
                                });

                HashingBalancedPartitioner hbp =
                                new HashingBalancedPartitioner(Collections.singletonList(partitionWeights));
                JavaPairRDD, T> partitionedRDD = indexedRDD.partitionBy(hbp);

                return partitionedRDD.map(new Function, T>, T>() {
                    @Override
                    public T call(Tuple2, T> indexNPayload) {
                        return indexNPayload._2();
                    }
                });
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
    }

    /**
     * Repartition a RDD (given the {@link Repartition} setting) such that we have approximately
     * {@code numPartitions} partitions, each of which has {@code objectsPerPartition} objects.
     *
     * @param rdd RDD to repartition
     * @param repartition Repartitioning setting
     * @param objectsPerPartition Number of objects we want in each partition
     * @param numPartitions Number of partitions to have
     * @param  Type of RDD
     * @return Repartitioned RDD, or the original RDD if no repartitioning was performed
     */
    public static  JavaRDD repartitionBalanceIfRequired(JavaRDD rdd, Repartition repartition,
                    int objectsPerPartition, int numPartitions) {
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never:
                return rdd;
            case NumPartitionsWorkersDiffers:
                if (origNumPartitions == numPartitions)
                    return rdd;
            case Always:
                //Repartition: either always, or origNumPartitions != numWorkers

                //First: count number of elements in each partition. Need to know this so we can work out how to properly index each example,
                // so we can in turn create properly balanced partitions after repartitioning
                //Because the objects (DataSets etc) should be small, this should be OK

                //Count each partition...
                List> partitionCounts =
                                rdd.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect();
                int totalObjects = 0;
                int initialPartitions = partitionCounts.size();

                boolean allCorrectSize = true;
                int x = 0;
                for (Tuple2 t2 : partitionCounts) {
                    int partitionSize = t2._2();
                    allCorrectSize &= (partitionSize == objectsPerPartition);
                    totalObjects += t2._2();
                }

                if (numPartitions * objectsPerPartition < totalObjects) {
                    allCorrectSize = true;
                    for (Tuple2 t2 : partitionCounts) {
                        allCorrectSize &= (t2._2() == objectsPerPartition);
                    }
                }

                if (initialPartitions == numPartitions && allCorrectSize) {
                    //Don't need to do any repartitioning here - already in the format we want
                    return rdd;
                }

                //Index each element for repartitioning (can only do manual repartitioning on a JavaPairRDD)
                JavaPairRDD pairIndexed = indexedRDD(rdd);

                int remainder = (totalObjects - numPartitions * objectsPerPartition) % numPartitions;
                log.trace("About to rebalance: numPartitions={}, objectsPerPartition={}, remainder={}", numPartitions, objectsPerPartition, remainder);
                pairIndexed = pairIndexed
                                .partitionBy(new BalancedPartitioner(numPartitions, objectsPerPartition, remainder));
                return pairIndexed.values();
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
    }

    public static  JavaPairRDD indexedRDD(JavaRDD rdd) {
        return rdd.zipWithIndex().mapToPair(new PairFunction, Integer, T>() {
            @Override
            public Tuple2 call(Tuple2 elemIdx) {
                return new Tuple2<>(elemIdx._2().intValue(), elemIdx._1());
            }
        });
    }

    public static  JavaRDD repartitionEqually(JavaRDD rdd, Repartition repartition, int numPartitions){
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never:
                return rdd;
            case NumPartitionsWorkersDiffers:
                if (origNumPartitions == numPartitions)
                    return rdd;
            case Always:
                return new EqualRepartitioner().repartition(rdd, -1, numPartitions);
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
    }

    /**
     * Random split the specified RDD into a number of RDDs, where each has {@code numObjectsPerSplit} in them.
     * 

* This similar to how RDD.randomSplit works (i.e., split via filtering), but this should result in more * equal splits (instead of independent binomial sampling that is used there, based on weighting) * This balanced splitting approach is important when the number of DataSet objects we want in each split is small, * as random sampling variance of {@link JavaRDD#randomSplit(double[])} is quite large relative to the number of examples * in each split. Note however that this method doesn't guarantee that partitions will be balanced *

* Downside is we need total object count (whereas {@link JavaRDD#randomSplit(double[])} does not). However, randomSplit * requires a full pass of the data anyway (in order to do filtering upon it) so this should not add much overhead in practice * * @param totalObjectCount Total number of objects in the RDD to split * @param numObjectsPerSplit Number of objects in each split * @param data Data to split * @param Generic type for the RDD * @return The RDD split up (without replacement) into a number of smaller RDDs */ public static JavaRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD data) { return balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong()); } /** * Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} with control over the RNG seed */ public static JavaRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD data, long rngSeed) { JavaRDD[] splits; if (totalObjectCount <= numObjectsPerSplit) { splits = (JavaRDD[]) Array.newInstance(JavaRDD.class, 1); splits[0] = data; } else { int numSplits = totalObjectCount / numObjectsPerSplit; //Intentional round down splits = (JavaRDD[]) Array.newInstance(JavaRDD.class, numSplits); for (int i = 0; i < numSplits; i++) { splits[i] = data.mapPartitionsWithIndex(new SplitPartitionsFunction(i, numSplits, rngSeed), true); } } return splits; } /** * Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} but for Pair RDDs */ public static JavaPairRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaPairRDD data) { return balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong()); } /** * Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} but for pair RDDs, and with control over the RNG seed */ public static JavaPairRDD[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaPairRDD data, long rngSeed) { JavaPairRDD[] splits; if (totalObjectCount <= numObjectsPerSplit) { splits = (JavaPairRDD[]) Array.newInstance(JavaPairRDD.class, 1); splits[0] = data; } else { int numSplits = totalObjectCount / numObjectsPerSplit; //Intentional round down splits = (JavaPairRDD[]) Array.newInstance(JavaPairRDD.class, numSplits); for (int i = 0; i < numSplits; i++) { //What we really need is a .mapPartitionsToPairWithIndex function //but, of course Spark doesn't provide this //So we need to do a two-step process here... JavaRDD> split = data.mapPartitionsWithIndex( new SplitPartitionsFunction2(i, numSplits, rngSeed), true); splits[i] = split.mapPartitionsToPair(new MapTupleToPairFlatMap(), true); } } return splits; } /** * List of the files in the given directory (path), as a {@code JavaRDD} * * @param sc Spark context * @param path Path to list files in * @return Paths in the directory * @throws IOException If error occurs getting directory contents */ public static JavaRDD listPaths(JavaSparkContext sc, String path) throws IOException { return listPaths(sc, path, false); } /** * List of the files in the given directory (path), as a {@code JavaRDD} * * @param sc Spark context * @param path Path to list files in * @param recursive Whether to walk the directory tree recursively (i.e., include subdirectories) * @return Paths in the directory * @throws IOException If error occurs getting directory contents */ public static JavaRDD listPaths(JavaSparkContext sc, String path, boolean recursive) throws IOException { //NativeImageLoader.ALLOWED_FORMATS return listPaths(sc, path, recursive, (Set)null); } /** * List of the files in the given directory (path), as a {@code JavaRDD} * * @param sc Spark context * @param path Path to list files in * @param recursive Whether to walk the directory tree recursively (i.e., include subdirectories) * @param allowedExtensions If null: all files will be accepted. If non-null: only files with the specified extension will be allowed. * Exclude the extension separator - i.e., use "txt" not ".txt" here. * @return Paths in the directory * @throws IOException If error occurs getting directory contents */ public static JavaRDD listPaths(JavaSparkContext sc, String path, boolean recursive, String[] allowedExtensions) throws IOException { return listPaths(sc, path, recursive, (allowedExtensions == null ? null : new HashSet<>(Arrays.asList(allowedExtensions)))); } /** * List of the files in the given directory (path), as a {@code JavaRDD} * * @param sc Spark context * @param path Path to list files in * @param recursive Whether to walk the directory tree recursively (i.e., include subdirectories) * @param allowedExtensions If null: all files will be accepted. If non-null: only files with the specified extension will be allowed. * Exclude the extension separator - i.e., use "txt" not ".txt" here. * @return Paths in the directory * @throws IOException If error occurs getting directory contents */ public static JavaRDD listPaths(JavaSparkContext sc, String path, boolean recursive, Set allowedExtensions) throws IOException { return listPaths(sc, path, recursive, allowedExtensions, sc.hadoopConfiguration()); } /** * List of the files in the given directory (path), as a {@code JavaRDD} * * @param sc Spark context * @param path Path to list files in * @param recursive Whether to walk the directory tree recursively (i.e., include subdirectories) * @param allowedExtensions If null: all files will be accepted. If non-null: only files with the specified extension will be allowed. * Exclude the extension separator - i.e., use "txt" not ".txt" here. * @param config Hadoop configuration to use. Must not be null. * @return Paths in the directory * @throws IOException If error occurs getting directory contents */ public static JavaRDD listPaths(@NonNull JavaSparkContext sc, String path, boolean recursive, Set allowedExtensions, @NonNull Configuration config) throws IOException { List paths = new ArrayList<>(); FileSystem hdfs = FileSystem.get(URI.create(path), config); RemoteIterator fileIter = hdfs.listFiles(new org.apache.hadoop.fs.Path(path), recursive); while (fileIter.hasNext()) { String filePath = fileIter.next().getPath().toString(); if(allowedExtensions == null){ paths.add(filePath); } else { String ext = FilenameUtils.getExtension(path); if(allowedExtensions.contains(ext)){ paths.add(filePath); } } } return sc.parallelize(paths); } /** * Randomly shuffle the examples in each DataSet object, and recombine them into new DataSet objects * with the specified BatchSize * * @param rdd DataSets to shuffle/recombine * @param newBatchSize New batch size for the DataSet objects, after shuffling/recombining * @param numPartitions Number of partitions to use when splitting/recombining * @return A new {@link JavaRDD}, with the examples shuffled/combined in each */ public static JavaRDD shuffleExamples(JavaRDD rdd, int newBatchSize, int numPartitions) { //Step 1: split into individual examples, mapping to a pair RDD (random key in range 0 to numPartitions) JavaPairRDD singleExampleDataSets = rdd.flatMapToPair(new SplitDataSetExamplesPairFlatMapFunction(numPartitions)); //Step 2: repartition according to the random keys singleExampleDataSets = singleExampleDataSets.partitionBy(new HashPartitioner(numPartitions)); //Step 3: Recombine return singleExampleDataSets.values().mapPartitions(new BatchDataSetsFunction(newBatchSize)); } /** * Get the Spark executor ID
* The ID is parsed from the JVM launch args. If that is not specified (or can't be obtained) then the value * from {@link UIDProvider#getJVMUID()} is returned * @return */ public static String getSparkExecutorId(){ if(sparkExecutorId != null) return sparkExecutorId; synchronized (SparkUtils.class){ //re-check, in case some other thread set it while waiting for lock if(sparkExecutorId != null) return sparkExecutorId; String s = System.getProperty("sun.java.command"); if(s == null || s.isEmpty() || !s.contains("executor-id")){ sparkExecutorId = UIDProvider.getJVMUID(); return sparkExecutorId; } int idx = s.indexOf("executor-id"); String sub = s.substring(idx); String[] split = sub.split(" "); if(split.length < 2){ sparkExecutorId = UIDProvider.getJVMUID(); return sparkExecutorId; } sparkExecutorId = split[1]; return sparkExecutorId; } } public static Broadcast asByteArrayBroadcast(JavaSparkContext sc, INDArray array){ ByteArrayOutputStream baos = new ByteArrayOutputStream(); try { Nd4j.write(array, new DataOutputStream(baos)); } catch (IOException e){ throw new RuntimeException(e); //Should never happen } byte[] paramBytes = baos.toByteArray(); //See docs in EvaluationRunner for why we use byte[] instead of INDArray (thread locality etc) return sc.broadcast(paramBytes); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy