ai.djl.training.DefaultTrainingConfig Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.training;
import ai.djl.Device;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.initializer.XavierInitializer.FactorType;
import ai.djl.training.initializer.XavierInitializer.RandomType;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */
public class DefaultTrainingConfig implements TrainingConfig {
private Initializer initializer;
private Optimizer optimizer;
private Device[] devices;
private Loss loss;
private List evaluators;
private List listeners;
/**
* Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code
* DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link
* XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The
* evaluators and listeners are left to the user's discretion.
*
* @param loss the loss to use for training
*/
public DefaultTrainingConfig(Loss loss) {
// Defaults to initializer defined in https://arxiv.org/abs/1502.01852
this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2);
optimizer = Adam.builder().build();
this.loss = loss;
evaluators = new ArrayList<>();
listeners = new ArrayList<>();
}
/**
* Sets the {@link Initializer} to use for the parameters (default from paper).
*
* @param initializer the initialer to use for the parameters
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optInitializer(Initializer initializer) {
this.initializer = initializer;
return this;
}
/**
* Sets the array of {@link Device} available for training.
*
* @param devices an array of devices to be set
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optDevices(Device[] devices) {
this.devices = devices;
return this;
}
/**
* Sets the {@link Optimizer} used during training (default {@link Adam}).
*
* @param optimizer the optimizer to be set
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
this.optimizer = optimizer;
return this;
}
/**
* Adds an {@link Evaluator} that needs to be computed during training.
*
* @param evaluator the evaluator to be added
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig addEvaluator(Evaluator evaluator) {
evaluators.add(evaluator);
return this;
}
/**
* Adds {@link TrainingListener}s for training.
*
* @param listeners the {@link TrainingListener}s to add
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig addTrainingListeners(TrainingListener... listeners) {
this.listeners.addAll(Arrays.asList(listeners));
return this;
}
/** {@inheritDoc} */
@Override
public Device[] getDevices() {
if (devices == null) {
return Device.getDevices(Integer.MAX_VALUE);
}
return devices;
}
/** {@inheritDoc} */
@Override
public Initializer getInitializer() {
return initializer;
}
/** {@inheritDoc} */
@Override
public Optimizer getOptimizer() {
return optimizer;
}
/** {@inheritDoc} */
@Override
public Loss getLossFunction() {
return loss;
}
/** {@inheritDoc} */
@Override
public List getEvaluators() {
return evaluators;
}
/** {@inheritDoc} */
@Override
public List getTrainingListeners() {
return listeners;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy