All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.optimizer.Optimizer Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.optimizer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

/**
 * An {@code Optimizer} updates the weight parameters to minimize the loss function. {@code
 * Optimizer} is an abstract class that provides the base implementation for optimizers.
 *
 * @see The D2L chapters on
 *     optimization algorithms
 */
public abstract class Optimizer {

    protected float rescaleGrad;
    protected float clipGrad;
    private float weightDecays;
    private int beginNumUpdate;
    private int numUpdate;
    private Map updateCounts = new ConcurrentHashMap<>();

    /**
     * Creates a new instance of {@code Optimizer}.
     *
     * @param builder the builder used to create an instance of {@code Optimizer}
     */
    public Optimizer(OptimizerBuilder builder) {
        this.rescaleGrad = builder.rescaleGrad;
        this.weightDecays = builder.weightDecays;
        this.clipGrad = builder.clipGrad;
        this.beginNumUpdate = builder.beginNumUpdate;
    }

    /**
     * Returns a new instance of {@link ai.djl.training.optimizer.Sgd.Builder} that can build an
     * {@link Sgd} optimizer.
     *
     * @return the {@link Sgd} {@link ai.djl.training.optimizer.Sgd.Builder}
     */
    public static Sgd.Builder sgd() {
        return new Sgd.Builder();
    }

    /**
     * Returns a new instance of {@link ai.djl.training.optimizer.Nag.Builder} that can build an
     * {@link Nag} optimizer.
     *
     * @return the {@link Nag} {@link ai.djl.training.optimizer.Nag.Builder}
     */
    public static Nag.Builder nag() {
        return new Nag.Builder();
    }

    /**
     * Returns a new instance of {@link ai.djl.training.optimizer.Adam.Builder} that can build an
     * {@link Adam} optimizer.
     *
     * @return the {@link Adam} {@link ai.djl.training.optimizer.Adam.Builder}
     */
    public static Adam.Builder adam() {
        return new Adam.Builder();
    }

    /**
     * Returns a new instance of {@link ai.djl.training.optimizer.AdamW.Builder} that can build an
     * {@link AdamW} optimizer.
     *
     * @return the {@link AdamW} {@link ai.djl.training.optimizer.AdamW.Builder}
     */
    public static AdamW.Builder adamW() {
        return new AdamW.Builder();
    }

    /**
     * Returns a new instance of {@link RmsProp.Builder} that can build an {@link RmsProp}
     * optimizer.
     *
     * @return the {@link RmsProp} {@link RmsProp.Builder}
     */
    public static RmsProp.Builder rmsprop() {
        return new RmsProp.Builder();
    }

    /**
     * Returns a new instance of {@link ai.djl.training.optimizer.Adagrad.Builder} that can build an
     * {@link Adagrad} optimizer.
     *
     * @return the {@link Adagrad} {@link ai.djl.training.optimizer.Adagrad.Builder}
     */
    public static Adagrad.Builder adagrad() {
        return new Adagrad.Builder();
    }

    /**
     * Returns a new instance of {@link ai.djl.training.optimizer.Adadelta.Builder} that can build
     * an {@link Adadelta} optimizer.
     *
     * @return the {@link Adadelta} {@link ai.djl.training.optimizer.Adadelta.Builder}
     */
    public static Adadelta.Builder adadelta() {
        return new Adadelta.Builder();
    }

    /**
     * Gets the value of weight decay.
     *
     * @return the value of weight decay
     */
    protected float getWeightDecay() {
        return weightDecays;
    }

    protected int updateCount(String parameterId) {
        // if index exists, increment update count, if not, use begin number of update + 1
        int count =
                updateCounts.compute(
                        parameterId, (key, val) -> (val == null) ? beginNumUpdate + 1 : val + 1);
        numUpdate = Math.max(numUpdate, count);
        return numUpdate;
    }

    /**
     * Updates the parameters according to the gradients.
     *
     * @param parameterId the parameter to be updated
     * @param weight the weights of the parameter
     * @param grad the gradients
     */
    public abstract void update(String parameterId, NDArray weight, NDArray grad);

    protected NDArray withDefaultState(
            Map> state,
            String key,
            Device device,
            Function defaultFunction) {
        Map arrayMap =
                state.computeIfAbsent(
                        key,
                        k -> {
                            Map map = new ConcurrentHashMap<>();
                            NDArray s = defaultFunction.apply(k);
                            // TODO attach s to the NDManager of ParameterStore
                            s.detach(); // s is detached because it would be put into the optimizer
                            // callback manager and closed after the optimizer callback
                            // when using the MxParameterServer. For now, this will let it be closed
                            // by the
                            // GC when the optimizer is out of scope. Ideally, it should be put into
                            // the
                            // trainer manager instead.
                            map.put(device, s);
                            return map;
                        });
        return arrayMap.computeIfAbsent(
                device, k -> arrayMap.values().iterator().next().toDevice(device, true));
    }

    /** The Builder to construct an {@link Optimizer}. */
    @SuppressWarnings("rawtypes")
    public abstract static class OptimizerBuilder {

        private float rescaleGrad = 1.0f;
        private float weightDecays;
        private float clipGrad = -1;
        private int beginNumUpdate;

        protected OptimizerBuilder() {}

        /**
         * Sets the value used to rescale the gradient. This is used to alleviate the effect of
         * batching on the loss. Usually, the value is set to \( 1/batch_size \). Defaults to 1.
         *
         * @param rescaleGrad the value used to rescale the gradient
         * @return this {@code Builder}
         */
        public T setRescaleGrad(float rescaleGrad) {
            this.rescaleGrad = rescaleGrad;
            return self();
        }

        /**
         * Sets the value of weight decay. Weight decay augments the objective function with a
         * regularization term that penalizes large weights.
         *
         * @param weightDecays the value of weight decay to be set
         * @return this {@code Builder}
         */
        public T optWeightDecays(float weightDecays) {
            this.weightDecays = weightDecays;
            return self();
        }

        /**
         * Sets the value of the \(clipGrad\). Clips the gradient to the range of \([-clipGrad,
         * clipGrad]\). If \(clipGrad \lt 0\), gradient clipping is turned off.
         *
         * 

\(grad = max(min(grad, clipGrad), -clipGrad)\) * * @param clipGrad the value of \(clipGrad\) * @return this {@code Builder} */ public T optClipGrad(float clipGrad) { this.clipGrad = clipGrad; return self(); } /** * Sets the initial value of the number of updates. * * @param beginNumUpdate the initial value of the number of updates * @return this {@code Builder} */ public T optBeginNumUpdate(int beginNumUpdate) { this.beginNumUpdate = beginNumUpdate; return self(); } protected abstract T self(); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy