All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.updater.UpdaterUtils Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.updater;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.nd4j.linalg.learning.config.IUpdater;

import java.util.Objects;

/**
 * Created by Alex on 14/04/2017.
 */
public class UpdaterUtils {


    public static boolean updaterConfigurationsEquals(Layer layer1, String param1, Layer layer2, String param2) {
        org.deeplearning4j.nn.conf.layers.Layer l1 = layer1.conf().getLayer();
        org.deeplearning4j.nn.conf.layers.Layer l2 = layer2.conf().getLayer();
        IUpdater u1 = l1.getIUpdaterByParam(param1);
        IUpdater u2 = l2.getIUpdaterByParam(param2);
        if (!u1.equals(u2)) {
            //Different updaters or different config
            return false;
        }
        //For updaters to be equal (and hence combinable), we require that:
        //(a) The updater-specific configurations are equal (inc. LR)
        //(b) The learning rate *schedules* are equal
        //(c) If one or more of the params are pretrainable params, they are in the same layer
        //    This last point is necessary as we don't want to modify the pretrain gradient/updater state during
        //    backprop, or modify the pretrain gradient/updater state of one layer while training another
        if (!lrSchedulesEqual(layer1, param1, layer2, param2)) {
            return false;
        }

        boolean isPretrainParam1 = layer1.conf().getLayer().isPretrainParam(param1);
        boolean isPretrainParam2 = layer2.conf().getLayer().isPretrainParam(param2);
        if (isPretrainParam1 || isPretrainParam2) {
            //One or both of params are pretrainable.
            //Either layers differ -> don't want to combine a pretrain updaters across layers
            //Or one is pretrain and the other isn't -> don't want to combine pretrain updaters within a layer
            return layer1 == layer2 && isPretrainParam1 && isPretrainParam2;
        }

        return true;
    }

    public static boolean lrSchedulesEqual(Layer layer1, String param1, Layer layer2, String param2) {

        LearningRatePolicy lp1 = layer1.conf().getLearningRatePolicy();
        LearningRatePolicy lp2 = layer2.conf().getLearningRatePolicy();

        if (lp1 != lp2) {
            return false;
        }

        double lr1 = layer1.conf().getLearningRateByParam(param1);
        double lr2 = layer2.conf().getLearningRateByParam(param2);
        if (lr1 != lr2) {
            return false;
        }

        double dr1 = layer1.conf().getLrPolicyDecayRate();
        double dr2 = layer2.conf().getLrPolicyDecayRate();

        boolean lrConfigEqual;
        switch (lp1) {
            case None:
                lrConfigEqual = true;
                break;
            case Exponential:
                lrConfigEqual = dr1 == dr2;
                break;
            case Inverse:
                lrConfigEqual = dr1 == dr2 && layer1.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            case Poly:
                lrConfigEqual = layer1.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            case Sigmoid:
                lrConfigEqual = dr1 == dr2 && layer1.conf().getLrPolicySteps() == layer2.conf().getLrPolicySteps();
                break;
            case Step:
                lrConfigEqual = dr1 == dr2 && layer1.conf().getLrPolicySteps() == layer2.conf().getLrPolicySteps();
                break;
            case TorchStep:
                lrConfigEqual = layer1.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            case Schedule:
                BaseLayer bl1 = (BaseLayer) layer1.conf().getLayer();
                BaseLayer bl2 = (BaseLayer) layer2.conf().getLayer();
                lrConfigEqual = Objects.equals(bl1.getLearningRateSchedule(), bl2.getLearningRateSchedule());
                break;
            case Score:
                //TODO - might be ok sometimes??
                lrConfigEqual = false;
                break;
            default:
                throw new UnsupportedOperationException("Unknown learning rate schedule: " + lp1);
        }

        return lrConfigEqual;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy