
org.deeplearning4j.spark.canova.RDDMiniBatches Maven / Gradle / Ivy
package org.deeplearning4j.spark.canova;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.rdd.RDD;
import org.deeplearning4j.spark.ordering.DataSetOrdering;
import org.nd4j.linalg.dataset.DataSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* Created by agibsonccc on 1/18/15.
*/
public class RDDMiniBatches implements Serializable {
private int miniBatches = 10;
private JavaRDD toSplitJava;
public RDDMiniBatches(int miniBatches, JavaRDD toSplit) {
this.miniBatches = miniBatches;
this.toSplitJava = toSplit;
}
public JavaRDD miniBatchesJava() {
final int batchSize = miniBatches;
JavaRDD miniBatches = toSplitJava.mapPartitions(new FlatMapFunction, DataSet>() {
@Override
public Iterable call(Iterator dataSetIterator) throws Exception {
List ret = new ArrayList<>();
List temp = new ArrayList<>();
while (dataSetIterator.hasNext()) {
temp.add(dataSetIterator.next());
if (temp.size() == batchSize) {
ret.add(DataSet.merge(temp));
temp.clear();
}
}
if(!temp.isEmpty())
ret.add(DataSet.merge(temp));
return ret;
}
});
return miniBatches;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy