All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.Block Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.function.Predicate;

/**
 * A {@code Block} is a composable function that forms a neural network.
 *
 * 

Blocks serve a purpose similar to functions that convert an input NDList to an output NDList. * They can represent single operations, parts of a neural network, and even the whole neural * network. What makes blocks special is that they contain a number of parameters that are used in * their function and are trained during deep learning. As these parameters are trained, the * functions represented by the blocks get more and more accurate. Each block consists of the * following components: * *

    *
  • Forward function *
  • Parameters *
  • Child blocks *
* *

The core purpose of a {@code Block} is to perform an operation on the inputs, and return an * output. It is defined in the {@link #forward(ParameterStore, NDList, boolean) forward} method. * The forward function could be defined explicitly in terms of parameters or implicitly and could * be a combination of the functions of the child blocks. * *

The parameters of a {@code Block} are instances of {@link Parameter} which are required for * the operation in the forward function. For example, in a {@link Conv2d} block, the parameters are * {@code weight} and {@code bias}. During training, these parameters are updated to reflect the * training data, and that forms the crux of learning. * *

When building these block functions, the easiest way is to use composition. Similar to how * functions are built by calling other functions, blocks can be built by combining other blocks. We * refer to the containing block as the parent and the sub-blocks as the children. * *

We provide helpers for creating two common structures of blocks. For blocks that call children * in a chain, use {@link SequentialBlock}. If a blocks calls all of the children in parallel and * then combines their results, use {@link ParallelBlock}. For blocks that do not fit these * structures, you should directly extend the {@link AbstractBlock} class. * *

A block does not necessarily have to have children and parameters. For example, {@link * SequentialBlock}, and {@link ParallelBlock} don't have any parameters, but do have child blocks. * Similarly, {@link Conv2d} does not have children, but has parameters. There can be special cases * where blocks have neither parameters nor children. One such example is {@link LambdaBlock}. * {@link LambdaBlock} takes in a function, and applies that function to its input in the {@link * #forward(ParameterStore, NDList, boolean) forward} method. * *

Now that we understand the components of the block, we can explore what the block really * represents. A block combined with the recursive, hierarchical structure of its children forms a * network. It takes in the input to the network, performs its operation, and returns the output of * the network. When a block is added as a child of another block, it becomes a sub-network of that * block. * *

The life-cycle of a block has 3 stages: * *

    *
  • Construction *
  • Initialization *
  • Training *
* *

Construction is the process of building the network. During this stage, blocks are created * with appropriate arguments and the desired network is built by adding creating a hierarchy of * parent and child blocks. At this stage, it is a bare-bones network. The parameter values are not * created and the shapes of the inputs are not known. The block is ready for initialization. * *

Initialization is the process of initializing all the parameters of the block and its * children, according to the inputs expected. It involves setting an {@link Initializer}, deciding * the {@link DataType}, and the shapes of the input. The parameter arrays are {@link * ai.djl.ndarray.NDArray} that are initialized according to the {@link Initializer} set. At this * stage, the block is expecting a specific type of input, and is ready to be trained. * *

Training is when we are starting feeding the training data as input to the block, get the * output, and try to update parameters to learn. For more information about training, please refer * the javadoc at {@link ai.djl.training.Trainer}. At the end of training, a block represents a * fully-trained model. * *

It is also possible to freeze parameters and blocks to avoid them being trained. When loading * models or building blocks with preTrained data, they default to being frozen. If you wish to * further refine these elements, use {@link Block#freezeParameters(boolean)} to unfreeze them. * * @see this * tutorial on creating your first network * @see The * D2L chapter on blocks and blocks with * direct parameters */ public interface Block { /** * Applies the operating function of the block once. This method should be called only on blocks * that are initialized. * * @param parameterStore the parameter store * @param inputs the input NDList * @param training true for a training forward pass (turn on dropout and layerNorm) * @return the output of the forward pass */ default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training) { return forward(parameterStore, inputs, training, null); } /** * Applies the operating function of the block once. This method should be called only on blocks * that are initialized. * * @param parameterStore the parameter store * @param inputs the input NDList * @param training true for a training forward pass (turn on dropout and layerNorm) * @param params optional parameters * @return the output of the forward pass */ NDList forward( ParameterStore parameterStore, NDList inputs, boolean training, PairList params); /** * A forward call using both training data and labels. * *

Within this forward call, it can be assumed that training is true. * * @param parameterStore the parameter store * @param data the input data NDList * @param labels the input labels NDList * @param params optional parameters * @return the output of the forward pass * @see #forward(ParameterStore, NDList, boolean, PairList) */ default NDList forward( ParameterStore parameterStore, NDList data, NDList labels, PairList params) { return forward(parameterStore, data, true, params); } /** * Sets an {@link Initializer} to all the parameters that match parameter type in the block. * * @param initializer the initializer to set * @param type the Parameter Type we want to setInitializer */ void setInitializer(Initializer initializer, Parameter.Type type); /** * Sets an {@link Initializer} to the specified direct parameter of the block, overriding the * initializer of the parameter, if already set. * * @param initializer the initializer to be set * @param paramName the name of the parameter */ void setInitializer(Initializer initializer, String paramName); /** * Sets an {@link Initializer} to all the parameters that match Predicate in the block. * * @param initializer the initializer to be set * @param predicate predicate function to indicate parameters you want to set */ void setInitializer(Initializer initializer, Predicate predicate); /** * Initializes the parameters of the block, set require gradient if required and infer the block * inputShape. This method must be called before calling `forward`. * * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block */ void initialize(NDManager manager, DataType dataType, Shape... inputShapes); /** * Returns a boolean whether the block is initialized (block has inputShape and params have * nonNull array). * * @return whether the block is initialized */ boolean isInitialized(); /** * Guaranteed to throw an exception. Not yet implemented * * @param dataType the data type to cast to * @throws UnsupportedOperationException always */ void cast(DataType dataType); /** * Closes all the parameters of the block. All the updates made during training will be lost. */ void clear(); /** * Returns a {@link PairList} of input names, and shapes. * * @return the {@link PairList} of input names, and shapes */ PairList describeInput(); /** * Returns a list of all the children of the block. * * @return the list of child blocks */ BlockList getChildren(); /** * Returns a list of all the direct parameters of the block. * * @return the list of {@link Parameter} */ ParameterList getDirectParameters(); /** * Returns a list of all the parameters of the block, including the parameters of its children * fetched recursively. * * @return the list of all parameters of the block */ ParameterList getParameters(); /** * Returns the expected output shapes of the block for the specified input shapes. * * @param inputShapes the shapes of the inputs * @return the expected output shapes of the block */ Shape[] getOutputShapes(Shape[] inputShapes); /** * Returns the expected output shapes of the block for the specified input shapes. * * @param inputShapes the shapes of the inputs * @param inputDataTypes the datatypes of the inputs * @return the expected output shapes of the block */ default Shape[] getOutputShapes(Shape[] inputShapes, DataType[] inputDataTypes) { return getOutputShapes(inputShapes); } /** * Returns the input shapes of the block. The input shapes are only available after the block is * initialized, otherwise an {@link IllegalStateException} is thrown. * * @return the input shapes of the block */ Shape[] getInputShapes(); /** * Returns the input dataTypes of the block. * * @return the input dataTypes of the block */ DataType[] getOutputDataTypes(); /** * Writes the parameters of the block to the given outputStream. * * @param os the outputstream to save the parameters to * @throws IOException if an I/O error occurs */ void saveParameters(DataOutputStream os) throws IOException; /** * Loads the parameters from the given input stream. * * @param manager an NDManager to create the parameter arrays * @param is the inputstream that stream the parameter values * @throws IOException if an I/O error occurs * @throws MalformedModelException if the model file is corrupted or unsupported */ void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException; /** * Freezes or unfreezes all parameters inside the block for training. * * @param freeze true if the parameter should be frozen * @see Parameter#freeze(boolean) */ default void freezeParameters(boolean freeze) { for (Parameter parameter : getParameters().values()) { parameter.freeze(freeze); } } /** * Validates that actual layout matches the expected layout. * * @param expectedLayout the expected layout * @param actualLayout the actual Layout * @throws UnsupportedOperationException if the actual layout does not match the expected layout */ static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayout) { if (actualLayout.length != expectedLayout.length) { throw new UnsupportedOperationException( "Expected layout: " + LayoutType.toString(expectedLayout) + ", but got: " + LayoutType.toString(actualLayout)); } for (int i = 0; i < actualLayout.length; i++) { if (actualLayout[i] != LayoutType.UNKNOWN && actualLayout[i] != expectedLayout[i]) { throw new UnsupportedOperationException( "Expected layout: " + LayoutType.toString(expectedLayout) + ", but got: " + LayoutType.toString(actualLayout)); } } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy