All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.aggregates.Batch Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.api.ops.aggregates;

import org.nd4j.shade.guava.collect.Lists;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;

import java.util.ArrayList;
import java.util.List;

/**
 * Wrapper for "batch of aggregates"
 *
 * @author [email protected]
 */
@Slf4j
public class Batch {
    /**
     * This batchLimit should be equal to its counterpart at helper_ptrmap.h
     *
     */
    @Getter
    @Setter
    private DataBuffer paramsSurface;

    @Getter
    private static final int batchLimit = 512;

    // all aggregates within this batch
    @Getter
    private List aggregates;

    @Getter
    private T sample;
    @Getter
    private int numAggregates;

    /**
     * This constructor takes List of Aggregates, and builds Batch instance, usable with Nd4j executioner.
     *
     * @param aggregates
     */
    public Batch(List aggregates) {
        //if (aggregates.size() > batchLimit)
        //    throw new RuntimeException("Number of aggregates is higher then " + batchLimit + " elements, multiple batches should be issued.");

        this.aggregates = aggregates;
        this.numAggregates = aggregates.size();

        // we fetch single sample for possible future use. not sure if will be used though
        this.sample = aggregates.get(0);
    }

    /**
     * This method returns opNum for batched aggregate
     * @return
     */
    public int opNum() {
        return sample.opNum();
    }

    /**
     * This method tries to append aggregate to the current batch, if it has free room
     *
     * @param aggregate
     * @return
     */
    public boolean append(T aggregate) {
        if (!isFull()) {
            aggregates.add(aggregate);
            return true;
        } else
            return false;
    }

    /**
     * This method checks, if number of batched aggregates equals to maximum possible value
     *
     * @return
     */
    public boolean isFull() {
        return batchLimit == numAggregates;
    }


    /**
     * Helper method to create batch from list of aggregates, for cases when list of aggregates is higher then batchLimit
     *
     * @param list
     * @param 
     * @return
     */
    public static  List> getBatches(List list) {
        return getBatches(list, batchLimit);
    }

    /**
     * Helper method to create batch from list of aggregates, for cases when list of aggregates is higher then batchLimit
     *
     * @param list
     * @param 
     * @return
     */
    public static  List> getBatches(List list, int partitionSize) {
        DataType c = null;
        for (val u:list) {
            for (val a:u.getArguments()) {
                // we'll be comparing to the first array
                if (c == null && a != null)
                    c = a.dataType();

                if (a != null && c != null)
                    Preconditions.checkArgument(c == a.dataType(), "All arguments must have same data type");
            }
        }

        if (c == null)
            throw new ND4JIllegalStateException("Can't infer data type from arguments");

        List> partitions = Lists.partition(list, partitionSize);
        List> split = new ArrayList<>();

        for (List partition : partitions) {
            split.add(new Batch(partition));
        }

        return split;
    }
}