ai.djl.nn.AbstractBlock Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* {@code AbstractBlock} is an abstract implementation of {@link Block}. It is recommended that all
* {@code Block} classes that have children extend the {@code AbstractBlock}.
*/
public abstract class AbstractBlock implements Block {
protected Shape[] inputShapes;
protected List inputNames = Collections.singletonList("data");
/** {@inheritDoc} */
@Override
public PairList describeInput() {
if (!isInitialized()) {
throw new IllegalStateException("Parameter of this block are not initialised");
}
return new PairList<>(inputNames, Arrays.asList(inputShapes));
}
/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
for (Parameter parameter : getDirectParameters()) {
parameter.setInitializer(initializer, false);
}
for (Block child : getChildren().values()) {
child.setInitializer(initializer);
}
}
/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer, String paramName) {
Parameter parameter =
getDirectParameters()
.stream()
.filter(pair -> pair.getName().equals(paramName))
.findFirst()
.orElseThrow(
() ->
new IllegalArgumentException(
"Could not find parameter " + paramName));
parameter.setInitializer(initializer, true);
}
/** {@inheritDoc} */
@Override
public ParameterList getParameters() {
ParameterList parameters = new ParameterList();
List directParams = getDirectParameters();
directParams.forEach(param -> parameters.add(param.getName(), param));
PairList childrenParameters = getChildrenParameters();
childrenParameters.forEach(parameters::add);
return parameters;
}
/**
* Performs any action necessary before initialization.
*
* @param inputShapes the expected shapes of the input
*/
protected void beforeInitialize(Shape[] inputShapes) {
this.inputShapes = inputShapes;
}
/** {@inheritDoc} */
@Override
public boolean isInitialized() {
for (Parameter param : getParameters().values()) {
if (!param.isInitialized()) {
return false;
}
}
return true;
}
/** {@inheritDoc} */
@Override
public void clear() {
getParameters().forEach(param -> param.getValue().close());
}
/** {@inheritDoc} */
@Override
public void cast(DataType dataType) {
throw new UnsupportedOperationException("Not implemented yet.");
}
protected void saveInputShapes(DataOutputStream os) throws IOException {
os.writeInt(inputShapes.length);
for (Shape shape : inputShapes) {
os.write(shape.getEncoded());
}
}
protected void readInputShapes(DataInputStream is) throws IOException {
int len = is.readInt();
inputShapes = new Shape[len];
for (int i = 0; i < len; ++i) {
inputShapes[i] = Shape.decode(is);
}
}
private ParameterList getChildrenParameters() {
ParameterList parameters = new ParameterList();
for (Pair childPair : getChildren()) {
for (Pair paramPair : childPair.getValue().getParameters()) {
parameters.add(childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue());
}
}
return parameters;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy