All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.canova.RDDMiniBatches Maven / Gradle / Ivy

package org.deeplearning4j.spark.canova;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.rdd.RDD;
import org.deeplearning4j.spark.ordering.DataSetOrdering;
import org.nd4j.linalg.dataset.DataSet;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/**
 * Created by agibsonccc on 1/18/15.
 */
public class RDDMiniBatches  implements Serializable {
    private int miniBatches = 10;
    private JavaRDD toSplitJava;

    public RDDMiniBatches(int miniBatches, JavaRDD toSplit) {
        this.miniBatches = miniBatches;
        this.toSplitJava = toSplit;
    }

    public JavaRDD miniBatchesJava() {
        final int batchSize = miniBatches;
        JavaRDD miniBatches = toSplitJava.mapPartitions(new FlatMapFunction, DataSet>() {
            @Override
            public Iterable call(Iterator dataSetIterator) throws Exception {
                List ret = new ArrayList<>();
                List temp = new ArrayList<>();
                while (dataSetIterator.hasNext()) {
                    temp.add(dataSetIterator.next());
                    if (temp.size() == batchSize) {
                        ret.add(DataSet.merge(temp));
                        temp.clear();
                    }
                }

                if(!temp.isEmpty())
                    ret.add(DataSet.merge(temp));

                return ret;
            }
        });

        return miniBatches;
    }




}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy