All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.optimizer.learningrate.LearningRateTracker Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.optimizer.learningrate;

import ai.djl.TrainingDivergedException;

/**
 * A {@code LearningRateTracker} tracks the evolution of the learning rate through the training
 * process.
 */
public abstract class LearningRateTracker {

    float baseLearningRate;
    int warmUpSteps;
    float warmUpBeginLearningRate;
    float warmUpFinalLearningRate;
    WarmUpMode warmUpMode;

    /**
     * A tracker returns a new learning rate based on the number of updates that have been
     * performed.
     *
     * @param builder the builder that configures learning rate options
     */
    LearningRateTracker(LrBaseBuilder builder) {
        this.baseLearningRate = builder.baseLearningRate;
        this.warmUpSteps = builder.warmUpSteps;
        this.warmUpBeginLearningRate = builder.warmUpBeginLearningRate;
        this.warmUpMode = builder.warmUpMode;
        this.warmUpFinalLearningRate = baseLearningRate;
    }

    float getWarmUpLearningRate(int numUpdate) {
        float learningRate = warmUpBeginLearningRate;
        if (warmUpMode == WarmUpMode.LINEAR) {
            learningRate =
                    warmUpBeginLearningRate
                            + (warmUpFinalLearningRate - warmUpBeginLearningRate)
                                    * numUpdate
                                    / warmUpSteps;
        }
        checkLearningRate(learningRate);
        return learningRate;
    }

    /**
     * Fetches the value of the learning rate after the given number of steps/updates.
     *
     * @param numUpdate the number of steps/updates
     * @return this {@code Builder}
     */
    public abstract float getNewLearningRate(int numUpdate);

    void checkLearningRate(float learningRate) {
        if (Float.isNaN(learningRate)) {
            throw new TrainingDivergedException("Learning rate is Nan.");
        }
    }

    /**
     * Returns a new instance of {@link
     * ai.djl.training.optimizer.learningrate.FactorTracker.Builder} that can build an {@link
     * FactorTracker}.
     *
     * @return the {@link FactorTracker} {@link
     *     ai.djl.training.optimizer.learningrate.FactorTracker.Builder}
     */
    public static FactorTracker.Builder factorTracker() {
        return new FactorTracker.Builder();
    }

    /**
     * Returns a new instance of {@link
     * ai.djl.training.optimizer.learningrate.MultiFactorTracker.Builder} that can build an {@link
     * MultiFactorTracker}.
     *
     * @return the {@link MultiFactorTracker} {@link
     *     ai.djl.training.optimizer.learningrate.MultiFactorTracker.Builder}
     */
    public static MultiFactorTracker.Builder multiFactorTracker() {
        return new MultiFactorTracker.Builder();
    }

    /**
     * Returns a new instance of {@code LearningRateTracker} with the fixed learning rate.
     *
     * @param learningRate the fixed learning rate
     * @return a instance of {@code LearningRateTracker} with the fixed learning rate
     */
    public static LearningRateTracker fixedLearningRate(float learningRate) {
        return FixedLearningRate.builder().optBaseLearningRate(learningRate).build();
    }

    /** The Builder to construct a {@link LearningRateTracker}. */
    @SuppressWarnings("rawtypes")
    public abstract static class LrBaseBuilder {

        float baseLearningRate = 0.01f;
        int warmUpSteps;
        float warmUpBeginLearningRate;
        WarmUpMode warmUpMode = WarmUpMode.LINEAR;

        /**
         * Sets the base learning rate.
         *
         * @param baseLearningRate the base learning rate
         * @return this {@code Builder}
         */
        public T optBaseLearningRate(float baseLearningRate) {
            this.baseLearningRate = baseLearningRate;
            return self();
        }

        /**
         * Sets the number of steps until the point the learning rate is updated in warm-up mode.
         *
         * @param warmUpSteps the number of steps until the point the learning rate is updated in
         *     warm-up mode
         * @return this {@code Builder}
         */
        public T optWarmUpSteps(int warmUpSteps) {
            this.warmUpSteps = warmUpSteps;
            return self();
        }

        /**
         * Sets the value of the learning rate at the beginning of warm-up mode.
         *
         * @param warmUpBeginLearningRate the value of the learning rate at the beginning of warm-up
         *     mode
         * @return this {@code Builder}
         */
        public T optWarmUpBeginLearningRate(float warmUpBeginLearningRate) {
            this.warmUpBeginLearningRate = warmUpBeginLearningRate;
            return self();
        }

        /**
         * Sets the {@link WarmUpMode} for the {@link LearningRateTracker}.
         *
         * @param warmUpMode the {@link WarmUpMode} to be set
         * @return this {@code Builder}
         */
        public T optWarmUpMode(WarmUpMode warmUpMode) {
            this.warmUpMode = warmUpMode;
            return self();
        }

        protected abstract T self();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy