All Downloads are FREE. Search and download functionalities are using the official Maven repository.

deepboof.impl.backward.standard.BaseDBatchNorm_F64 Maven / Gradle / Ivy

/*
 * Copyright (c) 2016, Peter Abeles. All Rights Reserved.
 *
 * This file is part of DeepBoof
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package deepboof.impl.backward.standard;

import deepboof.DeepBoofConstants;
import deepboof.backward.DBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;

import java.util.List;

/**
 * Implements common functionality for all batch normalization implementations for {@link Tensor_F64}.
 *
 * @author Peter Abeles
 */
public abstract class BaseDBatchNorm_F64 extends BaseDFunction implements DBatchNorm
{
    // If gamma and beta are used as parameters
    protected boolean requiresGammaBeta;

    // storage for mean and standard deviation tensor
    protected Tensor_F64 tensorMean = new Tensor_F64();
    protected Tensor_F64 tensorStd = new Tensor_F64();  // really is sqrt( stdev^2 + eps) ~= stdev
    // storage for the normalized input  (e.g. stdev = 1, mean = 1)
    protected Tensor_F64 tensorXhat = new Tensor_F64();

    // storage for gradient of variance, mean, and others
    protected Tensor_F64 tensorDVar = new Tensor_F64();
    protected Tensor_F64 tensorDMean = new Tensor_F64();
    protected Tensor_F64 tensorDXhat = new Tensor_F64();

    // x[i] - mean(x)
    protected Tensor_F64 tensorDiffX = new Tensor_F64();

    // temporary storage
    protected Tensor_F64 tensorTmp = new Tensor_F64();

    // shape of a tensor that just contains the variables
    protected int[] shapeVariables;

    // number of elements in input tensor (excluding mini-batch)
    protected int D;

    // Internal storage for gamma and beta parameters.  Stored interleaved gamma then beta.  1 for each input variable
    // params = [ gamma[0], beta[0], gamma[1], beta[1],  ... , gamma[D], beta[D]]
    protected Tensor_F64 params = new Tensor_F64(0);
    protected double EPS = DeepBoofConstants.TEST_TOL_F64*0.1;

    public BaseDBatchNorm_F64(boolean requiresGammaBeta) {
        this.requiresGammaBeta = requiresGammaBeta;
    }

    @Override
    public void _initialize() {
        shapeVariables = createShapeVariables(shapeInput);

        tensorMean.reshape(shapeVariables);
        tensorStd.reshape(shapeVariables);
        tensorDVar.reshape(shapeVariables);
        tensorDMean.reshape(shapeVariables);
        tensorTmp.reshape(shapeVariables);

        this.shapeOutput = shapeInput.clone();

        if( requiresGammaBeta ) {
            int shapeParam[] = TensorOps.WI(shapeVariables, 2);
            this.shapeParameters.add(shapeParam);
            params.reshape(shapeParam);
        }

        D = TensorOps.tensorLength(shapeVariables);
    }

    /**
     * Create the shape for all the variables which are being normalized
     * @param shapeInput Shape of input tensor (without mini-batch)
     * @return shape of variables tensor
     */
    protected abstract int[] createShapeVariables(  int shapeInput[] );

    @Override
    public void _setParameters(List parameters) {
        if( requiresGammaBeta ) {
            params.setTo(parameters.get(0));
        } else if( parameters.size() != 0 ){
            throw new IllegalArgumentException("There are no parameters since gamma and beta have been turned off");
        }
    }

    @Override
    public double getEPS() {
        return EPS;
    }

    @Override
    public void setEPS(double EPS) {
        this.EPS = EPS;
    }

    @Override
    public boolean hasGammaBeta() {
        return requiresGammaBeta;
    }

    @Override
    public Class getTensorType() {
        return Tensor_F64.class;
    }

    @Override
    public Tensor_F64 getMean( Tensor_F64 output ) {
        if( output == null )
            output = tensorMean.createLike();

        output.setTo(tensorMean);

        return output;
    }

    @Override
    public Tensor_F64 getVariance( Tensor_F64 output ) {
        if( output == null )
            output = tensorStd.createLike();

        output.reshape(tensorStd.getShape());

        int indexOut = output.startIndex;
        int indexStd = 0;

        int length = tensorStd.length();

        for (int i = 0; i < length; i++) {
            double d = tensorStd.d[indexStd++];
            output.d[indexOut++] = d*d - EPS;
        }

        return output;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy