ai.djl.training.Trainer Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.training;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.UninitializedParameterException;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.EpochTrainingListener;
import ai.djl.training.listener.EvaluatorTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
/**
* The {@code Trainer} interface provides a session for model training.
*
* {@code Trainer} provides an easy, and manageable interface for training. {@code Trainer} is
* not thread-safe.
*
*
See the tutorials on:
*
*
* - Training
* your first model
*
- Training
* using transfer learning
*
- Inference
* with an MXNet model
*
*
* @see The guide on
* memory management
*/
public class Trainer implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(Trainer.class);
private Model model;
private NDManager manager;
private Metrics metrics;
private List listeners;
private Device[] devices;
private ParameterStore parameterStore;
private List evaluators;
private Loss loss;
private ExecutorService executorService;
private boolean gradientsChecked;
/**
* Creates an instance of {@code Trainer} with the given {@link Model} and {@link
* TrainingConfig}.
*
* @param model the model the trainer will train on
* @param trainingConfig the configuration used by the trainer
*/
@SuppressWarnings("this-escape")
public Trainer(Model model, TrainingConfig trainingConfig) {
this.model = model;
manager = model.getNDManager().newSubManager();
manager.setName("trainer");
devices = trainingConfig.getDevices();
loss = trainingConfig.getLossFunction();
Objects.requireNonNull(loss, "You must specify a loss for the trainer");
evaluators = new ArrayList<>(trainingConfig.getEvaluators());
evaluators.add(loss); // track loss as an evaluator by default
executorService = trainingConfig.getExecutorService();
ParameterServer parameterServer =
manager.getEngine().newParameterServer(trainingConfig.getOptimizer());
parameterStore = new ParameterStore(manager, false);
parameterStore.setParameterServer(parameterServer, devices);
listeners = trainingConfig.getTrainingListeners();
notifyListeners(listener -> listener.onTrainingBegin(this));
}
/**
* Initializes the {@link Model} that the {@code Trainer} is going to train.
*
* @param shapes an array of {@code Shape} of the inputs
*/
public void initialize(Shape... shapes) {
model.getBlock().initialize(model.getNDManager(), model.getDataType(), shapes);
// call getValue on all params to initialize on all devices
model.getBlock()
.getParameters()
.forEach(
pair -> {
for (Device device : devices) {
try {
parameterStore.getValue(pair.getValue(), device, true);
} catch (UninitializedParameterException e) {
throw new IllegalStateException(
"Failed to initialize parameter: "
+ pair.getKey()
+ ".\n"
+ "If you are defining a Block extending"
+ " AbstractBlock, check that you are"
+ " initializing all child blocks as part of"
+ " the overload for"
+ " AbstractBlock.initializeChildBlocks().",
e);
}
}
});
}
/**
* Fetches an iterator that can iterate through the given {@link Dataset}.
*
* @param dataset the dataset to iterate through
* @return an {@link Iterable} of {@link Batch} that contains batches of data from the dataset
* @throws IOException for various exceptions depending on the dataset
* @throws TranslateException if there is an error while processing input
*/
public Iterable iterateDataset(Dataset dataset) throws IOException, TranslateException {
return dataset.getData(getManager(), executorService);
}
/**
* Returns a new instance of {@link GradientCollector}.
*
* @return a new instance of {@link GradientCollector}
*/
public GradientCollector newGradientCollector() {
return manager.getEngine().newGradientCollector();
}
/**
* Applies the forward function of the model once on the given input {@link NDList}.
*
* @param input the input {@link NDList}
* @return the output of the forward function
*/
public NDList forward(NDList input) {
long begin = System.nanoTime();
try {
return model.getBlock().forward(parameterStore, input, true);
} finally {
addMetric("forward", begin);
}
}
/**
* Applies the forward function of the model once with both data and labels.
*
* @param data the input data {@link NDList}
* @param labels the input labels {@link NDList}
* @return the output of the forward function
*/
public NDList forward(NDList data, NDList labels) {
long begin = System.nanoTime();
try {
return model.getBlock().forward(parameterStore, data, labels, null);
} finally {
addMetric("forward", begin);
}
}
/**
* Evaluates function of the model once on the given input {@link NDList}.
*
* @param input the input {@link NDList}
* @return the output of the predict function
*/
public NDList evaluate(NDList input) {
return model.getBlock().forward(parameterStore, input, false, null);
}
/** Updates all of the parameters of the model once. */
public void step() {
if (!gradientsChecked) {
checkGradients();
}
long begin = System.nanoTime();
parameterStore.updateAllParameters();
addMetric("step", begin);
}
/**
* Returns the Metrics param used for benchmarking.
*
* @return the the Metrics param used for benchmarking
*/
public Metrics getMetrics() {
return metrics;
}
/**
* Attaches a Metrics param to use for benchmarking.
*
* @param metrics the Metrics class
*/
public void setMetrics(Metrics metrics) {
this.metrics = metrics;
}
/**
* Returns the devices used for training.
*
* @return the devices used for training
*/
public Device[] getDevices() {
return devices;
}
/**
* Gets the training {@link Loss} function of the trainer.
*
* @return the {@link Loss} function
*/
public Loss getLoss() {
return loss;
}
/**
* Returns the model used to create this trainer.
*
* @return the model associated with this trainer
*/
public Model getModel() {
return model;
}
/**
* Returns the {@link ExecutorService}.
*
* @return the {@link ExecutorService}
*/
public Optional getExecutorService() {
return Optional.ofNullable(executorService);
}
/**
* Gets all {@link Evaluator}s.
*
* @return the evaluators used during training
*/
public List getEvaluators() {
return evaluators;
}
/**
* Executes a method on each of the {@link TrainingListener}s.
*
* @param listenerConsumer a consumer that executes the method
*/
public final void notifyListeners(Consumer listenerConsumer) {
listeners.forEach(listenerConsumer);
}
/**
* Returns the {@link TrainingResult}.
*
* @return the {@code TrainingResult}
*/
public TrainingResult getTrainingResult() {
TrainingResult result = new TrainingResult();
for (TrainingListener listener : listeners) {
if (listener instanceof EpochTrainingListener) {
result.setEpoch(((EpochTrainingListener) listener).getNumEpochs());
} else if (listener instanceof EvaluatorTrainingListener) {
EvaluatorTrainingListener l = (EvaluatorTrainingListener) listener;
result.setEvaluations(l.getLatestEvaluations());
}
}
return result;
}
/**
* Gets the {@link NDManager} from the model.
*
* @return the {@link NDManager}
*/
public NDManager getManager() {
return manager;
}
/** {@inheritDoc} */
@SuppressWarnings("deprecation")
@Override
protected void finalize() throws Throwable {
if (manager.isOpen()) {
if (logger.isDebugEnabled()) {
logger.warn("Trainer for {} was not closed explicitly.", model.getName());
}
close();
}
super.finalize();
}
/** {@inheritDoc} */
@Override
public void close() {
notifyListeners(listener -> listener.onTrainingEnd(this));
parameterStore.sync();
manager.close();
}
/**
* Checks if all gradients are zeros. This prevent users from calling step() without running
* {@code backward}.
*/
private void checkGradients() {
List grads = new ArrayList<>();
model.getBlock().getParameters().values().stream()
.filter(Parameter::requiresGradient)
.forEach(
param ->
grads.add(
parameterStore
.getValue(param, devices[0], true)
.getGradient()));
try (NDManager scoped = manager.newSubManager()) {
scoped.tempAttachAll(new NDList(grads));
NDList list = new NDList(grads.stream().map(NDArray::sum).toArray(NDArray[]::new));
float gradSum = NDArrays.stack(list).sum().getFloat();
if (gradSum == 0f) {
throw new IllegalStateException(
"Gradient values are all zeros, please call gradientCollector.backward() on"
+ "your target NDArray (usually loss), before calling step() ");
}
gradientsChecked = true;
}
}
/**
* Helper to add a metric for a time difference.
*
* @param metricName the metric name
* @param begin the time difference start (this method is called at the time difference end)
*/
public void addMetric(String metricName, long begin) {
if (metrics != null && begin > 0L) {
metrics.addMetric(metricName, System.nanoTime() - begin);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy