All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.LayerValidation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.conf.layers;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.OneTimeLogger;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.*;
import org.nd4j.linalg.learning.config.*;

import java.util.HashMap;
import java.util.Map;

/**
 * Created by Alex on 22/02/2017.
 */
@Slf4j
public class LayerValidation {

    /**
     * Validate the updater configuration - setting the default updater values, if necessary
     */
    public static void updaterValidation(String layerName, Layer layer, Double learningRate, Double momentum,
                    Map momentumSchedule, Double adamMeanDecay, Double adamVarDecay, Double rho,
                    Double rmsDecay, Double epsilon) {
        BaseLayer bLayer;
        if (layer instanceof FrozenLayer && ((FrozenLayer) layer).getLayer() instanceof BaseLayer) {
            bLayer = (BaseLayer) ((FrozenLayer) layer).getLayer();
        } else if (layer instanceof BaseLayer) {
            bLayer = (BaseLayer) layer;
        } else {
            return;
        }
        updaterValidation(layerName, bLayer, learningRate == null ? Double.NaN : learningRate,
                        momentum == null ? Double.NaN : momentum, momentumSchedule,
                        adamMeanDecay == null ? Double.NaN : adamMeanDecay,
                        adamVarDecay == null ? Double.NaN : adamVarDecay, rho == null ? Double.NaN : rho,
                        rmsDecay == null ? Double.NaN : rmsDecay, epsilon == null ? Double.NaN : epsilon);
    }

    /**
     * Validate the updater configuration - setting the default updater values, if necessary
     */
    public static void updaterValidation(String layerName, BaseLayer layer, double learningRate, double momentum,
                    Map momentumSchedule, double adamMeanDecay, double adamVarDecay, double rho,
                    double rmsDecay, double epsilon) {
        if ((!Double.isNaN(momentum) || !Double.isNaN(layer.getMomentum())) && layer.getUpdater() != Updater.NESTEROVS)
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" momentum has been set but will not be applied unless the updater is set to NESTEROVS.");
        if ((momentumSchedule != null || layer.getMomentumSchedule() != null)
                        && layer.getUpdater() != Updater.NESTEROVS)
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" momentum schedule has been set but will not be applied unless the updater is set to NESTEROVS.");
        if ((!Double.isNaN(adamVarDecay) || (!Double.isNaN(layer.getAdamVarDecay())))
                        && layer.getUpdater() != Updater.ADAM)
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" adamVarDecay is set but will not be applied unless the updater is set to Adam.");
        if ((!Double.isNaN(adamMeanDecay) || !Double.isNaN(layer.getAdamMeanDecay()))
                        && layer.getUpdater() != Updater.ADAM)
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" adamMeanDecay is set but will not be applied unless the updater is set to Adam.");
        if ((!Double.isNaN(rho) || !Double.isNaN(layer.getRho())) && layer.getUpdater() != Updater.ADADELTA)
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" rho is set but will not be applied unless the updater is set to ADADELTA.");
        if ((!Double.isNaN(rmsDecay) || (!Double.isNaN(layer.getRmsDecay()))) && layer.getUpdater() != Updater.RMSPROP)
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" rmsdecay is set but will not be applied unless the updater is set to RMSPROP.");


        //Set values from old (deprecated) .epsilon(), .momentum(), etc methods to the built-in updaters
        //Note that there are *layer* versions (available via the layer) and *global* versions (via the method args)
        //The layer versions take precedence over the global versions. If neither are set, we use whatever is set
        // on the IUpdater instance, which may be the default, or may be user-configured
        //Note that default values for all other parameters are set by default in the Sgd/Adam/whatever classes
        //Hence we don't need to set them here
        //Finally: we'll also set the (updater enumeration field to something sane) to avoid updater=SGD,
        // iupdater=Adam() type situations. Though the updater field isn't used, we don't want to confuse users

        IUpdater u = layer.getIUpdater();
        if (!Double.isNaN(layer.getLearningRate())) {
            //Note that for LRs, if user specifies .learningRate(x).updater(Updater.SGD) (for example), we need to set the
            // LR in the Sgd object. We can do this using the schedules method, which also works for custom updaters
            //Local layer LR set
            u.applySchedules(0, layer.getLearningRate());
        } else if (!Double.isNaN(learningRate)) {
            //Global LR set
            u.applySchedules(0, learningRate);
        }


        if (u instanceof Sgd) {
            layer.setUpdater(Updater.SGD);

        } else if (u instanceof Adam) {
            Adam a = (Adam) u;
            if (!Double.isNaN(layer.getEpsilon())) {
                //user has done legacy .epsilon(...) on the layer itself
                a.setEpsilon(layer.getEpsilon());
            } else if (!Double.isNaN(epsilon)) {
                //user has done legacy .epsilon(...) on MultiLayerNetwork or ComputationGraph
                a.setEpsilon(epsilon);
            }

            if (!Double.isNaN(layer.getAdamMeanDecay())) {
                a.setBeta1(layer.getAdamMeanDecay());
            } else if (!Double.isNaN(adamMeanDecay)) {
                a.setBeta1(adamMeanDecay);
            }

            if (!Double.isNaN(layer.getAdamVarDecay())) {
                a.setBeta2(layer.getAdamVarDecay());
            } else if (!Double.isNaN(adamVarDecay)) {
                a.setBeta2(adamVarDecay);
            }

            layer.setUpdater(Updater.ADAM);

        } else if (u instanceof AdaDelta) {
            AdaDelta a = (AdaDelta) u;

            if (!Double.isNaN(layer.getRho())) {
                a.setRho(layer.getRho());
            } else if (!Double.isNaN(rho)) {
                a.setRho(rho);
            }

            if (!Double.isNaN(layer.getEpsilon())) {
                a.setEpsilon(layer.getEpsilon());
            } else if (!Double.isNaN(epsilon)) {
                a.setEpsilon(epsilon);
            }

            layer.setUpdater(Updater.ADADELTA);

        } else if (u instanceof Nesterovs) {
            Nesterovs n = (Nesterovs) u;
            if (!Double.isNaN(layer.getMomentum())) {
                n.setMomentum(layer.getMomentum());
            } else if (!Double.isNaN(momentum)) {
                n.setMomentum(momentum);
            }

            if (layer.getMomentumSchedule() != null && !layer.getMomentumSchedule().isEmpty()) {
                n.setMomentumSchedule(layer.getMomentumSchedule());
            } else if (momentumSchedule != null && !momentumSchedule.isEmpty()) {
                n.setMomentumSchedule(momentumSchedule);
            }
            layer.setUpdater(Updater.NESTEROVS);

        } else if (u instanceof AdaGrad) {
            AdaGrad a = (AdaGrad) u;
            if (!Double.isNaN(layer.getEpsilon())) {
                a.setEpsilon(layer.getEpsilon());
            } else if (!Double.isNaN(epsilon)) {
                a.setEpsilon(epsilon);
            }

            layer.setUpdater(Updater.ADAGRAD);

        } else if (u instanceof RmsProp) {
            RmsProp r = (RmsProp) u;

            if (!Double.isNaN(layer.getEpsilon())) {
                r.setEpsilon(layer.getEpsilon());
            } else if (!Double.isNaN(epsilon)) {
                r.setEpsilon(epsilon);
            }

            if (!Double.isNaN(layer.getRmsDecay())) {
                r.setRmsDecay(layer.getRmsDecay());
            } else if (!Double.isNaN(rmsDecay)) {
                r.setRmsDecay(rmsDecay);
            }
            layer.setUpdater(Updater.RMSPROP);

        } else if (u instanceof AdaMax) {
            AdaMax a = (AdaMax) u;

            if (!Double.isNaN(layer.getEpsilon())) {
                a.setEpsilon(layer.getEpsilon());
            } else if (!Double.isNaN(epsilon)) {
                a.setEpsilon(epsilon);
            }

            if (!Double.isNaN(layer.getAdamMeanDecay())) {
                a.setBeta1(layer.getAdamMeanDecay());
            } else if (!Double.isNaN(adamMeanDecay)) {
                a.setBeta1(adamMeanDecay);
            }

            if (!Double.isNaN(layer.getAdamVarDecay())) {
                a.setBeta2(layer.getAdamVarDecay());
            } else if (!Double.isNaN(adamVarDecay)) {
                a.setBeta2(adamVarDecay);
            }
            layer.setUpdater(Updater.ADAMAX);

        } else if (u instanceof NoOp) {
            layer.setUpdater(Updater.NONE);
        } else {
            //Probably a custom updater
            layer.setUpdater(null);
        }


        //Finally: Let's set the legacy momentum, epsilon, rmsDecay fields on the layer
        //At this point, it's purely cosmetic, to avoid NaNs etc there that might confuse users
        //The *true* values are now in the IUpdater instances
        if (layer.getUpdater() != null) { //May be null with custom updaters etc
            switch (layer.getUpdater()) {
                case NESTEROVS:
                    if (Double.isNaN(momentum) && Double.isNaN(layer.getMomentum())) {
                        layer.setMomentum(Nesterovs.DEFAULT_NESTEROV_MOMENTUM);
                    } else if (Double.isNaN(layer.getMomentum()))
                        layer.setMomentum(momentum);
                    if (momentumSchedule != null && layer.getMomentumSchedule() == null)
                        layer.setMomentumSchedule(momentumSchedule);
                    else if (momentumSchedule == null && layer.getMomentumSchedule() == null)
                        layer.setMomentumSchedule(new HashMap());
                    break;
                case ADAM:
                    if (Double.isNaN(adamMeanDecay) && Double.isNaN(layer.getAdamMeanDecay())) {
                        layer.setAdamMeanDecay(Adam.DEFAULT_ADAM_BETA1_MEAN_DECAY);
                    } else if (Double.isNaN(layer.getAdamMeanDecay()))
                        layer.setAdamMeanDecay(adamMeanDecay);

                    if (Double.isNaN(adamVarDecay) && Double.isNaN(layer.getAdamVarDecay())) {
                        layer.setAdamVarDecay(Adam.DEFAULT_ADAM_BETA2_VAR_DECAY);
                    } else if (Double.isNaN(layer.getAdamVarDecay()))
                        layer.setAdamVarDecay(adamVarDecay);

                    if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(Adam.DEFAULT_ADAM_EPSILON);
                    } else if (Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(epsilon);
                    }
                    break;
                case ADADELTA:
                    if (Double.isNaN(rho) && Double.isNaN(layer.getRho())) {
                        layer.setRho(AdaDelta.DEFAULT_ADADELTA_RHO);
                    } else if (Double.isNaN(layer.getRho())) {
                        layer.setRho(rho);
                    }

                    if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(AdaDelta.DEFAULT_ADADELTA_EPSILON);
                    } else if (Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(epsilon);
                    }
                    break;
                case ADAGRAD:
                    if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(AdaGrad.DEFAULT_ADAGRAD_EPSILON);
                    } else if (Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(epsilon);
                    }
                    break;
                case RMSPROP:
                    if (Double.isNaN(rmsDecay) && Double.isNaN(layer.getRmsDecay())) {
                        layer.setRmsDecay(RmsProp.DEFAULT_RMSPROP_RMSDECAY);
                    } else if (Double.isNaN(layer.getRmsDecay()))
                        layer.setRmsDecay(rmsDecay);

                    if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(RmsProp.DEFAULT_RMSPROP_EPSILON);
                    } else if (Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(epsilon);
                    }
                    break;
                case ADAMAX:
                    if (Double.isNaN(adamMeanDecay) && Double.isNaN(layer.getAdamMeanDecay())) {
                        layer.setAdamMeanDecay(AdaMax.DEFAULT_ADAMAX_BETA1_MEAN_DECAY);
                    } else if (Double.isNaN(layer.getAdamMeanDecay()))
                        layer.setAdamMeanDecay(adamMeanDecay);

                    if (Double.isNaN(adamVarDecay) && Double.isNaN(layer.getAdamVarDecay())) {
                        layer.setAdamVarDecay(AdaMax.DEFAULT_ADAMAX_BETA2_VAR_DECAY);
                    } else if (Double.isNaN(layer.getAdamVarDecay()))
                        layer.setAdamVarDecay(adamVarDecay);

                    if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(AdaMax.DEFAULT_ADAMAX_EPSILON);
                    } else if (Double.isNaN(layer.getEpsilon())) {
                        layer.setEpsilon(epsilon);
                    }
            }
        }
    }

    public static void generalValidation(String layerName, Layer layer, boolean useRegularization,
                    boolean useDropConnect, Double dropOut, Double l2, Double l2Bias, Double l1, Double l1Bias,
                    Distribution dist) {
        generalValidation(layerName, layer, useRegularization, useDropConnect, dropOut == null ? 0.0 : dropOut,
                        l2 == null ? Double.NaN : l2, l2Bias == null ? Double.NaN : l2Bias,
                        l1 == null ? Double.NaN : l1, l1Bias == null ? Double.NaN : l1Bias, dist);
    }

    public static void generalValidation(String layerName, Layer layer, boolean useRegularization,
                    boolean useDropConnect, double dropOut, double l2, double l2Bias, double l1, double l1Bias,
                    Distribution dist) {

        if (layer != null) {

            if (useDropConnect && (Double.isNaN(dropOut) && (Double.isNaN(layer.getDropOut()))))
                OneTimeLogger.warn(log,"Layer \"" + layerName
                                + "\" dropConnect is set to true but dropout rate has not been added to configuration.");
            if (useDropConnect && layer.getDropOut() == 0.0)
                OneTimeLogger.warn(log,"Layer \"" + layerName + " dropConnect is set to true but dropout rate is set to 0.0");

            if (layer instanceof BaseLayer) {
                BaseLayer bLayer = (BaseLayer) layer;
                configureBaseLayer(layerName, bLayer, useRegularization, useDropConnect, dropOut, l2, l2Bias, l1,
                                l1Bias, dist);
            } else if (layer instanceof FrozenLayer && ((FrozenLayer) layer).getLayer() instanceof BaseLayer) {
                BaseLayer bLayer = (BaseLayer) ((FrozenLayer) layer).getLayer();
                configureBaseLayer(layerName, bLayer, useRegularization, useDropConnect, dropOut, l2, l2Bias, l1,
                                l1Bias, dist);
            }
        }
    }

    private static void configureBaseLayer(String layerName, BaseLayer bLayer, boolean useRegularization,
                    boolean useDropConnect, Double dropOut, Double l2, Double l2Bias, Double l1, Double l1Bias,
                    Distribution dist) {
        if (useRegularization && (Double.isNaN(l1) && Double.isNaN(bLayer.getL1()) && Double.isNaN(l2)
                        && Double.isNaN(bLayer.getL2()) && Double.isNaN(l2Bias) && Double.isNaN(l1Bias)
                        && (Double.isNaN(dropOut) || dropOut == 0.0)
                        && (Double.isNaN(bLayer.getDropOut()) || bLayer.getDropOut() == 0.0)))
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" regularization is set to true but l1, l2 or dropout has not been added to configuration.");

        if (useRegularization) {
            if (!Double.isNaN(l1) && Double.isNaN(bLayer.getL1())) {
                bLayer.setL1(l1);
            }
            if (!Double.isNaN(l2) && Double.isNaN(bLayer.getL2())) {
                bLayer.setL2(l2);
            }
            if (!Double.isNaN(l1Bias) && Double.isNaN(bLayer.getL1Bias())) {
                bLayer.setL1Bias(l1Bias);
            }
            if (!Double.isNaN(l2Bias) && Double.isNaN(bLayer.getL2Bias())) {
                bLayer.setL2Bias(l2Bias);
            }
        } else if (!useRegularization && ((!Double.isNaN(l1) && l1 > 0.0)
                        || (!Double.isNaN(bLayer.getL1()) && bLayer.getL1() > 0.0) || (!Double.isNaN(l2) && l2 > 0.0)
                        || (!Double.isNaN(bLayer.getL2()) && bLayer.getL2() > 0.0)
                        || (!Double.isNaN(l1Bias) && l1Bias > 0.0)
                        || (!Double.isNaN(bLayer.getL1Bias()) && bLayer.getL1Bias() > 0.0)
                        || (!Double.isNaN(l2Bias) && l2Bias > 0.0)
                        || (!Double.isNaN(bLayer.getL2Bias()) && bLayer.getL2Bias() > 0.0))) {
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" l1 or l2 has been added to configuration but useRegularization is set to false.");
        }

        if (Double.isNaN(l2) && Double.isNaN(bLayer.getL2())) {
            bLayer.setL2(0.0);
        }
        if (Double.isNaN(l1) && Double.isNaN(bLayer.getL1())) {
            bLayer.setL1(0.0);
        }
        if (Double.isNaN(l2Bias) && Double.isNaN(bLayer.getL2Bias())) {
            bLayer.setL2Bias(0.0);
        }
        if (Double.isNaN(l1Bias) && Double.isNaN(bLayer.getL1Bias())) {
            bLayer.setL1Bias(0.0);
        }


        if (bLayer.getWeightInit() == WeightInit.DISTRIBUTION) {
            if (dist != null && bLayer.getDist() == null)
                bLayer.setDist(dist);
            else if (dist == null && bLayer.getDist() == null) {
                bLayer.setDist(new NormalDistribution(0, 1));
                OneTimeLogger.warn(log,"Layer \"" + layerName
                                + "\" distribution is automatically set to normalize distribution with mean 0 and variance 1.");
            }
        } else if ((dist != null || bLayer.getDist() != null)) {
            OneTimeLogger.warn(log,"Layer \"" + layerName
                            + "\" distribution is set but will not be applied unless weight init is set to WeighInit.DISTRIBUTION.");
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy