All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.optimize.solvers.BaseOptimizer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.optimize.solvers;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.*;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeGradientStepFunction;
import org.deeplearning4j.optimize.terminations.EpsTermination;
import org.deeplearning4j.optimize.terminations.ZeroDirection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Base optimizer
 * @author Adam Gibson
 */
public abstract class BaseOptimizer implements ConvexOptimizer {

    protected NeuralNetConfiguration conf;
    protected static final Logger log = LoggerFactory.getLogger(BaseOptimizer.class);
    protected StepFunction stepFunction;
    protected Collection iterationListeners = new ArrayList<>();
    protected Collection terminationConditions = new ArrayList<>();
    protected Model model;
    protected BackTrackLineSearch lineMaximizer;
    protected Updater updater;
    protected ComputationGraphUpdater computationGraphUpdater;
    protected double step;
    private int batchSize;
    protected double score,oldScore;
    protected double stepMax = Double.MAX_VALUE;
    public final static String GRADIENT_KEY = "g";
    public final static String SCORE_KEY = "score";
    public final static String PARAMS_KEY = "params";
    public final static String SEARCH_DIR = "searchDirection";
    protected Map searchState = new ConcurrentHashMap<>();

    /**
     *
     * @param conf
     * @param stepFunction
     * @param iterationListeners
     * @param model
     */
    public BaseOptimizer(NeuralNetConfiguration conf,StepFunction stepFunction,Collection iterationListeners,Model model) {
        this(conf, stepFunction, iterationListeners, Arrays.asList(new ZeroDirection(), new EpsTermination()), model);
    }


    /**
     *
     * @param conf
     * @param stepFunction
     * @param iterationListeners
     * @param terminationConditions
     * @param model
     */
    public BaseOptimizer(NeuralNetConfiguration conf,StepFunction stepFunction,Collection iterationListeners,Collection terminationConditions,Model model) {
        this.conf = conf;
        this.stepFunction = (stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(this.getClass()));
        this.iterationListeners = iterationListeners != null ? iterationListeners : new ArrayList();
        this.terminationConditions = terminationConditions;
        this.model = model;
        lineMaximizer = new BackTrackLineSearch(model,this.stepFunction,this);
        lineMaximizer.setStepMax(stepMax);
        lineMaximizer.setMaxIterations(conf.getMaxNumLineSearchIterations());

    }


    @Override
    public double score() {
        model.computeGradientAndScore();
        return model.score();
    }

    @Override
    public Updater getUpdater() {
        if(updater == null) {
            updater = UpdaterCreator.getUpdater(model);
        }
        return updater;
    }

    @Override
    public void setUpdater(Updater updater){
        this.updater = updater;
    }



    @Override
    public ComputationGraphUpdater getComputationGraphUpdater() {
        if(computationGraphUpdater == null && model instanceof ComputationGraph){
            computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph)model);
        }
        return computationGraphUpdater;
    }

    @Override
    public void setUpdaterComputationGraph(ComputationGraphUpdater updater) {
        this.computationGraphUpdater = updater;
    }

    @Override
    public void setListeners(Collection listeners){
        if(listeners == null) this.iterationListeners = Collections.emptyList();
        else this.iterationListeners = listeners;
    }

    @Override
    public NeuralNetConfiguration getConf() { return conf; }

    @Override
    public Pair gradientAndScore() {
        oldScore = score;
        model.computeGradientAndScore();

        if(iterationListeners != null && iterationListeners.size() > 0){
            for(IterationListener l : iterationListeners){
                if(l instanceof TrainingListener){
                    ((TrainingListener)l).onGradientCalculation(model);
                }
            }
        }

        Pair pair = model.gradientAndScore();
        score = pair.getSecond();
        updateGradientAccordingToParams(pair.getFirst(), model, model.batchSize());
        return pair;
    }

    /**
     * Optimize call. This runs the optimizer.
     * @return whether it converged or not
     */
    // TODO add flag to allow retaining state between mini batches and when to apply updates
    @Override
    public  boolean optimize() {
        //validate the input before training
        INDArray gradient;
        INDArray searchDirection;
        INDArray parameters = null;
        model.validateInput();
        Pair pair = gradientAndScore();
        if(searchState.isEmpty()){
            searchState.put(GRADIENT_KEY, pair.getFirst().gradient());
            setupSearchState(pair);		//Only do this once
        } else {
            searchState.put(GRADIENT_KEY, pair.getFirst().gradient());
        }

        //pre existing termination conditions
        /*
         * Commented out for now; this has been problematic for testing/debugging
         * Revisit & re-enable later. */
        for(TerminationCondition condition : terminationConditions){
            if(condition.terminate(0.0,0.0,new Object[]{pair.getFirst().gradient()})) {
                log.info("Hit termination condition " + condition.getClass().getName());
                return true;
            }
        }

        //calculate initial search direction
        preProcessLine();

        for(int i = 0; i < conf.getNumIterations(); i++) {
            gradient = (INDArray) searchState.get(GRADIENT_KEY);
            searchDirection = (INDArray) searchState.get(SEARCH_DIR);
            parameters = (INDArray) searchState.get(PARAMS_KEY);

            //perform one line search optimization
            try {
                step = lineMaximizer.optimize(parameters, gradient, searchDirection);
            } catch (InvalidStepException e) {
                log.warn("Invalid step...continuing another iteration: {}",e.getMessage());
                step = 0.0;
            }

            //Update parameters based on final/best step size returned by line search:
            if(step != 0.0) {
                stepFunction.step(parameters, searchDirection, step);	//Calculate params. given step size
                model.setParams(parameters);
            }else {
                log.debug("Step size returned by line search is 0.0.");
            }

            pair = gradientAndScore();

            //updates searchDirection
            postStep(pair.getFirst().gradient());

            //invoke listeners for debugging
            for(IterationListener listener : iterationListeners)
                listener.iterationDone(model,i);

            //check for termination conditions based on absolute change in score
            checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
            incrementIterationCount(model, 1);
        }
        return true;
    }

    protected  void postFirstStep(INDArray gradient) {
        //no-op
    }

    @Override
    public boolean checkTerminalConditions(INDArray gradient, double oldScore, double score, int i){
        for(TerminationCondition condition : terminationConditions){
            if(condition.terminate(score,oldScore,new Object[]{gradient})){
                log.debug("Hit termination condition on iteration {}: score={}, oldScore={}, condition={}", i, score, oldScore, condition);
                if(condition instanceof EpsTermination && conf.getLayer() != null && conf.getLearningRatePolicy() == LearningRatePolicy.Score) {
                    model.applyLearningRateScoreDecay();
                }
                return true;
            }
        }
        return false;
    }

    @Override
    public int batchSize() {
        return batchSize;
    }

    @Override
    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }


    /**
     * Pre preProcess to setup initial searchDirection approximation
     */
    @Override
    public  void preProcessLine() {
        //no-op
    }
    /**
     * Post step to update searchDirection with new gradient and parameter information
     */
    @Override
    public  void postStep(INDArray gradient) {
        //no-op
    }


    @Override
    public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
        if(model instanceof ComputationGraph){
            ComputationGraph graph = (ComputationGraph)model;
            if(computationGraphUpdater == null){
                computationGraphUpdater = new ComputationGraphUpdater(graph);
            }
            computationGraphUpdater.update(graph, gradient, getIterationCount(model), batchSize);
        } else {

            if (updater == null)
                updater = UpdaterCreator.getUpdater(model);
            Layer layer = (Layer) model;
            updater.update(layer, gradient, getIterationCount(model), batchSize);
        }
    }

    /**
     * Setup the initial search state
     * @param pair
     */
    @Override
    public  void setupSearchState(Pair pair) {
        INDArray gradient = pair.getFirst().gradient(conf.variables());
        INDArray params = model.params().dup(); //Need dup here: params returns an array that isn't a copy (hence changes to this are problematic for line search methods)
        searchState.put(GRADIENT_KEY,gradient);
        searchState.put(SCORE_KEY,pair.getSecond());
        searchState.put(PARAMS_KEY,params);
    }


    public static StepFunction getDefaultStepFunctionForOptimizer( Class optimizerClass ){
        if( optimizerClass == StochasticGradientDescent.class ){
            return new NegativeGradientStepFunction();
        } else {
            return new NegativeDefaultStepFunction();
        }
    }

    public static int getIterationCount(Model model){
        if(model instanceof MultiLayerNetwork){
            return ((MultiLayerNetwork)model).getLayerWiseConfigurations().getIterationCount();
        } else if(model instanceof ComputationGraph){
            return ((ComputationGraph)model).getConfiguration().getIterationCount();
        } else {
            return model.conf().getIterationCount();
        }
    }

    public static void incrementIterationCount(Model model, int incrementBy){
        if(model instanceof MultiLayerNetwork){
            MultiLayerConfiguration conf = ((MultiLayerNetwork)model).getLayerWiseConfigurations();
            conf.setIterationCount(conf.getIterationCount() + incrementBy);
        } else if(model instanceof ComputationGraph){
            ComputationGraphConfiguration conf = ((ComputationGraph)model).getConfiguration();
            conf.setIterationCount(conf.getIterationCount() + incrementBy);
        } else {
            model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy);
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy