
org.deeplearning4j.spark.text.functions.FoldWithinPartitionFunction Maven / Gradle / Ivy
package org.deeplearning4j.spark.text.functions;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.berkeley.Counter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
/**
* @author jeffreytang
*/
public class FoldWithinPartitionFunction implements Function2, Iterator> {
public FoldWithinPartitionFunction(Accumulator> maxPartitionAcc) {
this.maxPerPartitionAcc = maxPartitionAcc;
}
private Accumulator> maxPerPartitionAcc;
@Override
public Iterator call(Integer ind, Iterator partition) throws Exception {
List foldedItemList = new ArrayList() {{ add(new AtomicLong(0L)); }};
// Recurrent state implementation of cum sum
int foldedItemListSize = 1;
while (partition.hasNext()) {
long curPartitionItem = partition.next().get();
int lastFoldedIndex = foldedItemListSize - 1;
long lastFoldedItem = foldedItemList.get(lastFoldedIndex).get();
AtomicLong sumLastCurrent = new AtomicLong(curPartitionItem + lastFoldedItem);
foldedItemList.set(lastFoldedIndex, sumLastCurrent);
foldedItemList.add(sumLastCurrent);
foldedItemListSize += 1;
}
// Update Accumulator
long maxFoldedItem = foldedItemList.remove(foldedItemListSize - 1).get();
Counter partitionIndex2maxItemCounter = new Counter<>();
partitionIndex2maxItemCounter.incrementCount(ind, maxFoldedItem);
maxPerPartitionAcc.add(partitionIndex2maxItemCounter);
return foldedItemList.iterator();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy