org.nd4j.linalg.api.ops.aggregates.Batch Maven / Gradle / Ivy
package org.nd4j.linalg.api.ops.aggregates;
import com.google.common.collect.Lists;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.buffer.DataBuffer;
import java.util.ArrayList;
import java.util.List;
/**
* Wrapper for "batch of aggregates"
*
* @author [email protected]
*/
@Slf4j
public class Batch {
/**
* This batchLimit should be equal to its counterpart at helper_ptrmap.h
*
*/
@Getter
@Setter
private DataBuffer paramsSurface;
@Getter
private static final int batchLimit = 512;
// all aggregates within this batch
@Getter
private List aggregates;
@Getter
private T sample;
@Getter
private int numAggregates;
/**
* This constructor takes List of Aggregates, and builds Batch instance, usable with Nd4j executioner.
*
* @param aggregates
*/
public Batch(List aggregates) {
//if (aggregates.size() > batchLimit)
// throw new RuntimeException("Number of aggregates is higher then " + batchLimit + " elements, multiple batches should be issued.");
this.aggregates = aggregates;
this.numAggregates = aggregates.size();
// we fetch single sample for possible future use. not sure if will be used though
this.sample = aggregates.get(0);
}
/**
* This method returns opNum for batched aggregate
* @return
*/
public int opNum() {
return sample.opNum();
}
/**
* This method tries to append aggregate to the current batch, if it has free room
*
* @param aggregate
* @return
*/
public boolean append(T aggregate) {
if (!isFull()) {
aggregates.add(aggregate);
return true;
} else
return false;
}
/**
* This method checks, if number of batched aggregates equals to maximum possible value
*
* @return
*/
public boolean isFull() {
return batchLimit == numAggregates;
}
/**
* Helper method to create batch from list of aggregates, for cases when list of aggregates is higher then batchLimit
*
* @param list
* @param
* @return
*/
public static List> getBatches(List list) {
return getBatches(list, batchLimit);
}
/**
* Helper method to create batch from list of aggregates, for cases when list of aggregates is higher then batchLimit
*
* @param list
* @param
* @return
*/
public static List> getBatches(List list, int partitionSize) {
List> partitions = Lists.partition(list, partitionSize);
List> split = new ArrayList<>();
for (List partition : partitions) {
split.add(new Batch(partition));
}
return split;
}
}