All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.deep.optimizer.SGD Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.deep.optimizer;

import java.util.Arrays;
import smile.base.mlp.Layer;
import smile.math.MathEx;
import smile.math.TimeFunction;
import smile.math.matrix.Matrix;

/**
 * Stochastic gradient descent (with momentum) optimizer.
 *
 * @author Haifeng Li
 */
public class SGD implements Optimizer {
    /**
     * The learning rate.
     */
    private final TimeFunction learningRate;
    /**
     * The momentum factor.
     */
    private final TimeFunction momentum;

    /**
     * Constructor.
     */
    public SGD() {
        this(TimeFunction.constant(0.01));
    }

    /**
     * Constructor.
     * @param learningRate the learning rate.
     */
    public SGD(TimeFunction learningRate) {
        this(learningRate, null);
    }

    /**
     * Constructor.
     * @param learningRate the learning rate.
     * @param momentum the momentum.
     */
    public SGD(TimeFunction learningRate, TimeFunction momentum) {
        this.learningRate = learningRate;
        this.momentum = momentum;
    }

    @Override
    public String toString() {
        return momentum == null ?
                String.format("SGD(%s)", learningRate) :
                String.format("SGD(%s, %s)", learningRate, momentum);
    }

    @Override
    public void update(Layer layer, int m, int t) {
        /*
        Matrix weightGradient = layer.weightGradient.get();
        double[] biasGradient = layer.biasGradient.get();

        // Instead of computing the average gradient explicitly,
        // we scale down the learning rate by the number of samples.
        double eta = learningRate.apply(t) / m;
        int n = layer.n;

        if (momentum == null) {
            layer.weight.add(eta, weightGradient);
            for (int i = 0; i < n; i++) {
                layer.bias[i] += eta * biasGradient[i];
            }
        } else {
            double alpha = momentum.apply(t);
            Matrix weightUpdate = layer.weightUpdate.get();
            double[] biasUpdate = layer.biasUpdate.get();

            weightUpdate.add(alpha, eta, weightGradient);
            for (int i = 0; i < n; i++) {
                biasUpdate[i] = alpha * biasUpdate[i] + eta * biasGradient[i];
            }

            layer.weight.add(weightUpdate);
            MathEx.add(layer.bias, biasUpdate);
        }

        weightGradient.fill(0.0);
        Arrays.fill(biasGradient, 0.0);

         */
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy