All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.params.BatchNormalizationParamInitializer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.params;

import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.*;

/**
 * Batch normalization variable init
 */

public class BatchNormalizationParamInitializer implements ParamInitializer {

    private static final BatchNormalizationParamInitializer INSTANCE = new BatchNormalizationParamInitializer();

    public static BatchNormalizationParamInitializer getInstance() {
        return INSTANCE;
    }

    public static final String GAMMA = "gamma";
    public static final String BETA = "beta";
    public static final String GLOBAL_MEAN = "mean";
    public static final String GLOBAL_VAR = "var";

    public static List keys() {
        return Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_VAR);
    }

    @Override
    public int numParams(NeuralNetConfiguration conf) {
        return numParams(conf.getLayer());
    }

    @Override
    public int numParams(Layer l) {
        BatchNormalization layer = (BatchNormalization) l;
        //Parameters in batch norm:
        //gamma, beta, global mean estimate, global variance estimate
        // latter 2 are treated as parameters, which greatly simplifies spark training and model serialization

        if (layer.isLockGammaBeta()) {
            //Special case: gamma and beta are fixed values for all outputs -> no parameters for gamma and  beta in this case
            return 2 * layer.getNOut();
        } else {
            //Standard case: gamma and beta are learned per output; additional 2*nOut for global mean/variance estimate
            return 4 * layer.getNOut();
        }
    }

    @Override
    public Map init(NeuralNetConfiguration conf, INDArray paramView, boolean initializeParams) {
        Map params = Collections.synchronizedMap(new LinkedHashMap());
        // TODO setup for RNN
        BatchNormalization layer = (BatchNormalization) conf.getLayer();
        int nOut = layer.getNOut();

        int meanOffset = 0;
        if (!layer.isLockGammaBeta()) { //No gamma/beta parameters when gamma/beta are locked
            INDArray gammaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut));
            INDArray betaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut));

            params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
            conf.addVariable(GAMMA);
            params.put(BETA, createBeta(conf, betaView, initializeParams));
            conf.addVariable(BETA);

            meanOffset = 2 * nOut;
        }

        INDArray globalMeanView =
                        paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut));
        INDArray globalVarView = paramView.get(NDArrayIndex.point(0),
                        NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut));

        if (initializeParams) {
            globalMeanView.assign(0);
            globalVarView.assign(1);
        }

        params.put(GLOBAL_MEAN, globalMeanView);
        conf.addVariable(GLOBAL_MEAN);
        params.put(GLOBAL_VAR, globalVarView);
        conf.addVariable(GLOBAL_VAR);

        return params;
    }

    @Override
    public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        BatchNormalization layer = (BatchNormalization) conf.getLayer();
        int nOut = layer.getNOut();

        Map out = new LinkedHashMap<>();
        int meanOffset = 0;
        if (!layer.isLockGammaBeta()) {
            INDArray gammaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut));
            INDArray betaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut));
            out.put(GAMMA, gammaView);
            out.put(BETA, betaView);
            meanOffset = 2 * nOut;
        }

        out.put(GLOBAL_MEAN,
                        gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut)));
        out.put(GLOBAL_VAR, gradientView.get(NDArrayIndex.point(0),
                        NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut)));

        return out;
    }

    private INDArray createBeta(NeuralNetConfiguration conf, INDArray betaView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization) conf.getLayer();
        if (initializeParams)
            betaView.assign(layer.getBeta());
        return betaView;
    }

    private INDArray createGamma(NeuralNetConfiguration conf, INDArray gammaView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization) conf.getLayer();
        if (initializeParams)
            gammaView.assign(layer.getGamma());
        return gammaView;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy