All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.Parameter Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

/**
 * {@code Parameter} is a container class that holds a learnable parameter of a model.
 *
 * 

Every {@code Parameter} is associated with a {@link Block}. The output of the block's forward * function depends on the values in the {@code Parameter}. During training, the values in the * {@code Parameter} are updated to reflect the training data. This process forms the crux of * learning. * * @see The D2L * chapter on parameter management */ public class Parameter implements AutoCloseable { private static final byte VERSION = 1; private String id; private String name; private Shape shape; private Type type; private Initializer initializer; private NDArray array; private boolean requiresGrad; Parameter(Builder builder) { this.id = UUID.randomUUID().toString(); this.name = builder.name; this.shape = builder.shape; this.type = builder.type; this.array = builder.array; this.requiresGrad = builder.requiresGrad; this.initializer = (builder.initializer != null) ? builder.initializer : type.getInitializer(); } /** * Gets the ID of this {@code Parameter}. * * @return the ID of this {@code Parameter} */ public String getId() { return id; } /** * Gets the name of this {@code Parameter}. * * @return the name of this {@code Parameter} */ public String getName() { return name == null ? "" : name; } /** * Gets the type of this {@code Parameter}. * * @return the type of this {@code Parameter} */ public Type getType() { return type; } /** * Sets the values of this {@code Parameter}. * * @param array the {@link NDArray} that contains values of this {@code Parameter} */ public void setArray(NDArray array) { if (shape != null) { throw new IllegalStateException("array has been set! Use either setArray or setShape"); } this.array = array; shape = array.getShape(); array.setName(name); } /** * Sets the shape of this {@code Parameter}. * * @param shape the shape of this {@code Parameter} */ public void setShape(Shape shape) { if (array != null) { throw new IllegalStateException("array has been set! Use either setArray or setShape"); } this.shape = shape; } /** * Gets the values of this {@code Parameter} as an {@link NDArray}. * * @return an {@link NDArray} that contains values of this {@code Parameter} */ public NDArray getArray() { if (!isInitialized()) { throw new IllegalStateException("The array has not been initialized"); } return array; } /** * Returns whether this parameter needs gradients to be computed. * * @return whether this parameter needs gradients to be computed */ public boolean requiresGradient() { return requiresGrad; } /** * Checks if this {@code Parameter} is initialized. * * @return {@code true} if this {@code Parameter} is initialized */ public boolean isInitialized() { return array != null; } /** * Sets the {@link Initializer} for this {@code Parameter}, if not already set. If overwrite * flag is true, sets the initializer regardless. * * @param initializer the initializer to be set */ public void setInitializer(Initializer initializer) { this.initializer = initializer; } /** * Initializes the parameter with the given {@link NDManager}, with given {@link DataType} for * the given expected input shapes. * * @param manager an NDManager to create the arrays * @param dataType the datatype of the {@code Parameter} */ public void initialize(NDManager manager, DataType dataType) { Objects.requireNonNull(initializer, "No initializer has been set"); Objects.requireNonNull(shape, "No parameter shape has been set"); if (!isInitialized()) { array = initializer.initialize(manager, shape, dataType); array.setName(name); } if (requiresGradient()) { array.setRequiresGradient(true); } } /** * Writes the parameter NDArrays to the given output stream. * * @param dos the output stream to write to * @throws IOException if the write operation fails */ public void save(DataOutputStream dos) throws IOException { if (!isInitialized()) { dos.writeChar('N'); return; } dos.writeChar('P'); dos.writeByte(VERSION); dos.writeUTF(getName()); dos.write(array.encode()); } /** * Loads parameter NDArrays from InputStream. * *

Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray * will be loaded as NDArray only. * * @param manager the NDManager * @param dis the InputStream * @throws IOException if failed to read * @throws MalformedModelException Exception thrown when model is not in expected format * (parameters). */ public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException { char magic = dis.readChar(); if (magic == 'N') { return; } else if (magic != 'P') { throw new MalformedModelException("Invalid input data."); } // Version byte version = dis.readByte(); if (version != VERSION) { throw new MalformedModelException("Unsupported encoding version: " + version); } String parameterName = dis.readUTF(); if (!parameterName.equals(getName())) { throw new MalformedModelException( "Unexpected parameter name: " + parameterName + ", expected: " + name); } array = manager.decode(dis); // set the shape of the parameter and prepare() can be skipped shape = array.getShape(); } /** {@inheritDoc} */ @Override public void close() { if (array != null) { array.close(); array = null; } } /** * Creates a builder to build a {@code Parameter}. * *

The methods start with {@code set} are required fields, and {@code opt} for optional * fields. * * @return a new builder */ public static Parameter.Builder builder() { return new Parameter.Builder(); } /** Enumerates the types of {@link Parameter}. */ public enum Type { WEIGHT( new XavierInitializer( XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2)), BIAS(Initializer.ZEROS), GAMMA(Initializer.ONES), BETA(Initializer.ZEROS), RUNNING_MEAN(Initializer.ZEROS), RUNNING_VAR(Initializer.ONES), OTHER(null); private final transient Initializer initializer; Type(Initializer initializer) { this.initializer = initializer; } /** * Gets the {@link Initializer} of this {@code ParameterType}. * * @return the {@link Initializer} of this {@code ParameterType} */ public Initializer getInitializer() { return initializer; } } /** A Builder to construct a {@code Parameter}. */ public static final class Builder { String name; Shape shape; Type type; Initializer initializer; NDArray array; boolean requiresGrad = true; /** * Sets the name of the {@code Parameter}. * * @param name the name of the {@code Parameter} * @return this {@code Parameter} */ public Builder setName(String name) { this.name = name; return this; } /** * Sets the {@code Type} of the {@code Parameter}. * * @param type the {@code Type} of the {@code Parameter} * @return this {@code Parameter} */ public Builder setType(Type type) { this.type = type; return this; } /** * Sets the shape of the {@code Parameter}. * * @param shape the shape of the {@code Parameter} * @return this {@code Parameter} */ public Builder optShape(Shape shape) { this.shape = shape; return this; } /** * Sets the Initializer of the {@code Parameter}. * * @param initializer the Initializer of the {@code Parameter} * @return this {@code Parameter} */ public Builder optInitializer(Initializer initializer) { this.initializer = initializer; return this; } /** * Sets the array of the {@code Parameter}. * * @param array the array of the {@code Parameter} * @return this {@code Parameter} */ public Builder optArray(NDArray array) { this.array = array; return this; } /** * Sets if the {@code Parameter} requires gradient. * * @param requiresGrad if the {@code Parameter} requires gradient * @return this {@code Parameter} */ public Builder optRequiresGrad(boolean requiresGrad) { this.requiresGrad = requiresGrad; return this; } /** * Builds a {@code Parameter} instance. * * @return the {@code Parameter} instance */ public Parameter build() { return new Parameter(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy