All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.convolutional.Conv2dTranspose Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.convolutional;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.util.Preconditions;

/**
 * The input to a {@code Conv2dTranspose} is an {@link ai.djl.ndarray.NDList} with a single 4-D
 * {@link ai.djl.ndarray.NDArray}. The layout of the {@link ai.djl.ndarray.NDArray} must be "NCHW".
 * The shapes are
 *
 * 
    *
  • {@code data: (batch_size, channel, height, width)} *
  • {@code weight: (num_filter, channel, kernel[0], kernel[1])} *
  • {@code bias: (num_filter,)} *
  • {@code out: (batch_size, num_filter, out_height, out_width)}
    * {@code out_height = f(height, kernel[0], pad[0], oPad[0], stride[0], dilate[0])}
    * {@code out_width = f(width, kernel[1], pad[1], oPad[1], stride[1], dilate[1])}
    * {@code where f(x, k, p, oP, s, d) = (x-1)*s-2*p+k+oP} *
* *

Both {@code weight} and {@code bias} are learn-able parameters. * * @see Deconvolution */ public class Conv2dTranspose extends Deconvolution { private static final LayoutType[] EXPECTED_LAYOUT = { LayoutType.BATCH, LayoutType.CHANNEL, LayoutType.HEIGHT, LayoutType.WIDTH }; private static final String STRING_LAYOUT = "NCHW"; private static final int NUM_DIMENSIONS = 4; Conv2dTranspose(Builder builder) { super(builder); } /** {@inheritDoc} */ @Override protected LayoutType[] getExpectedLayout() { return EXPECTED_LAYOUT; } /** {@inheritDoc} */ @Override protected String getStringLayout() { return STRING_LAYOUT; } /** {@inheritDoc} */ @Override protected int numDimensions() { return NUM_DIMENSIONS; } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose(NDArray input, NDArray weight) { return conv2dTranspose( input, weight, null, new Shape(1, 1), new Shape(0, 0), new Shape(0, 0), new Shape(1, 1)); } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @param bias bias {@code NDArray} of shape (outChannel) * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias) { return conv2dTranspose( input, weight, bias, new Shape(1, 1), new Shape(0, 0), new Shape(0, 0), new Shape(1, 1)); } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @param bias bias {@code NDArray} of shape (outChannel) * @param stride the stride of the deconvolving kernel: Shape(height, width) * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose( NDArray input, NDArray weight, NDArray bias, Shape stride) { return conv2dTranspose( input, weight, bias, stride, new Shape(0, 0), new Shape(0, 0), new Shape(1, 1)); } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @param bias bias {@code NDArray} of shape (outChannel) * @param stride the stride of the deconvolving kernel: Shape(height, width) * @param padding implicit paddings on both sides of the input: Shape(height, width) * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose( NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding) { return conv2dTranspose( input, weight, bias, stride, padding, new Shape(0, 0), new Shape(1, 1)); } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @param bias bias {@code NDArray} of shape (outChannel) * @param stride the stride of the deconvolving kernel: Shape(height, width) * @param padding implicit paddings on both sides of the input: Shape(height, width) * @param outPadding Controls the amount of implicit zero-paddings on both sides of the output * for outputPadding number of points for each dimension. * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose( NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding) { return conv2dTranspose(input, weight, bias, stride, padding, outPadding, new Shape(1, 1)); } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @param bias bias {@code NDArray} of shape (outChannel) * @param stride the stride of the deconvolving kernel: Shape(height, width) * @param padding implicit paddings on both sides of the input: Shape(height, width) * @param outPadding Controls the amount of implicit zero-paddings on both sides of the output * for outputPadding number of points for each dimension. * @param dilation the spacing between kernel elements: Shape(height, width) * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose( NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation) { return conv2dTranspose(input, weight, bias, stride, padding, outPadding, dilation, 1); } /** * Applies 2D deconvolution over an input signal composed of several input planes. * * @param input the input {@code NDArray} of shape (batchSize, inputChannel, height, width) * @param weight filters {@code NDArray} of shape (outChannel, inputChannel/groups, height, * width) * @param bias bias {@code NDArray} of shape (outChannel) * @param stride the stride of the deconvolving kernel: Shape(height, width) * @param padding implicit paddings on both sides of the input: Shape(height, width) * @param outPadding Controls the amount of implicit zero-paddings on both sides of the output * for outputPadding number of points for each dimension. Shape(height, width) * @param dilation the spacing between kernel elements: Shape(height, width) * @param groups split input into groups: input channel(input.size(1)) should be divisible by * the number of groups * @return the output of the conv2dTranspose operation */ public static NDList conv2dTranspose( NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation, int groups) { Preconditions.checkArgument( input.getShape().dimension() == 4 && weight.getShape().dimension() == 4, "the shape of input or weight doesn't match the conv2dTranspose"); Preconditions.checkArgument( stride.dimension() == 2 && padding.dimension() == 2 && outPadding.dimension() == 2 && dilation.dimension() == 2, "the shape of stride or padding or dilation doesn't match the conv2dTranspose"); return Deconvolution.deconvolution( input, weight, bias, stride, padding, outPadding, dilation, groups); } /** * Creates a builder to build a {@code Conv2dTranspose}. * * @return a new builder */ public static Builder builder() { return new Builder(); } /** The Builder to construct a {@link Conv2dTranspose} type of {@link Block}. */ public static final class Builder extends DeconvolutionBuilder { /** Creates a builder that can build a {@link Conv2dTranspose} block. */ Builder() { stride = new Shape(1, 1); padding = new Shape(0, 0); outPadding = new Shape(0, 0); dilation = new Shape(1, 1); } /** {@inheritDoc} */ @Override protected Builder self() { return this; } /** * Builds a {@link Conv2dTranspose} block. * * @return the {@link Conv2dTranspose} block */ public Conv2dTranspose build() { validate(); return new Conv2dTranspose(this); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy