All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.DefaultTrainingConfig Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.initializer.XavierInitializer.FactorType;
import ai.djl.training.initializer.XavierInitializer.RandomType;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */
public class DefaultTrainingConfig implements TrainingConfig {

    private Initializer initializer;
    private Optimizer optimizer;
    private Device[] devices;
    private Loss loss;
    private List evaluators;
    private List listeners;

    /**
     * Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code
     * DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link
     * XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The
     * evaluators and listeners are left to the user's discretion.
     *
     * @param loss the loss to use for training
     */
    public DefaultTrainingConfig(Loss loss) {
        // Defaults to initializer defined in https://arxiv.org/abs/1502.01852
        this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2);
        optimizer = Adam.builder().build();
        this.loss = loss;
        evaluators = new ArrayList<>();
        listeners = new ArrayList<>();
    }

    /**
     * Sets the {@link Initializer} to use for the parameters (default from paper).
     *
     * @param initializer the initialer to use for the parameters
     * @return this {@code DefaultTrainingConfig}
     */
    public DefaultTrainingConfig optInitializer(Initializer initializer) {
        this.initializer = initializer;
        return this;
    }

    /**
     * Sets the array of {@link Device} available for training.
     *
     * @param devices an array of devices to be set
     * @return this {@code DefaultTrainingConfig}
     */
    public DefaultTrainingConfig optDevices(Device[] devices) {
        this.devices = devices;
        return this;
    }

    /**
     * Sets the {@link Optimizer} used during training (default {@link Adam}).
     *
     * @param optimizer the optimizer to be set
     * @return this {@code DefaultTrainingConfig}
     */
    public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
        this.optimizer = optimizer;
        return this;
    }

    /**
     * Adds an {@link Evaluator} that needs to be computed during training.
     *
     * @param evaluator the evaluator to be added
     * @return this {@code DefaultTrainingConfig}
     */
    public DefaultTrainingConfig addEvaluator(Evaluator evaluator) {
        evaluators.add(evaluator);
        return this;
    }

    /**
     * Adds {@link TrainingListener}s for training.
     *
     * @param listeners the {@link TrainingListener}s to add
     * @return this {@code DefaultTrainingConfig}
     */
    public DefaultTrainingConfig addTrainingListeners(TrainingListener... listeners) {
        this.listeners.addAll(Arrays.asList(listeners));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public Device[] getDevices() {
        if (devices == null) {
            return Device.getDevices(Integer.MAX_VALUE);
        }
        return devices;
    }

    /** {@inheritDoc} */
    @Override
    public Initializer getInitializer() {
        return initializer;
    }

    /** {@inheritDoc} */
    @Override
    public Optimizer getOptimizer() {
        return optimizer;
    }

    /** {@inheritDoc} */
    @Override
    public Loss getLossFunction() {
        return loss;
    }

    /** {@inheritDoc} */
    @Override
    public List getEvaluators() {
        return evaluators;
    }

    /** {@inheritDoc} */
    @Override
    public List getTrainingListeners() {
        return listeners;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy