
org.deeplearning4j.spark.text.functions.CountCumSum Maven / Gradle / Ivy
package org.deeplearning4j.spark.text.functions;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.spark.text.accumulators.MaxPerPartitionAccumulator;
import java.util.concurrent.atomic.AtomicLong;
/**
* @author jeffreytang
*/
@SuppressWarnings("unchecked")
public class CountCumSum {
// Starting variables
private JavaSparkContext sc;
private JavaRDD sentenceCountRDD;
// Variables to fill in as we go
private JavaRDD foldWithinPartitionRDD;
private Broadcast> broadcastedMaxPerPartitionCounter;
private JavaRDD cumSumRDD;
// Constructor
public CountCumSum(JavaRDD sentenceCountRDD) {
this.sentenceCountRDD = sentenceCountRDD;
this.sc = new JavaSparkContext(sentenceCountRDD.context());
}
// Getter
public JavaRDD getCumSumRDD() {
if (cumSumRDD != null) {
return cumSumRDD;
} else {
throw new IllegalAccessError("Cumulative Sum list not defined. Call buildCumSum() first.");
}
}
// For each equivalent for partitions
public void actionForMapPartition(JavaRDD rdd) {
// Action to fill the accumulator
rdd.foreachPartition(new MapPerPartitionVoidFunction());
}
// Do cum sum within the partition
public void cumSumWithinPartition() {
// Accumulator to get the max of the cumulative sum in each partition
final Accumulator> maxPerPartitionAcc = sc.accumulator(new Counter(),
new MaxPerPartitionAccumulator());
// Partition mapping to fold within partition
foldWithinPartitionRDD = sentenceCountRDD.mapPartitionsWithIndex(
new FoldWithinPartitionFunction(maxPerPartitionAcc), true).cache();
actionForMapPartition(foldWithinPartitionRDD);
// Broadcast the counter (partition index : sum of count) to all workers
broadcastedMaxPerPartitionCounter = sc.broadcast(maxPerPartitionAcc.value());
}
public void cumSumBetweenPartition() {
cumSumRDD = foldWithinPartitionRDD.mapPartitionsWithIndex(
new FoldBetweenPartitionFunction(broadcastedMaxPerPartitionCounter), true)
.setName("cumSumRDD").cache();
foldWithinPartitionRDD.unpersist();
}
public JavaRDD buildCumSum() {
cumSumWithinPartition();
cumSumBetweenPartition();
return getCumSumRDD();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy