All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.convolutional.Convolution Maven / Gradle / Ivy

/*
 * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.convolutional;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/** Computes N-D convolution on (N+2)-D input. */
public abstract class Convolution extends ParameterBlock {

    private static final byte VERSION = 2;

    protected Shape kernel;
    protected Shape stride;
    protected Shape pad;
    protected Shape dilate;
    protected int numFilters;
    protected int numGroups;
    protected boolean includeBias;

    protected Parameter weight;
    protected Parameter bias;

    /**
     * Creates a {@link Convolution} object.
     *
     * @param builder the {@code Builder} that has the necessary configurations
     */
    public Convolution(ConvolutionBuilder builder) {
        kernel = builder.kernel;
        stride = builder.stride;
        pad = builder.pad;
        dilate = builder.dilate;
        numFilters = builder.numFilters;
        numGroups = builder.numGroups;
        includeBias = builder.includeBias;

        weight = new Parameter("weight", this, ParameterType.WEIGHT);
        if (includeBias) {
            bias = new Parameter("bias", this, ParameterType.BIAS);
        }
    }

    /**
     * Returns the expected layout of the input.
     *
     * @return the expected layout of the input
     */
    protected abstract LayoutType[] getExpectedLayout();

    /**
     * Returns the string representing the layout of the input.
     *
     * @return the string representing the layout of the input
     */
    protected abstract String getStringLayout();

    /**
     * Returns the number of dimensions of the input.
     *
     * @return the number of dimensions of the input
     */
    protected abstract int numDimensions();

    /** {@inheritDoc} */
    @Override
    public NDList forward(
            ParameterStore parameterStore, NDList inputs, PairList params) {
        inputs = opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        return ex.convolution(
                inputs,
                kernel,
                stride,
                pad,
                dilate,
                numFilters,
                numGroups,
                getStringLayout(),
                !includeBias,
                params);
    }

    /** {@inheritDoc} */
    @Override
    protected void beforeInitialize(Shape[] inputs) {
        this.inputShapes = inputs;
        Shape inputShape = inputs[0];
        Block.validateLayout(getExpectedLayout(), inputShape.getLayout());
    }

    /** {@inheritDoc} */
    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        long[] shape = new long[numDimensions()];
        shape[0] = inputs[0].get(0);
        shape[1] = numFilters;
        for (int i = 0; i < numDimensions() - 2; i++) {
            shape[2 + i] =
                    (inputs[0].get(2 + i)
                                            + 2 * pad.get(i)
                                            - dilate.get(0) * (kernel.get(i) - 1)
                                            - 1)
                                    / stride.get(0)
                            + 1;
        }
        return new Shape[] {new Shape(shape)};
    }

    /** {@inheritDoc} */
    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Shape shape = inputShapes[0];
        switch (name) {
            case "weight":
                return new Shape(numFilters, shape.get(1)).addAll(kernel);
            case "bias":
                return new Shape(numFilters);
            default:
                throw new IllegalArgumentException("Invalid parameter name");
        }
    }

    /** {@inheritDoc} */
    @Override
    public List getDirectParameters() {
        List parameters = new ArrayList<>();
        parameters.add(weight);
        if (includeBias) {
            parameters.add(bias);
        }
        return parameters;
    }

    /** {@inheritDoc} */
    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(VERSION);
        saveInputShapes(os);
        weight.save(os);
        if (bias != null) {
            bias.save(os);
        }
    }

    /** {@inheritDoc} */
    @Override
    public void loadParameters(NDManager manager, DataInputStream is)
            throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version == VERSION) {
            readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        weight.load(manager, is);
        if (bias != null) {
            bias.load(manager, is);
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        NDArray data = inputs.singletonOrThrow();
        Device device = data.getDevice();
        NDList ret = new NDList(3);
        ret.add(data);
        ret.add(parameterStore.getValue(weight, device));
        if (bias != null) {
            ret.add(parameterStore.getValue(bias, device));
        }
        return ret;
    }

    /**
     * A builder that can build any {@code Convolution} block.
     *
     * @param  the type of {@code Convolution} block to build
     */
    @SuppressWarnings("rawtypes")
    public abstract static class ConvolutionBuilder {

        protected Shape kernel;
        protected Shape stride;
        protected Shape pad;
        protected Shape dilate;
        protected int numFilters;
        protected int numGroups = 1;
        protected boolean includeBias = true;

        /**
         * Sets the shape of the kernel.
         *
         * @param kernel the shape of the kernel
         * @return this Builder
         */
        public T setKernel(Shape kernel) {
            this.kernel = kernel;
            return self();
        }

        /**
         * Sets the stride of the convolution. Defaults to 1 in each dimension.
         *
         * @param stride the shape of the stride
         * @return this Builder
         */
        public T optStride(Shape stride) {
            this.stride = stride;
            return self();
        }

        /**
         * Sets the padding along each dimension. Defaults to 0 along each dimension.
         *
         * @param pad the shape of padding along each dimension
         * @return this Builder
         */
        public T optPad(Shape pad) {
            this.pad = pad;
            return self();
        }

        /**
         * Sets the dilation along each dimension. Defaults to 1 along each dimension.
         *
         * @param dilate the shape of dilation along each dimension
         * @return this Builder
         */
        public T optDilate(Shape dilate) {
            this.dilate = dilate;
            return self();
        }

        /**
         * Sets the Required number of filters.
         *
         * @param numFilters the number of convolution filters(channels)
         * @return this Builder
         */
        public T setNumFilters(int numFilters) {
            this.numFilters = numFilters;
            return self();
        }

        /**
         * Sets the number of group partitions.
         *
         * @param numGroups the number of group partitions
         * @return this Builder
         */
        public T optNumGroups(int numGroups) {
            this.numGroups = numGroups;
            return self();
        }

        /**
         * Sets the optional parameter of whether to include a bias vector. Includes bias by
         * default.
         *
         * @param includeBias whether to use a bias vector parameter
         * @return this Builder
         */
        public T optBias(boolean includeBias) {
            this.includeBias = includeBias;
            return self();
        }

        /**
         * Validates that the required arguments are set.
         *
         * @throws IllegalArgumentException if the required arguments are not set
         */
        protected void validate() {
            if (kernel == null || numFilters == 0) {
                throw new IllegalArgumentException("Kernel and numFilters must be set");
            }
        }

        protected abstract T self();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy