All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.translate.StackBatchifier Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.translate;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

import java.util.Arrays;
import java.util.stream.LongStream;

/**
 * {@code StackBatchifier} is used to merge a list of samples to form a mini-batch of NDArray(s).
 * The is default {@link Batchifier} for data loading.
 */
public class StackBatchifier implements Batchifier {

    private static final long serialVersionUID = 1L;

    /** {@inheritDoc} */
    @Override
    public NDList batchify(NDList[] inputs) {
        // each input as NDList might contain several data or labels
        // so those should be batchified with counterpart
        int batchSize = inputs.length;
        int numInputKinds = inputs[0].size();
        // if the NDList is empty
        if (numInputKinds == 0) {
            return new NDList();
        }

        try {
            // stack all the data and labels together
            NDList result = new NDList(numInputKinds);
            for (int i = 0; i < numInputKinds; i++) {
                NDList inputsOfKind = new NDList(batchSize);
                String inputName = inputs[0].get(i).getName();
                for (NDList input : inputs) {
                    inputsOfKind.add(input.get(i));
                }
                NDArray stacked = NDArrays.stack(new NDList(inputsOfKind));
                // keep the name for stacked inputs
                stacked.setName(inputName);
                result.add(stacked);
            }

            return result;
        } catch (IndexOutOfBoundsException | EngineException e) {
            // If there is an error when batchifying, check for various potential problems with the
            // inputs. The error is not handled in this block. It only attempts to clarify the
            // error's cause before rethrowing.

            // Check if numInputKinds is not consistant for all inputs
            for (NDList input : inputs) {
                if (input.size() != numInputKinds) {
                    throw new IllegalArgumentException(
                            "You cannot batch data with different numbers of inputs", e);
                }
            }

            // Check if data does not have a consistent shape or type
            for (int i = 0; i < numInputKinds; i++) {
                Shape kindDataShape = inputs[0].get(i).getShape();
                DataType kindDataType = inputs[0].get(i).getDataType();
                for (NDList input : inputs) {
                    NDArray currInput = input.get(i);
                    if (!currInput.getShape().equals(kindDataShape)) {
                        throw new IllegalArgumentException(
                                "You cannot batch data with different input shapes"
                                        + currInput.getShape()
                                        + " vs "
                                        + kindDataShape,
                                e);
                    }
                    if (!currInput.getDataType().equals(kindDataType)) {
                        throw new IllegalArgumentException(
                                "You cannot batch data with different input data types", e);
                    }
                }
            }

            // Could not clarify cause - rethrow original error.
            throw e;
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDList[] unbatchify(NDList inputs) {
        int numInputKinds = inputs.size();
        if (numInputKinds == 0) {
            return new NDList[0];
        }
        int batchSize = Math.toIntExact(inputs.head().size(0));
        if (batchSize == 0) {
            return new NDList[0];
        }

        NDList[] dataList = new NDList[batchSize];
        for (int i = 0; i < batchSize; i++) {
            dataList[i] = new NDList();
        }

        for (NDArray input : inputs) {
            NDList splitList = input.split(batchSize);
            for (int i = 0; i < batchSize; i++) {
                NDArray array = splitList.get(i).squeeze(0);
                array.setName(input.getName());
                dataList[i].add(array);
            }
        }
        return dataList;
    }

    /** {@inheritDoc} */
    @Override
    public NDList[] split(NDList list, int numOfSlices, boolean evenSplit) {
        int batchSize = Math.toIntExact(list.head().size(0));
        numOfSlices = Math.min(numOfSlices, batchSize);

        NDList[] splitted = new NDList[numOfSlices];
        Arrays.setAll(splitted, i -> new NDList());

        for (NDArray nd : list) {
            String name = nd.getName();
            NDList rows = split(nd, numOfSlices, evenSplit);

            for (int i = 0; i < numOfSlices; ++i) {
                NDArray array = rows.get(i);
                array.setName(name);
                splitted[i].add(array);
            }
        }
        return splitted;
    }

    /**
     * Splits an {@code NDArray} into the given number of slices along the given batch axis.
     *
     * 

Usually used for data parallelism where each slice is sent to one device (i.e. GPU). * * @param array a batch of {@code NDArray} * @param numOfSlices the number of desired slices * @param evenSplit whether to force all slices to have the same number of elements * @return an NDList even if `numOfSlice` is 1 */ private NDList split(NDArray array, int numOfSlices, boolean evenSplit) { int batchSize = Math.toIntExact(array.size(0)); if (batchSize < numOfSlices) { throw new IllegalArgumentException( "Batch size(" + batchSize + ") is less then slice number(" + numOfSlices + ")."); } if (evenSplit && batchSize % numOfSlices != 0) { throw new IllegalArgumentException( "data with shape " + batchSize + " cannot be evenly split into " + numOfSlices + ". Use a batch size that's multiple of " + numOfSlices + " or set even_split=true to allow" + " uneven partitioning of data."); } if (evenSplit) { return array.split(numOfSlices); } int step = (int) Math.ceil((double) batchSize / numOfSlices); long[] indices = LongStream.range(1, numOfSlices).map(i -> i * step).toArray(); return array.split(indices); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy