ai.djl.nn.Blocks Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
/** Utility class that provides some useful blocks. */
public final class Blocks {
private Blocks() {}
/**
* Inflates the {@link ai.djl.ndarray.NDArray} provided as input to a 2-D {@link
* ai.djl.ndarray.NDArray} of shape (batch, size).
*
* @param array a array to be flattened
* @return a {@link NDList} that contains the inflated {@link ai.djl.ndarray.NDArray}
*/
public static NDArray batchFlatten(NDArray array) {
long batch = array.size(0);
return array.reshape(batch, -1);
}
/**
* Inflates the {@link ai.djl.ndarray.NDArray} provided as input to a 2-D {@link
* ai.djl.ndarray.NDArray} of shape (batch, size).
*
* @param array a array to be flattened
* @param size the input size
* @return a {@link NDList} that contains the inflated {@link ai.djl.ndarray.NDArray}
* @throws IndexOutOfBoundsException if the input {@link NDList} has more than one {@link
* ai.djl.ndarray.NDArray}
*/
public static NDArray batchFlatten(NDArray array, long size) {
return array.reshape(-1, size);
}
/**
* Creates a {@link Block} whose forward function applies the {@link #batchFlatten(NDArray)
* batchFlatten} method.
*
* @return a {@link Block} whose forward function applies the {@link #batchFlatten(NDArray)
* batchFlatten} method
*/
public static Block batchFlattenBlock() {
return LambdaBlock.singleton(Blocks::batchFlatten);
}
/**
* Creates a {@link Block} whose forward function applies the {@link #batchFlatten(NDArray)
* batchFlatten} method. The size of input to the block returned must be batch_size * size.
*
* @param size the expected size of each input
* @return a {@link Block} whose forward function applies the {@link #batchFlatten(NDArray)
* batchFlatten} method
*/
public static Block batchFlattenBlock(long size) {
return LambdaBlock.singleton(array -> batchFlatten(array, size));
}
/**
* Creates a {@link LambdaBlock} that performs the identity function.
*
* @return an identity {@link Block}
*/
public static Block identityBlock() {
return new LambdaBlock(x -> x);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy