All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.translate.PaddingStackBatchifier Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;

import java.util.ArrayList;
import java.util.List;

/**
 * The padding stack batchifier is a {@link StackBatchifier} that also pads elements to reach the
 * same length.
 */
public final class PaddingStackBatchifier implements Batchifier {

    private static final long serialVersionUID = 1L;

    @SuppressWarnings("serial")
    private List arraysToPad;

    @SuppressWarnings("serial")
    private List dimsToPad;

    private transient List paddingSuppliers;

    @SuppressWarnings("serial")
    private List paddingSizes;

    private boolean includeValidLengths;

    private PaddingStackBatchifier(Builder builder) {
        arraysToPad = builder.arraysToPad;
        dimsToPad = builder.dimsToPad;
        paddingSuppliers = builder.paddingSuppliers;
        paddingSizes = builder.paddingSizes;
        includeValidLengths = builder.includeValidLengths;
    }

    /** {@inheritDoc} */
    @Override
    public NDList batchify(NDList[] inputs) {
        NDList validLengths = new NDList(inputs.length);
        NDManager manager = inputs[0].get(0).getManager();
        for (int i = 0; i < arraysToPad.size(); i++) {
            int arrayIndex = arraysToPad.get(i);
            int dimIndex = dimsToPad.get(i);
            NDArray padding = paddingSuppliers.get(i).get(manager);
            long paddingSize = paddingSizes.get(i);
            long maxSize = findMaxSize(inputs, arrayIndex, dimIndex);
            if (paddingSize != -1 && maxSize > paddingSize) {
                throw new IllegalArgumentException(
                        "The batchifier padding size is too small " + maxSize + " " + paddingSize);
            }
            maxSize = Math.max(maxSize, paddingSize);
            long[] arrayValidLengths = padArrays(inputs, arrayIndex, dimIndex, padding, maxSize);
            validLengths.add(manager.create(arrayValidLengths));
        }
        NDList result = Batchifier.STACK.batchify(inputs);
        if (includeValidLengths) {
            result.addAll(validLengths);
        }
        return result;
    }

    /** {@inheritDoc} */
    @Override
    public NDList[] unbatchify(NDList inputs) {
        if (!includeValidLengths) {
            return Batchifier.STACK.unbatchify(inputs);
        }
        NDList validLengths =
                new NDList(inputs.subList(inputs.size() - arraysToPad.size(), inputs.size()));
        inputs = new NDList(inputs.subList(0, inputs.size() - arraysToPad.size()));
        NDList[] split = Batchifier.STACK.unbatchify(inputs);
        for (int i = 0; i < split.length; i++) {
            NDList arrays = split[i];
            for (int j = 0; j < arraysToPad.size(); j++) {
                long validLength = validLengths.get(j).getLong(i);
                int arrayIndex = arraysToPad.get(j);
                NDArray dePadded =
                        arrays.get(arrayIndex)
                                .get(NDIndex.sliceAxis(dimsToPad.get(j) - 1, 0, validLength));
                arrays.set(arrayIndex, dePadded);
            }
        }
        return split;
    }

    /** {@inheritDoc} */
    @Override
    public NDList[] split(NDList list, int numOfSlices, boolean evenSplit) {
        if (!includeValidLengths) {
            return Batchifier.STACK.split(list, numOfSlices, evenSplit);
        }
        NDList validLengths =
                new NDList(list.subList(list.size() - arraysToPad.size(), list.size()));
        list = new NDList(list.subList(0, list.size() - arraysToPad.size()));
        NDList[] split = Batchifier.STACK.split(list, numOfSlices, evenSplit);
        long sliceSize = split[0].get(0).getShape().get(0);
        long totalSize = list.get(0).getShape().get(0);
        for (int i = 0; i < split.length; i++) {
            // TODO: The padding required may not be the same for all splits. For smaller splits,
            // we can remove some extra padding.
            NDList arrays = split[i];
            for (int j = 0; j < arraysToPad.size(); j++) {
                long min = i * sliceSize;
                long max = Math.min((i + 1) * sliceSize, totalSize);
                NDArray splitValidLenghts = validLengths.get(j).get(NDIndex.sliceAxis(0, min, max));
                arrays.add(splitValidLenghts);
            }
        }
        return split;
    }

    /**
     * Finds the maximum size for a particular array/dimension in a batch of inputs (which can be
     * padded to equalize their sizes).
     *
     * @param inputs the batch of inputs
     * @param arrayIndex the array (for each NDList in the batch)
     * @param dimIndex for the array in each NDList in the batch
     * @return the maximum size
     */
    public static long findMaxSize(NDList[] inputs, int arrayIndex, int dimIndex) {
        long maxSize = -1;
        for (NDList input : inputs) {
            NDArray array = input.get(arrayIndex);
            maxSize = Math.max(maxSize, array.getShape().get(dimIndex));
        }
        return maxSize;
    }

    /**
     * Pads the arrays at a particular dimension to all have the same size (updating inputs in
     * place).
     *
     * @param inputs the batch of inputs
     * @param arrayIndex the array (for each NDList in the batch)
     * @param dimIndex for the array in each NDList in the batch
     * @param padding the padding to use. Say you have a batch of arrays of Shape(10, ?, 3) and you
     *     are padding the "?" dimension. There are two padding modes:
     *     
    *
  • If you give padding of Shape(1, 3) (same dimensionality as required), it will be * repeated with {@link NDArray#repeat(long)} as necessary *
  • If you give padding of Shape(3) or Shape(0) (smaller dimensionality as required), * it will be broadcasted with {@link NDArray#broadcast(Shape)} to reach the full * required Shape(?, 3) *
* * @param maxSize the size that each array will be padded to in that dimension. In the example * above, the padding to be applied to the "?" dimension. * @return the original valid length for each dimension in the batch (same length as * inputs.length). The inputs will be updated in place. */ public static long[] padArrays( NDList[] inputs, int arrayIndex, int dimIndex, NDArray padding, long maxSize) { long[] arrayValidLengths = new long[inputs.length]; for (int i = 0; i < inputs.length; i++) { NDArray array = inputs[i].get(arrayIndex); String arrayName = array.getName(); long validLength = array.getShape().get(dimIndex); if (validLength < maxSize) { // Number of dimensions the padding must be int dimensionsRequired = array.getShape().dimension() - padding.getShape().dimension(); NDArray paddingArray; if (dimensionsRequired == 0) { paddingArray = padding.repeat( Shape.update( array.getShape(), dimIndex, maxSize - validLength)); } else if (dimensionsRequired > 0) { paddingArray = padding.broadcast( Shape.update( array.getShape(), dimIndex, maxSize - validLength)); } else { throw new IllegalArgumentException( "The padding must be <=" + dimensionsRequired + " dimensions, but found " + padding.getShape().dimension()); } array = array.concat(paddingArray.toType(array.getDataType(), false), dimIndex); } // keep input name array.setName(arrayName); inputs[i].set(arrayIndex, array); arrayValidLengths[i] = validLength; } return arrayValidLengths; } /** * Returns a {@link PaddingStackBatchifier.Builder}. * * @return a {@link PaddingStackBatchifier.Builder} */ public static PaddingStackBatchifier.Builder builder() { return new Builder(); } /** Builder to build a {@link PaddingStackBatchifier}. */ public static final class Builder { private List arraysToPad; private List dimsToPad; private List paddingSuppliers; private List paddingSizes; private boolean includeValidLengths; private Builder() { arraysToPad = new ArrayList<>(); dimsToPad = new ArrayList<>(); paddingSuppliers = new ArrayList<>(); paddingSizes = new ArrayList<>(); } /** * Sets whether to include the valid lengths (length of non-padded data) for each array. * * @param includeValidLengths true to include valid lengths * @return this builder */ public Builder optIncludeValidLengths(boolean includeValidLengths) { this.includeValidLengths = includeValidLengths; return this; } /** * Adds a new dimension to be padded in the input {@link NDList}. * * @param array which array in the {@link NDList} to pad * @param dim which dimension in the array to pad * @param supplier a supplier that produces the padding array. The padding array shape * should include both the batch and a 1 for the padded dimension. For batch array shape * NTC, the padding shape should be N x 1 x C * @return this builder */ public Builder addPad(int array, int dim, NDArraySupplier supplier) { return addPad(array, dim, supplier, -1); } /** * Adds a new dimension to be padded in the input {@link NDList}. * * @param array which array in the {@link NDList} to pad * @param dim which dimension in the array to pad * @param supplier a supplier that produces the padding array. The padding array shape * should include both the batch and a 1 for the padded dimension. For batch array shape * NTC, the padding shape should be N x 1 x C * @param paddingSize the minimum padding size to use. All sequences to pad must be less * than this size * @return this builder */ public Builder addPad(int array, int dim, NDArraySupplier supplier, int paddingSize) { arraysToPad.add(array); dimsToPad.add(dim); paddingSuppliers.add(supplier); paddingSizes.add(paddingSize); return this; } /** * Builds the {@link PaddingStackBatchifier}. * * @return the constructed {@link PaddingStackBatchifier} */ public PaddingStackBatchifier build() { return new PaddingStackBatchifier(this); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy