ai.djl.nn.LambdaBlock Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.function.Function;
/**
* {@code LambdaBlock} is a {@link Block} with no parameters or children.
*
* {@code LambdaBlock} allows converting any function that takes an {@code NDList} as input and
* returns an {@code NDList} into a parameter-less and child-less {@link Block}. This can be useful
* in converting activation functions, identity blocks, and more. For example, identity block can be
* created as {@code new LambdaBlock(x -> x)}.
*/
public class LambdaBlock extends AbstractBlock {
public static final String DEFAULT_NAME = "anonymous";
private static final byte VERSION = 2;
private Function lambda;
private String name;
/**
* Creates a LambdaBlock that can apply the specified function.
*
* @param lambda the function to apply
*/
public LambdaBlock(Function lambda) {
this(lambda, DEFAULT_NAME);
}
/**
* Creates a LambdaBlock that can apply the specified function.
*
* @param lambda the function to apply
* @param name the function name
*/
public LambdaBlock(Function lambda, String name) {
super(VERSION);
this.lambda = lambda;
this.name = name;
}
/**
* Returns the lambda function name.
*
* @return the lambda function name
*/
public String getName() {
return name;
}
/**
* Creates a {@link LambdaBlock} for a singleton function.
*
* @param lambda a function accepting a singleton {@link NDList} and returning another singleton
* {@link NDList}
* @return a new {@link LambdaBlock} for the function
*/
public static LambdaBlock singleton(Function lambda) {
return new LambdaBlock(
arrays -> new NDList(lambda.apply(arrays.singletonOrThrow())),
lambda.getClass().getSimpleName());
}
/**
* Creates a {@link LambdaBlock} for a singleton function.
*
* @param lambda a function accepting a singleton {@link NDList} and returning another singleton
* {@link NDList}
* @param name the function name
* @return a new {@link LambdaBlock} for the function
*/
public static LambdaBlock singleton(Function lambda, String name) {
return new LambdaBlock(arrays -> new NDList(lambda.apply(arrays.singletonOrThrow())), name);
}
/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList params) {
return lambda.apply(inputs);
}
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
try (NDManager manager = NDManager.newBaseManager()) {
NDList input = new NDList(inputShapes.length);
for (Shape shape : inputShapes) {
input.add(manager.zeros(shape));
}
NDList output = lambda.apply(input);
Shape[] outputShapes = new Shape[output.size()];
DataType[] dataTypes = new DataType[output.size()];
for (int i = 0; i < output.size(); ++i) {
outputShapes[i] = output.get(i).getShape();
dataTypes[i] = output.get(i).getDataType();
}
outputDataTypes = dataTypes;
return outputShapes;
}
}
/** {@inheritDoc} */
@Override
public void loadParameters(NDManager manager, DataInputStream is)
throws IOException, MalformedModelException {
byte version = is.readByte();
if (version == VERSION) {
readInputShapes(is);
} else if (version != 1) {
throw new MalformedModelException("Unsupported encoding version: " + version);
}
}
}