All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.text.functions.CountCumSum Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta_spark_2
Show newest version
package org.deeplearning4j.spark.text.functions;

import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.spark.text.accumulators.MaxPerPartitionAccumulator;

import java.util.concurrent.atomic.AtomicLong;

/**
 * @author jeffreytang
 */
@SuppressWarnings("unchecked")
public class CountCumSum {

    // Starting variables
    private JavaSparkContext sc;
    private JavaRDD sentenceCountRDD;

    // Variables to fill in as we go
    private JavaRDD foldWithinPartitionRDD;
    private Broadcast> broadcastedMaxPerPartitionCounter;
    private JavaRDD cumSumRDD;

    // Constructor
    public CountCumSum(JavaRDD sentenceCountRDD) {
        this.sentenceCountRDD = sentenceCountRDD;
        this.sc = new JavaSparkContext(sentenceCountRDD.context());
    }

    // Getter
    public JavaRDD getCumSumRDD() {
        if (cumSumRDD != null) {
            return cumSumRDD;
        } else {
            throw new IllegalAccessError("Cumulative Sum list not defined. Call buildCumSum() first.");
        }
    }

    // For each equivalent for partitions
    public void actionForMapPartition(JavaRDD rdd) {
        // Action to fill the accumulator
        rdd.foreachPartition(new MapPerPartitionVoidFunction());
    }
    // Do cum sum within the partition
    public void cumSumWithinPartition() {

        // Accumulator to get the max of the cumulative sum in each partition
        final Accumulator> maxPerPartitionAcc = sc.accumulator(new Counter(),
                                                                                new MaxPerPartitionAccumulator());
        // Partition mapping to fold within partition
        foldWithinPartitionRDD = sentenceCountRDD.mapPartitionsWithIndex(
                new FoldWithinPartitionFunction(maxPerPartitionAcc), true).cache();
        actionForMapPartition(foldWithinPartitionRDD);

        // Broadcast the counter (partition index : sum of count) to all workers
        broadcastedMaxPerPartitionCounter = sc.broadcast(maxPerPartitionAcc.value());
    }

    public void cumSumBetweenPartition() {

        cumSumRDD = foldWithinPartitionRDD.mapPartitionsWithIndex(
                new FoldBetweenPartitionFunction(broadcastedMaxPerPartitionCounter), true)
                                          .setName("cumSumRDD").cache();
        foldWithinPartitionRDD.unpersist();
    }

    public JavaRDD buildCumSum() {
        cumSumWithinPartition();
        cumSumBetweenPartition();
        return getCumSumRDD();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy