All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.text.functions.FoldWithinPartitionFunction Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta_spark_1
Show newest version
package org.deeplearning4j.spark.text.functions;

import org.apache.spark.Accumulator;
import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.berkeley.Counter;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

/**
 * @author jeffreytang
 */
public class FoldWithinPartitionFunction implements Function2, Iterator> {

    public FoldWithinPartitionFunction(Accumulator> maxPartitionAcc) {
        this.maxPerPartitionAcc = maxPartitionAcc;
    }

    private Accumulator> maxPerPartitionAcc;


    @Override
    public Iterator call(Integer ind, Iterator partition) throws Exception {

        List foldedItemList = new ArrayList() {{ add(new AtomicLong(0L)); }};

        // Recurrent state implementation of cum sum
        int foldedItemListSize = 1;
        while (partition.hasNext()) {
            long curPartitionItem = partition.next().get();
            int lastFoldedIndex = foldedItemListSize - 1;
            long lastFoldedItem = foldedItemList.get(lastFoldedIndex).get();
            AtomicLong sumLastCurrent = new AtomicLong(curPartitionItem + lastFoldedItem);

            foldedItemList.set(lastFoldedIndex, sumLastCurrent);
            foldedItemList.add(sumLastCurrent);
            foldedItemListSize += 1;
        }

        // Update Accumulator
        long maxFoldedItem = foldedItemList.remove(foldedItemListSize - 1).get();
        Counter partitionIndex2maxItemCounter = new Counter<>();
        partitionIndex2maxItemCounter.incrementCount(ind, maxFoldedItem);
        maxPerPartitionAcc.add(partitionIndex2maxItemCounter);

        return foldedItemList.iterator();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy