All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.tracker.CyclicalTracker Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.tracker;

import ai.djl.util.Preconditions;

/**
 * {@code CyclicalTracker} is an implementation of {@link Tracker} which is a policy of learning
 * rate adjustment that increases the learning rate off a base value in a cyclical nature, as
 * detailed in the paper Cyclical Learning Rates for
 * Training Neural Networks.
 */
public class CyclicalTracker implements Tracker {

    private float baseValue;
    private float maxValue;
    private int stepSizeUp;
    private int stepSizeDown;
    private int totalSize;
    private float stepRatio;
    private ScaleFunction scaleFunction;
    private boolean scaleModeCycle;

    /**
     * Creates a new instance of {@code CyclicalTracker}.
     *
     * @param builder the builder to create a new instance of {@code CyclicalTracker}
     */
    public CyclicalTracker(Builder builder) {
        this.baseValue = builder.baseValue;
        this.maxValue = builder.maxValue;
        this.stepSizeUp = builder.stepSizeUp;
        this.stepSizeDown = builder.stepSizeDown > 0 ? builder.stepSizeDown : builder.stepSizeUp;
        this.totalSize = this.stepSizeUp + this.stepSizeDown;
        this.stepRatio = (float) this.stepSizeUp / this.totalSize;
        if (builder.scaleFunction != null) {
            this.scaleFunction = builder.scaleFunction;
            this.scaleModeCycle = builder.scaleModeCycle;
        } else {
            switch (builder.mode) {
                case TRIANGULAR:
                    this.scaleFunction = new TriangularScaleFunction();
                    this.scaleModeCycle = true;
                    break;
                case TRIANGULAR2:
                    this.scaleFunction = new Triangular2ScaleFunction();
                    this.scaleModeCycle = true;
                    break;
                case EXP_RANGE:
                    this.scaleFunction = new ExpRangeScaleFunction(builder.gamma);
                    this.scaleModeCycle = false;
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported Cyclical mode.");
            }
        }
    }

    /**
     * Creates a new builder.
     *
     * @return a new builder
     */
    public static Builder builder() {
        return new Builder();
    }

    /** {@inheritDoc} */
    @Override
    public float getNewValue(int numUpdate) {
        int cycle = (int) Math.floor(1 + (float) numUpdate / this.totalSize);
        float x = 1 + (float) numUpdate / this.totalSize - cycle;
        float scaleFactor;
        float res;
        if (x < this.stepRatio) {
            scaleFactor = x / this.stepRatio;
        } else {
            scaleFactor = (x - 1) / (this.stepRatio - 1);
        }
        float baseHeight = (this.maxValue - this.baseValue) * scaleFactor;
        if (this.scaleModeCycle) {
            res = this.baseValue + baseHeight * this.scaleFunction.func(cycle);
        } else {
            res = this.baseValue + baseHeight * this.scaleFunction.func(numUpdate);
        }
        return res;
    }

    /** The Builder to construct an {@link CyclicalTracker} object. */
    public static final class Builder {

        private float baseValue = 0.001f;
        private float maxValue = 0.006f;
        private int stepSizeUp = 2000;
        private int stepSizeDown;
        private CyclicalMode mode = CyclicalMode.TRIANGULAR;
        private ScaleFunction scaleFunction;
        private boolean scaleModeCycle = true;
        private float gamma = 1f;

        private Builder() {}

        /**
         * Sets the initial value. Default: 0.001.
         *
         * @param baseValue the initial value
         * @return this {@code Builder}
         */
        public Builder optBaseValue(float baseValue) {
            this.baseValue = baseValue;
            return this;
        }

        /**
         * Sets the initial value. Default: 0.006.
         *
         * @param maxValue the max value
         * @return this {@code Builder}
         */
        public Builder optMaxValue(float maxValue) {
            this.maxValue = maxValue;
            return this;
        }

        /**
         * Sets the number of iterations in up half of cycle. Default: 2000.
         *
         * @param stepSizeUp number of iterations in up half of cycle
         * @return this {@code Builder}
         */
        public Builder optStepSizeUp(int stepSizeUp) {
            this.stepSizeUp = stepSizeUp;
            return this;
        }

        /**
         * Sets the number of iterations in up half of cycle. If {@code stepSizeDown} equals 0, it
         * is set to {@code stepSizeUp}. Default: 0.
         *
         * @param stepSizeDown number of iterations in the down half of cycle
         * @return this {@code Builder}
         */
        public Builder optStepSizeDown(int stepSizeDown) {
            this.stepSizeDown = stepSizeDown;
            return this;
        }

        /**
         * Sets the cyclical mode. Can be {@code CyclicalMode.TRIANGULAR}, {@code
         * CyclicalMode.TRIANGULAR2} or {@code CyclicalMode.EXP_RANGE}. Values correspond to
         * policies detailed above.
         *
         * @param mode cyclical mode
         * @return this {@code Builder}
         */
        public Builder optMode(CyclicalMode mode) {
            this.mode = mode;
            return this;
        }

        /**
         * Constant in {@code CyclicalMode.EXP_RANGE} scaling function. Default: 1.
         *
         * @param gamma constant in 'exp_range' scaling function: gamma^cycleIterations
         * @return this {@code Builder}
         */
        public Builder optGamma(float gamma) {
            this.gamma = gamma;
            return this;
        }

        /**
         * Custom scaling function. If set, {@code mode} will be ignored. Default: null.
         *
         * @param scaleFunction Custom scaling function
         * @return this {@code Builder}
         */
        public Builder optScaleFunction(ScaleFunction scaleFunction) {
            this.scaleFunction = scaleFunction;
            return this;
        }

        /**
         * Defines whether scaling function is evaluated on cycle number or update number. Default:
         * true.
         *
         * @param scaleModeCycle if true then scaling function is evaluated on cycle number, else on
         *     update number
         * @return this {@code Builder}
         */
        public Builder optScaleModeCycle(boolean scaleModeCycle) {
            this.scaleModeCycle = scaleModeCycle;
            return this;
        }

        /**
         * Builds a {@link CyclicalTracker} block.
         *
         * @return the {@link CyclicalTracker} block
         */
        public CyclicalTracker build() {
            Preconditions.checkArgument(baseValue > 0, "baseValue has to be positive!");
            Preconditions.checkArgument(maxValue > 0, "maxValue has to be positive!");
            Preconditions.checkArgument(
                    baseValue <= maxValue, "baseValue has to lower than maxValue!");
            Preconditions.checkArgument(stepSizeUp >= 1, "stepSizeUp has to be positive!");
            Preconditions.checkArgument(stepSizeDown >= 0, "stepSizeUp cannot be negative!");
            Preconditions.checkArgument(
                    gamma >= 0f && gamma <= 1f, "gamma has to be between 0 and 1!");
            return new CyclicalTracker(this);
        }
    }

    /**
     * {@code CyclicalTracker} provides three predefined cyclical modes and can be selected by this
     * enum.
     *
     * 
    *
  • TRIANGULAR: A basic triangular cycle without amplitude scaling. *
  • TRIANGULAR2: A basic triangular cycle that scales initial amplitude by half each cycle. *
  • EXP_RANGE: A cycle that scales initial amplitude by (gammacycleIterations) * at each cycle iteration. *
*/ public enum CyclicalMode { TRIANGULAR, TRIANGULAR2, EXP_RANGE } /** {@code ScaleFunction} is an interface to implement a custom scale function. */ public interface ScaleFunction { /** * Custom scaling policy. Return value has to satisfy 0<=func(steps)<=1 for all * steps>1. * * @param steps current cycles if {@code scaleModeCycle} is true; input update number if it * is false * @return scale ratio */ float func(int steps); } private static class TriangularScaleFunction implements ScaleFunction { @Override public float func(int steps) { return 1f; } } private static class Triangular2ScaleFunction implements ScaleFunction { @Override public float func(int steps) { return (float) (1 / (Math.pow(2f, steps - 1))); } } private static class ExpRangeScaleFunction implements ScaleFunction { float gamma; ExpRangeScaleFunction(float gamma) { this.gamma = gamma; } @Override public float func(int steps) { return (float) Math.pow(this.gamma, steps); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy