ai.djl.nn.transformer.TransformerBaseBlock Maven / Gradle / Ivy
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn.transformer;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.Parameter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.function.Function;
/**
* Utility base class for implementing custom blocks. Provides utilities to handle nested child
* blocks, parameters and serialization of state.
*/
public abstract class TransformerBaseBlock extends AbstractBlock {
/**
* The model version of this block, used for checking if parameters are still valid during
* parameter loading.
*/
protected int version;
/**
* All direct children of this Block. Keys are names of the blocks. Use the {@link
* TransformerBaseBlock#addChildBlock(String, Block)} method to add children. All children in
* this map are automagically loaded / saved.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the children are always iterated over in insertion order. LinkedHashMap
// provides this guarantee, Map does not.
@SuppressWarnings("PMD.LooseCoupling")
protected LinkedHashMap children = new LinkedHashMap<>();
/**
* All direct parameters of this Block. Keys are name of the parameters. Use the {@link
* TransformerBaseBlock#addParameter(Parameter)} method to add children. All parameters in this
* map are automagically loaded / saved.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the parameters are always iterated over in insertion order. LinkedHashMap
// provides this guarantee, Map does not.
@SuppressWarnings("PMD.LooseCoupling")
protected LinkedHashMap parameters = new LinkedHashMap<>();
/**
* Callbacks to determine the shape of a parameter. Values may be null in which case extending
* classes need to override {@link Block#getParameterShape(String, Shape[])} and implement
* parameter shape resolution manually.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the callbacks are always iterated over in insertion order. LinkedHashMap
// provides this guarantee, Map does not.
@SuppressWarnings("PMD.LooseCoupling")
protected LinkedHashMap> parameterShapeCallbacks =
new LinkedHashMap<>();
/**
* Builds an empty block with the given version for parameter serialization.
*
* @param version the version to use for parameter serialization.
*/
public TransformerBaseBlock(int version) {
this.version = version;
}
/**
* Returns the version number to be used for parameter serialization.
*
* @return the version number to be used for parameter serialization
*/
public int getVersion() {
return version;
}
/**
* Adds a child block to this block.
*
* @param name Name of the block, must be unique or otherwise existing children with this name
* are removed, must not be null.
* @param block The block, must not be null.
* @param The type of block
* @return the block given as a parameter - that way the block can be created and reassigned to
* a member variable more easily.
*/
protected B addChildBlock(String name, B block) {
children.put(name, block);
return block;
}
/**
* Adds a parameter to this block. If parameters are added with this method, subclasses need to
* override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters
* themselves.
*
* @param parameter the parameter to add, not null
* @param the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected
P addParameter(P parameter) {
return addParameter(parameter, (Function) null);
}
/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shape the shape of the parameter
* @param the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected
P addParameter(P parameter, Shape shape) {
return addParameter(parameter, (inputShapes) -> shape);
}
/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shapeCallback the method to call once the input shape of this block is known to
* determine the shape of the given parameter
* @param
the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign parameters
* in one line
*/
protected
P addParameter(
P parameter, Function shapeCallback) {
parameters.put(parameter.getName(), parameter);
parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
return parameter;
}
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
Function callback = parameterShapeCallbacks.get(name);
if (callback == null) {
Parameter parameter = parameters.get(name);
if (parameter == null) {
throw new IllegalArgumentException(
"No parameter named " + name + " found in this block.");
} else {
throw new IllegalStateException(
"No shape initializer for parameter "
+ name
+ "found. "
+ "Either pass an initializer for the shape when adding the parameter or override "
+ "getParameterShape in the subclass.");
}
}
return callback.apply(inputShapes);
}
@Override
public BlockList getChildren() {
return new BlockList(children);
}
/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
for (Parameter parameter : getDirectParameters()) {
parameter.initialize(manager, dataType, inputShapes);
}
initializeChildBlocks(manager, dataType, inputShapes);
return getOutputShapes(manager, inputShapes);
}
/**
* Initializes the Child blocks of this block.
*
* @param manager the manager to use for initialization
* @param dataType the requested data type
* @param inputShapes the expected input shapes
*/
public abstract void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes);
@Override
public List getDirectParameters() {
return new ArrayList<>(parameters.values());
}
@Override
public void saveParameters(DataOutputStream os) throws IOException {
os.write(version);
for (Parameter parameter : parameters.values()) {
parameter.save(os);
}
for (Block child : children.values()) {
child.saveParameters(os);
}
}
@Override
public void loadParameters(NDManager manager, DataInputStream is)
throws IOException, MalformedModelException {
int loadVersion = is.readInt();
if (loadVersion != getVersion()) {
throw new MalformedModelException(
"Cannot load parameters for "
+ this.getClass().getCanonicalName()
+ ", expected version "
+ getVersion()
+ ", got "
+ loadVersion
+ ".");
}
for (Parameter parameter : parameters.values()) {
parameter.load(manager, is);
}
for (Block child : children.values()) {
child.loadParameters(manager, is);
}
}
}