All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.aggregates.Batch Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.api.ops.aggregates;

import com.google.common.collect.Lists;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;

import java.util.ArrayList;
import java.util.List;

/**
 * Wrapper for "batch of aggregates"
 *
 * @author [email protected]
 */
@Slf4j
public class Batch {
    /**
     * This batchLimit should be equal to its counterpart at helper_ptrmap.h
     *
     */
    @Getter @Setter private DataBuffer paramsSurface;

    @Getter private static final int batchLimit = 512;

    // all aggregates within this batch
    @Getter private List aggregates;

    @Getter private T sample;
    @Getter private int numAggregates;

    /**
     * This constructor takes List of Aggregates, and builds Batch instance, usable with Nd4j executioner.
     *
     * @param aggregates
     */
    public Batch(List aggregates) {
        //if (aggregates.size() > batchLimit)
        //    throw new RuntimeException("Number of aggregates is higher then " + batchLimit + " elements, multiple batches should be issued.");

        this.aggregates = aggregates;
        this.numAggregates = aggregates.size();

        // we fetch single sample for possible future use. not sure if will be used though
        this.sample = aggregates.get(0);
    }

    /**
     * This method returns opNum for batched aggregate
     * @return
     */
    public int opNum() {
        return sample.opNum();
    }

    /**
     * This method tries to append aggregate to the current batch, if it has free room
     *
     * @param aggregate
     * @return
     */
    public boolean append(T aggregate) {
       if (!isFull()) {
           aggregates.add(aggregate);
           return true;
       } else return false;
    }

    /**
     * This method checks, if number of batched aggregates equals to maximum possible value
     *
     * @return
     */
    public boolean isFull() {
        return batchLimit == numAggregates;
    }


    /**
     * Helper method to create batch from list of aggregates, for cases when list of aggregates is higher then batchLimit
     *
     * @param list
     * @param 
     * @return
     */
    public static  List> getBatches(List list) {
        return getBatches(list, batchLimit);
    }

    /**
     * Helper method to create batch from list of aggregates, for cases when list of aggregates is higher then batchLimit
     *
     * @param list
     * @param 
     * @return
     */
    public static  List> getBatches(List list, int partitionSize) {
        List> partitions =  Lists.partition(list, partitionSize);
        List> split = new ArrayList<>();

        for (List partition: partitions) {
            split.add(new Batch(partition));
        }

        return split;
    }
}