All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.LayerValidation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.conf.layers;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.learning.*;

import java.util.HashMap;
import java.util.Map;

/**
 * Created by Alex on 22/02/2017.
 */
@Slf4j
public class LayerValidation {

    /**
     * Validate the updater configuration - setting the default updater values, if necessary
     */
    public static void updaterValidation(String layerName, Layer layer, Double momentum,
                    Map momentumSchedule, Double adamMeanDecay, Double adamVarDecay, Double rho,
                    Double rmsDecay, Double epsilon) {
        updaterValidation(layerName, layer, momentum == null ? Double.NaN : momentum, momentumSchedule,
                        adamMeanDecay == null ? Double.NaN : adamMeanDecay,
                        adamVarDecay == null ? Double.NaN : adamVarDecay, rho == null ? Double.NaN : rho,
                        rmsDecay == null ? Double.NaN : rmsDecay, epsilon == null ? Double.NaN : epsilon);

    }

    /**
     * Validate the updater configuration - setting the default updater values, if necessary
     */
    public static void updaterValidation(String layerName, Layer layer, double momentum,
                    Map momentumSchedule, double adamMeanDecay, double adamVarDecay, double rho,
                    double rmsDecay, double epsilon) {
        if ((!Double.isNaN(momentum) || !Double.isNaN(layer.getMomentum())) && layer.getUpdater() != Updater.NESTEROVS)
            log.warn("Layer \"" + layerName
                            + "\" momentum has been set but will not be applied unless the updater is set to NESTEROVS.");
        if ((momentumSchedule != null || layer.getMomentumSchedule() != null)
                        && layer.getUpdater() != Updater.NESTEROVS)
            log.warn("Layer \"" + layerName
                            + "\" momentum schedule has been set but will not be applied unless the updater is set to NESTEROVS.");
        if ((!Double.isNaN(adamVarDecay) || (!Double.isNaN(layer.getAdamVarDecay())))
                        && layer.getUpdater() != Updater.ADAM)
            log.warn("Layer \"" + layerName
                            + "\" adamVarDecay is set but will not be applied unless the updater is set to Adam.");
        if ((!Double.isNaN(adamMeanDecay) || !Double.isNaN(layer.getAdamMeanDecay()))
                        && layer.getUpdater() != Updater.ADAM)
            log.warn("Layer \"" + layerName
                            + "\" adamMeanDecay is set but will not be applied unless the updater is set to Adam.");
        if ((!Double.isNaN(rho) || !Double.isNaN(layer.getRho())) && layer.getUpdater() != Updater.ADADELTA)
            log.warn("Layer \"" + layerName
                            + "\" rho is set but will not be applied unless the updater is set to ADADELTA.");
        if ((!Double.isNaN(rmsDecay) || (!Double.isNaN(layer.getRmsDecay()))) && layer.getUpdater() != Updater.RMSPROP)
            log.warn("Layer \"" + layerName
                            + "\" rmsdecay is set but will not be applied unless the updater is set to RMSPROP.");

        switch (layer.getUpdater()) {
            case NESTEROVS:
                if (Double.isNaN(momentum) && Double.isNaN(layer.getMomentum())) {
                    layer.setMomentum(Nesterovs.DEFAULT_NESTEROV_MOMENTUM);
                    log.warn("Layer \"" + layerName + "\" momentum is automatically set to "
                                    + Nesterovs.DEFAULT_NESTEROV_MOMENTUM
                                    + ". Add momentum to configuration to change the value.");
                } else if (Double.isNaN(layer.getMomentum()))
                    layer.setMomentum(momentum);
                if (momentumSchedule != null && layer.getMomentumSchedule() == null)
                    layer.setMomentumSchedule(momentumSchedule);
                else if (momentumSchedule == null && layer.getMomentumSchedule() == null)
                    layer.setMomentumSchedule(new HashMap());
                break;
            case ADAM:
                if (Double.isNaN(adamMeanDecay) && Double.isNaN(layer.getAdamMeanDecay())) {
                    layer.setAdamMeanDecay(Adam.DEFAULT_ADAM_BETA1_MEAN_DECAY);
                    log.warn("Layer \"" + layerName + "\" adamMeanDecay is automatically set to "
                                    + Adam.DEFAULT_ADAM_BETA1_MEAN_DECAY
                                    + ". Add adamVarDecay to configuration to change the value.");
                } else if (Double.isNaN(layer.getAdamMeanDecay()))
                    layer.setAdamMeanDecay(adamMeanDecay);

                if (Double.isNaN(adamVarDecay) && Double.isNaN(layer.getAdamVarDecay())) {
                    layer.setAdamVarDecay(Adam.DEFAULT_ADAM_BETA2_VAR_DECAY);
                    log.warn("Layer \"" + layerName + "\" adamVarDecay is automatically set to "
                                    + Adam.DEFAULT_ADAM_BETA2_VAR_DECAY
                                    + ". Add adamVarDecay to configuration to change the value.");
                } else if (Double.isNaN(layer.getAdamVarDecay()))
                    layer.setAdamVarDecay(adamVarDecay);

                if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(Adam.DEFAULT_ADAM_EPSILON);
                } else if (Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(epsilon);
                }
                break;
            case ADADELTA:
                if (Double.isNaN(layer.getRho()))
                    layer.setRho(rho);

                if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(AdaDelta.DEFAULT_ADADELTA_EPSILON);
                    log.warn("Layer \"" + layerName + "\" AdaDelta epsilon is automatically set to "
                                    + AdaDelta.DEFAULT_ADADELTA_EPSILON
                                    + ". Add epsilon to configuration to change the value.");
                } else if (Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(epsilon);
                }
                break;
            case ADAGRAD:
                if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(AdaGrad.DEFAULT_ADAGRAD_EPSILON);
                } else if (Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(epsilon);
                }
                break;
            case RMSPROP:
                if (Double.isNaN(rmsDecay) && Double.isNaN(layer.getRmsDecay())) {
                    layer.setRmsDecay(RmsProp.DEFAULT_RMSPROP_RMSDECAY);
                    log.warn("Layer \"" + layerName
                                    + "\" rmsDecay is automatically set to 0.95. Add rmsDecay to configuration to change the value.");
                } else if (Double.isNaN(layer.getRmsDecay()))
                    layer.setRmsDecay(rmsDecay);

                if (Double.isNaN(epsilon) && Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(RmsProp.DEFAULT_RMSPROP_EPSILON);
                } else if (Double.isNaN(layer.getEpsilon())) {
                    layer.setEpsilon(epsilon);
                }

                break;
        }
    }

    public static void generalValidation(String layerName, Layer layer, boolean useRegularization,
                    boolean useDropConnect, Double dropOut, Double l2, Double l2Bias, Double l1, Double l1Bias,
                    Distribution dist) {
        generalValidation(layerName, layer, useRegularization, useDropConnect, dropOut == null ? 0.0 : dropOut,
                        l2 == null ? Double.NaN : l2, l2Bias == null ? Double.NaN : l2Bias,
                        l1 == null ? Double.NaN : l1, l1Bias == null ? Double.NaN : l1Bias, dist);
    }

    public static void generalValidation(String layerName, Layer layer, boolean useRegularization,
                    boolean useDropConnect, double dropOut, double l2, double l2Bias, double l1, double l1Bias,
                    Distribution dist) {
        if (useDropConnect && (Double.isNaN(dropOut) && (Double.isNaN(layer.getDropOut()))))
            log.warn("Layer \"" + layerName
                            + "\" dropConnect is set to true but dropout rate has not been added to configuration.");
        if (useDropConnect && dropOut == 0.0)
            log.warn("Layer \"" + layerName + " dropConnect is set to true but dropout rate is set to 0.0");
        if (useRegularization && (Double.isNaN(l1) && layer != null && Double.isNaN(layer.getL1()) && Double.isNaN(l2)
                        && Double.isNaN(layer.getL2()) && Double.isNaN(l2Bias) && Double.isNaN(l1Bias)
                        && (Double.isNaN(dropOut) || dropOut == 0.0)
                        && (Double.isNaN(layer.getDropOut()) || layer.getDropOut() == 0.0)))
            log.warn("Layer \"" + layerName
                            + "\" regularization is set to true but l1, l2 or dropout has not been added to configuration.");

        if (layer != null) {
            if (useRegularization) {
                if (!Double.isNaN(l1) && Double.isNaN(layer.getL1())) {
                    layer.setL1(l1);
                }
                if (!Double.isNaN(l2) && Double.isNaN(layer.getL2())) {
                    layer.setL2(l2);
                }
                if (!Double.isNaN(l1Bias) && Double.isNaN(layer.getL1Bias())) {
                    layer.setL1Bias(l1Bias);
                }
                if (!Double.isNaN(l2Bias) && Double.isNaN(layer.getL2Bias())) {
                    layer.setL2Bias(l2Bias);
                }
            } else if (!useRegularization && ((!Double.isNaN(l1) && l1 > 0.0)
                            || (!Double.isNaN(layer.getL1()) && layer.getL1() > 0.0) || (!Double.isNaN(l2) && l2 > 0.0)
                            || (!Double.isNaN(layer.getL2()) && layer.getL2() > 0.0)
                            || (!Double.isNaN(l1Bias) && l1Bias > 0.0)
                            || (!Double.isNaN(layer.getL1Bias()) && layer.getL1Bias() > 0.0)
                            || (!Double.isNaN(l2Bias) && l2Bias > 0.0)
                            || (!Double.isNaN(layer.getL2Bias()) && layer.getL2Bias() > 0.0))) {
                log.warn("Layer \"" + layerName
                                + "\" l1 or l2 has been added to configuration but useRegularization is set to false.");
            }

            if (Double.isNaN(l2) && Double.isNaN(layer.getL2())) {
                layer.setL2(0.0);
            }
            if (Double.isNaN(l1) && Double.isNaN(layer.getL1())) {
                layer.setL1(0.0);
            }
            if (Double.isNaN(l2Bias) && Double.isNaN(layer.getL2Bias())) {
                layer.setL2Bias(0.0);
            }
            if (Double.isNaN(l1Bias) && Double.isNaN(layer.getL1Bias())) {
                layer.setL1Bias(0.0);
            }


            if (layer.getWeightInit() == WeightInit.DISTRIBUTION) {
                if (dist != null && layer.getDist() == null)
                    layer.setDist(dist);
                else if (dist == null && layer.getDist() == null) {
                    layer.setDist(new NormalDistribution(0, 1));
                    log.warn("Layer \"" + layerName
                                    + "\" distribution is automatically set to normalize distribution with mean 0 and variance 1.");
                }
            } else if ((dist != null || layer.getDist() != null)) {
                log.warn("Layer \"" + layerName
                                + "\" distribution is set but will not be applied unless weight init is set to WeighInit.DISTRIBUTION.");
            }
        }

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy