All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.optimize.solvers.StochasticGradientDescent Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.optimize.solvers;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;

import java.util.Collection;

@Slf4j
public class StochasticGradientDescent extends BaseOptimizer {


    public StochasticGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction,
                                     Collection trainingListeners, Model model) {
        super(conf, stepFunction, trainingListeners, model);
    }


    @Override
    public boolean optimize(LayerWorkspaceMgr workspaceMgr) {
        if (accumulator != null) {
            // before going FF, we're checking if there are any updates available
            if (accumulator.hasAnything()) {
                log.info("Applying external updates before FF...");

                // we'll just fire off params update process
                accumulator.applyUpdate(stepFunction, model.params(), Nd4j.createUninitialized(model.params().shape(), model.params().ordering()), false);
            }
        }

        Pair pair = gradientAndScore(workspaceMgr);

        Gradient gradient = pair.getFirst();

        INDArray params = model.params();
        INDArray fullGrad = gradient.gradient();
        fullGrad = fullGrad.reshape(fullGrad.length());
        // if optimizer has GradientsAccumulator defined - go for it
        if (accumulator != null) {
            // we're propagating current update
            int epochNum = 0;
            int iterationNum = 0;

            if (model instanceof MultiLayerNetwork) {
                iterationNum = ((MultiLayerNetwork) model).getIterationCount();
                epochNum = ((MultiLayerNetwork) model).getEpochCount();
            } else if (model instanceof ComputationGraph) {
                iterationNum = ((ComputationGraph) model).getIterationCount();
                epochNum = ((ComputationGraph) model).getEpochCount();
            }

            accumulator.storeUpdate(fullGrad, iterationNum, epochNum);

            // and getting (possible) pending update from accumulator
            //INDArray pendingUpdate = accumulator.getUpdate();
            //stepFunction.step(params, pendingUpdate);
            accumulator.applyUpdate(stepFunction, params, fullGrad, true);

            // if there's no update available - just go on then
        } else {
            // if accumulator isn't used - we just to for direct updates application
            stepFunction.step(params, fullGrad);
        }

        //Note: model.params() is always in-place for MultiLayerNetwork and ComputationGraph, hence no setParams is necessary there
        //However: for pretrain layers, params are NOT a view. Thus a setParams call is necessary
        //But setParams should be a no-op for MLN and CG
        model.setParams(params);

        int iterationCount = BaseOptimizer.getIterationCount(model);
        int epochCount = BaseOptimizer.getEpochCount(model);
        try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
            for (TrainingListener listener : trainingListeners)
                listener.iterationDone(model, iterationCount, epochCount);
        }

        BaseOptimizer.incrementIterationCount(model, 1);
        applyConstraints(model);
        return true;
    }

    @Override
    public void preProcessLine() {}

    @Override
    public void postStep(INDArray gradient) {}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy