All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.updater.UpdaterBlock Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.updater;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.ops.transforms.Transforms;

import java.util.ArrayList;
import java.util.List;

/**
 * UpdaterBlock: used in {@link BaseMultiLayerUpdater}, this class implements updating (i.e., Adam, RMSProp, Momentum,
 * etc) across multiple contiguous layers/parameters, as described in the {@link BaseMultiLayerUpdater} javadoc.
 *
 * @author Alex Black
 */
@Data
public class UpdaterBlock {
    private int paramOffsetStart;
    private int paramOffsetEnd;
    private int updaterViewOffsetStart;
    private int updaterViewOffsetEnd;
    private List layersAndVariablesInBlock = new ArrayList<>();

    private INDArray updaterView;
    private INDArray gradientView;
    private boolean updaterViewRequiresInitialization;

    private GradientUpdater gradientUpdater;


    @AllArgsConstructor
    @Data
    public static class ParamState {
        private final Layer layer;
        private final String paramName;
        private final int paramOffsetStart;
        private final int paramOffsetEnd;
        private final INDArray paramView;
        private final INDArray gradView;
    }

    /**
     * @param paramOffsetStart          Start offset of the parameters in this block (relative to overall net params
     *                                  view array)
     * @param paramOffsetEnd            End offset of the parameters in this block (relative to overall net params
     *                                  view array)
     * @param updaterViewOffsetStart    Start offset of the updater state array in this block (relative to overall net
     *                                  updater state view array)
     * @param updaterViewOffsetEnd      End offset of the updater state array in this block (relative to overall net
     *                                  updater state view array)
     * @param layersAndVariablesInBlock List of layers and variables in this updater block. By definition, all layers
     *                                  and variables in this list must have an identical updater configuration.
     */
    public UpdaterBlock(int paramOffsetStart, int paramOffsetEnd, int updaterViewOffsetStart, int updaterViewOffsetEnd,
                    List layersAndVariablesInBlock) {
        this.paramOffsetStart = paramOffsetStart;
        this.paramOffsetEnd = paramOffsetEnd;
        this.updaterViewOffsetStart = updaterViewOffsetStart;
        this.updaterViewOffsetEnd = updaterViewOffsetEnd;
        this.layersAndVariablesInBlock = layersAndVariablesInBlock;
    }

    public void init() {
        if (gradientUpdater == null) {
            ParamState varState = layersAndVariablesInBlock.get(0);
            String varName = varState.getParamName();
            gradientUpdater = varState.getLayer().conf().getLayer().getIUpdaterByParam(varName).instantiate(updaterView,
                            updaterViewRequiresInitialization); //UpdaterUtils.getGradientUpdater(varState.getLayer(), varState.getParamName());
        }
    }

    public boolean isPretrainUpdaterBlock() {
        //All in block should be the same layer, and all be pretrain params
        ParamState vs = layersAndVariablesInBlock.get(0);
        return vs.getLayer().conf().getLayer().isPretrainParam(vs.getParamName());
    }

    public boolean skipDueToPretrainConfig() {
        if (!isPretrainUpdaterBlock())
            return false;
        ParamState vs = layersAndVariablesInBlock.get(0);
        return !vs.getLayer().conf().isPretrain(); //Skip if not pretrain
    }

    public GradientUpdater getGradientUpdater() {
        if (gradientUpdater == null) {
            init();
        }
        return gradientUpdater;
    }

    /**
     * Update the gradient for this block
     *
     * @param iteration The current iteration (i.e., total number of parameter updates so far)
     */
    public void update(int iteration) {
        update(iteration, false, gradientView, null);
    }

    public void updateExternalGradient(int iteration, INDArray fullNetworkGradientView,
                    INDArray fullNetworkParamsArray) {
        //Extract the relevant subset from the external network
        update(iteration, true, fullNetworkGradientView, fullNetworkParamsArray);
    }

    private void update(int iteration, boolean externalGradient, INDArray fullNetworkGradientView,
                    INDArray fullNetworkParamsArray) {
        //Initialize the updater, if necessary
        if (gradientUpdater == null) {
            init();
        }

        INDArray blockGradViewArray;
        if (externalGradient) {
            blockGradViewArray = fullNetworkGradientView.get(NDArrayIndex.point(0),
                            NDArrayIndex.interval(paramOffsetStart, paramOffsetEnd));
        } else {
            blockGradViewArray = gradientView;
        }

        //First: Pre-apply gradient clipping etc: some are done on a per-layer basis
        //Therefore: it's already done by this point, in MultiLayerUpdater or ComputationGraphUpdater

        //Second: apply learning rate policy. Note that by definition we have the same LR policy for every single
        // variable in the block
        Layer l0 = layersAndVariablesInBlock.get(0).getLayer();
        if (!(l0.conf().getLayer() instanceof BaseLayer)) {
            //No params for this layer
            return;
        }
        BaseLayer baseLayer = (BaseLayer) l0.conf().getLayer();
        LearningRatePolicy lrPolicy = l0.conf().getLearningRatePolicy();
        if (lrPolicy != LearningRatePolicy.None || baseLayer.getIUpdater() instanceof Nesterovs) {
            applyLrDecayPolicy(lrPolicy, iteration);
        }

        //Apply the updater itself
        gradientUpdater.applyUpdater(blockGradViewArray, iteration);

        //Post apply: l1 and l2 by params
        for (ParamState p : layersAndVariablesInBlock) {
            INDArray paramView;
            INDArray gradView;
            if (externalGradient) {
                paramView = fullNetworkParamsArray.get(NDArrayIndex.point(0),
                                NDArrayIndex.interval(p.getParamOffsetStart(), p.getParamOffsetEnd()));
                gradView = fullNetworkGradientView.get(NDArrayIndex.point(0),
                                NDArrayIndex.interval(p.getParamOffsetStart(), p.getParamOffsetEnd()));
            } else {
                //Standard case
                paramView = p.getParamView();
                gradView = p.getGradView();
            }
            postApply(p.getLayer(), p.getParamName(), gradView, paramView);
        }
    }

    /**
     * Apply L1 and L2 regularization, if necessary. Note that L1/L2 may differ for different layers in the same block
     *
     * @param layer        The layer to apply L1/L2 to
     * @param paramName    Parameter name in the given layer
     * @param gradientView Gradient view array for the layer + param
     * @param paramsView   Parameter view array for the layer + param
     */
    public void postApply(Layer layer, String paramName, INDArray gradientView, INDArray paramsView) {
        NeuralNetConfiguration conf = layer.conf();

        //TODO: do this for multiple contiguous params/layers (fewer, larger ops)

        double l2 = conf.getL2ByParam(paramName);
        if (conf.isUseRegularization() && l2 > 0) {
            //This can be an axpy op, saving an allocation...
            //gradientView += params * l2           i.e., dC/dw = dC0/dw + lambda/n * w where C0 is pre-l2 cost function
            //Equivalent to gradientView.addi(paramsView.mul(conf.getL2ByParam(paramName)));
            int length = gradientView.length();
            Nd4j.getBlasWrapper().level1().axpy(length, l2, paramsView, gradientView);
        }
        if (conf.isUseRegularization() && conf.getL1ByParam(paramName) > 0) {
            gradientView.addi(Transforms.sign(paramsView, true).muli(conf.getL1ByParam(paramName)));
        }
    }

    /**
     * Apply learning rate decay, based on the configuration
     *
     * @param decay     Learning rate schedule enumeration
     * @param iteration Current iteration
     */
    public void applyLrDecayPolicy(LearningRatePolicy decay, int iteration) {
        Layer layer = layersAndVariablesInBlock.get(0).getLayer();
        String variable = layersAndVariablesInBlock.get(0).getParamName();

        NeuralNetConfiguration conf = layer.conf();
        double decayRate = layer.conf().getLrPolicyDecayRate();
        double lr = conf.getLearningRateByParam(variable);

        if (!(conf.getLayer() instanceof BaseLayer)) {
            //No params
            return;
        }

        BaseLayer baseLayer = (BaseLayer) conf.getLayer();

        double newLr;
        switch (decay) {
            case Exponential:
                newLr = lr * Math.pow(decayRate, iteration);
                break;
            case Inverse:
                newLr = lr / Math.pow((1 + decayRate * iteration), conf.getLrPolicyPower());
                break;
            case Step:
                newLr = lr * Math.pow(decayRate, Math.floor(iteration / conf.getLrPolicySteps()));
                break;
            case TorchStep:
                if (iteration > 1 && conf.getLrPolicySteps() % iteration == 0) {
                    newLr = lr * decayRate;
                } else {
                    newLr = lr;
                }
                break;
            case Poly:
                newLr = lr * Math.pow((1 - ((double) iteration) / conf.getNumIterations()), conf.getLrPolicyPower());
                break;
            case Sigmoid:
                newLr = lr / (1 + Math.exp(-decayRate * (iteration - conf.getLrPolicySteps())));
                break;
            case Schedule:
                if (baseLayer.getLearningRateSchedule().containsKey(iteration)) {
                    newLr = baseLayer.getLearningRateSchedule().get(iteration);
                } else {
                    newLr = lr;
                }
                break;
            case None:
            case Score:
                newLr = lr;
                break;
            default:
                throw new RuntimeException("Unknown Learning rate decay value: " + decay);
        }

        //Handle momentum schedules. Given the new updater design, this change is purely cosmetic
        double newMomentum = 0.0;
        if (baseLayer.getIUpdater() instanceof Nesterovs) {
            if (baseLayer.getMomentumSchedule() != null && baseLayer.getMomentumSchedule().containsKey(iteration)) {
                newMomentum = baseLayer.getMomentumSchedule().get(iteration);
            } else {
                newMomentum = baseLayer.getMomentum();
            }
        }

        //Need to set the LR for *all* variables in the Updater block. All variables (by definition of being in the
        // same block) share the same LR schedule
        for (ParamState vs : layersAndVariablesInBlock) {
            vs.getLayer().conf().setLearningRateByParam(vs.getParamName(), newLr);
            if (((BaseLayer) layer.conf().getLayer()).getIUpdater() instanceof Nesterovs) {
                ((BaseLayer) vs.getLayer().conf().getLayer()).setMomentum(newMomentum);
            }
        }

        //Apply the new LR according to the schedule.
        //Note: momentum schedules are applied internally in the Nesterov config object applySchedules method
        gradientUpdater.getConfig().applySchedules(iteration, newLr);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy