ai.djl.nn.core.Linear Maven / Gradle / Ivy
* Copyright 2019, Inc. or its affiliates. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
package ai.djl.nn.core;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.Collections;
* A Linear block applies a linear transformation \(Y = XW^T + b\).
* It has the following shapes:
If {@code flatten} is true:
* - input X: [batchSize, x1, x2, …, xn]
- weight W: [outChannels, x1 * x2 * … * xn]
- Bias b: [outChannels]
- output Y: [batchSize, outChannels]
* If {@code flatten} is false:
* - input X: [x1, x2, …, xn, input_dim]
- weight W: [outChannels, input_dim]
- Bias b: [outChannels]
- output Y: [x1, x2, …, xn, outChannels]
* The Linear block should be constructed using {@link Linear.Builder}.
public class Linear extends AbstractBlock {
private static final byte VERSION = 3;
private long outChannels;
private long inputDimension;
private boolean flatten;
private Shape inputShape;
private Parameter weight;
private Parameter bias;
Linear(Builder builder) {
outChannels = builder.outChannels;
flatten = builder.flatten;
// "inputDimension" is only known after "beforeInitialize" is called, hence we need
// a callback, even if we do not used the callback parameter
weight =
new Parameter("weight", this, ParameterType.WEIGHT),
inputShapes -> new Shape(outChannels, inputDimension));
if (builder.bias) {
bias =
new Parameter("bias", this, ParameterType.BIAS),
new Shape(outChannels));
/** {@inheritDoc} */
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList params) {
inputs = opInputs(parameterStore, inputs);
NDArrayEx ex = inputs.head().getNDArrayInternal();
return ex.fullyConnected(inputs, outChannels, flatten, bias == null, params);
/** {@inheritDoc} */
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
if (flatten) {
return new Shape[] {new Shape(inputs[0].get(0), outChannels)};
return new Shape[] {inputShape.addAll(new Shape(outChannels))};
/** {@inheritDoc} */
public PairList describeInput() {
return new PairList<>(
Collections.singletonList("linearInput"), Collections.singletonList(inputShape));
/** {@inheritDoc} */
public void beforeInitialize(Shape[] inputShapes) {
this.inputShapes = inputShapes;
Shape input = inputShapes[0];
if (flatten) {
Shape inChannels;
if (input.isLayoutKnown()) {
inChannels = input.filterByLayoutType(t -> !t.equals(LayoutType.BATCH));
inputShape =
pair ->
new Pair<>(
? Long.valueOf(-1)
: pair.getKey(),
} else if (input.dimension() > 1) {
inChannels = input.slice(1);
inputShape =
new Shape(new long[] {-1}, new LayoutType[] {LayoutType.BATCH})
} else {
inChannels = input;
inputShape = input;
inputDimension = inChannels.size();
} else {
inputDimension = input.get(input.dimension() - 1);
inputShape = input.slice(0, input.dimension() - 1);
/** {@inheritDoc} */
protected void saveMetadata(DataOutputStream os) throws IOException {
/** {@inheritDoc} */
public void loadMetadata(byte version, DataInputStream is)
throws IOException, MalformedModelException {
if (version < 1 || version > VERSION) {
throw new MalformedModelException("Unsupported encoding version: " + version);
if (version == VERSION) {
outChannels = is.readLong();
flatten = is.readBoolean();
inputDimension = is.readLong();
} else if (version == 2) {
flatten = is.readBoolean();
inputDimension = is.readLong();
} else {
flatten = false;
inputDimension = Shape.decode(is).size();
inputShape = Shape.decode(is);
private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
if (inputs.size() != 1) {
throw new IllegalArgumentException("Linear requires exactly 1 NDArray");
Device device = inputs.head().getDevice();
NDList result = new NDList(inputs);
result.add(parameterStore.getValue(weight, device));
if (bias != null) {
result.add(parameterStore.getValue(bias, device));
return result;
* Creates a builder to build a {@code Linear}.
* @return a new builder
public static Builder builder() {
return new Builder();
/** The Builder to construct a {@link Linear} type of {@link Block}. */
public static final class Builder {
private long outChannels;
private boolean bias = true;
private boolean flatten;
Builder() {}
* Sets the number of output channels.
* @param outChannels the number of desired output channels
* @return this Builder
public Builder setOutChannels(long outChannels) {
this.outChannels = outChannels;
return this;
* Sets the optional parameter that indicates whether to include a bias vector with default
* value of true.
* @param bias whether to use a bias vector parameter
* @return this Builder
public Builder optBias(boolean bias) {
this.bias = bias;
return this;
* Sets the optional parameter that indicates whether the input tensor should be flattened.
* If flatten is set to true, all but the first axis of input data are collapsed
* together. If false, all but the last axis of input data are kept the same, and the
* transformation applies on the last axis.
* @param flatten whether the input tensor should be flattened.
* @return this Builder
public Builder optFlatten(boolean flatten) {
this.flatten = flatten;
return this;
* Returns the constructed {@code Linear}.
* @return the constructed {@code Linear}
* @throws IllegalArgumentException if all required parameters (outChannels) have not been
* set
public Linear build() {
if (outChannels <= 0) {
throw new IllegalArgumentException("You must specify outChannels");
return new Linear(this);