All Downloads are FREE. Search and download functionalities are using the official Maven repository.

deepboof.impl.backward.standard.DFunctionBatchNorm_F64 Maven / Gradle / Ivy

/*
 * Copyright (c) 2016, Peter Abeles. All Rights Reserved.
 *
 * This file is part of DeepBoof
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package deepboof.impl.backward.standard;

import deepboof.backward.DFunctionBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;

import java.util.List;

/**
 * Implementation of {@link DFunctionBatchNorm} for {@link Tensor_F64}.  Intermediate variables are cached in the
 * forward pass.
 *
 * @author Peter Abeles
 */
public class DFunctionBatchNorm_F64 extends BaseDBatchNorm_F64
        implements DFunctionBatchNorm
{
	public DFunctionBatchNorm_F64(boolean requiresGammaBeta) {
		super(requiresGammaBeta);
	}

	@Override
	protected int[] createShapeVariables(int[] shapeInput) {
		return shapeInput;
	}

	@Override
	public void _forward(Tensor_F64 input, Tensor_F64 output) {
		if( input.length(0) <= 1 )
			throw new IllegalArgumentException("There must be more than 1 minibatch");

		if( learningMode ) {
			forwardsLearning(input, output);
		} else {
			forwardsEvaluate(input, output);
		}
	}

	private void forwardsLearning(Tensor_F64 input, Tensor_F64 output) {
		tensorDiffX.reshape( input.shape );
		tensorXhat.reshape( input.shape );
		computeStatisticsAndNormalize(input);

		if (requiresGammaBeta) {
			applyGammaBeta(output);
		} else {
			// is gamma and beta are not adjustable then the output is the normalized x_hat
			output.setTo(tensorXhat);
		}
	}

	public void forwardsEvaluate(Tensor_F64 input, Tensor_F64 output) {
		int D = TensorOps.outerLength(input.shape,1);

		int indexIn  = input.startIndex;
		int indexOut = output.startIndex;

		if( requiresGammaBeta ) {
			for (int batch = 0; batch < miniBatchSize; batch++) {
				int indexVar = 0;
				int indexP  = params.startIndex;
				int end = indexIn + D;
				while (indexIn < end) {
					double mean  = tensorMean.d[indexVar];
					double stdev_eps = tensorStd.d[indexVar];
					double gamma = params.d[indexP++];
					double beta  = params.d[indexP++];

					output.d[indexOut++] = (input.d[indexIn++] - mean)*(gamma / stdev_eps) + beta;
					indexVar++;
				}
			}
		} else {
			for (int stack = 0; stack < miniBatchSize; stack++) {
				int indexVar = 0;
				int end = indexIn + D;
				while (indexIn < end) {
					double mean  = tensorMean.d[indexVar];
					double stdev_eps = tensorStd.d[indexVar];

					output.d[indexOut++] = (input.d[indexIn++] - mean) / stdev_eps;
					indexVar++;
				}
			}
		}
	}

	/**
	 * Apply gamma and beta to normalized input x_hat
	 */
	private void applyGammaBeta(Tensor_F64 output) {
		int indexOut = output.startIndex;
		int indexTensor = 0;
		int end = params.length();

		for (int stack = 0; stack < miniBatchSize; stack++) {
			int indexParam = params.startIndex;
			while (indexParam < end) {
				double gamma = params.d[indexParam++];
				double beta = params.d[indexParam++];

				output.d[indexOut++] = gamma*tensorXhat.d[indexTensor++] + beta;
			}
		}
	}

	/**
	 * Computes and stores mean, standard deviation, and x_hat the normalized input vector
	 */
	private void computeStatisticsAndNormalize(Tensor_F64 input) {
		tensorMean.zero();
		tensorStd.zero();
		tensorXhat.zero();

		double M_var = miniBatchSize-1; // unbiased variance division, mean is computed with miniBatchSize

		// compute the mean
		int indexIn = input.startIndex;
		for (int stack = 0; stack < miniBatchSize; stack++) {
			int indexVar = 0;
			while (indexVar < D) {
				tensorMean.d[indexVar++] += input.d[indexIn++];
			}
		}
		for (int indexVar = 0; indexVar < D; indexVar++ ) {
			tensorMean.d[indexVar] /= miniBatchSize;
		}

		// compute the unbiased standard deviation with EPS for numerical reasons
		indexIn = input.startIndex;
		int indexTensor = 0;
		for (int stack = 0; stack < miniBatchSize; stack++) {
			for (int indexVar = 0; indexVar < D; indexVar++, indexTensor++ ) {
				double d = input.d[indexIn++] - tensorMean.d[indexVar];
				tensorDiffX.d[indexTensor] = d;
				tensorStd.d[indexVar] += d*d;
			}
		}
		for (int indexVar = 0; indexVar < D; indexVar++ ) {
			tensorStd.d[indexVar] = Math.sqrt( tensorStd.d[indexVar]/M_var + EPS);
		}

		// normalize so that mean is 1 and variance is 1
		// x_hat = (x - mu)/std
		indexTensor = 0;
		for (int stack = 0; stack < miniBatchSize; stack++) {
			for (int indexVar = 0; indexVar < D; indexVar++, indexTensor++ ) {
				tensorXhat.d[indexTensor] = tensorDiffX.d[indexTensor] / tensorStd.d[indexVar];
			}
		}
	}

	@Override
	protected void _backwards(Tensor_F64 input, Tensor_F64 dout,
							  Tensor_F64 gradientInput,
							  List gradientParameters)
	{
		// NOTE: @l/@y = dout
		tensorDXhat.reshape( input.shape );

		if( requiresGammaBeta ) {
			partialXHat(dout);
		} else {
			// if gamma and beta is not required then gamma effectively = 1 and Dxhat = dout
			tensorDXhat.setTo(dout);
		}

		partialVariance();
		partialMean();
		partialX(gradientInput);

		if( requiresGammaBeta ) {
			partialParameters(gradientParameters.get(0),dout);
		}
	}

	/**
	 * compute partial of gamma and Beta
	 *
	 * 
 @l/@gamma = sum( @l/y[i]  * x_hat[i] ) 
*
 @l/@Beta = sum( @l/y[i] )              
*/ private void partialParameters(Tensor_F64 tensorDParam , Tensor_F64 dout) { tensorDParam.zero(); int indexDOut = dout.startIndex; for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) { int indexDParam = 0; for (int indexVar = 0; indexVar < D; indexVar++, indexTensor++, indexDOut++) { double d = dout.d[indexDOut]; tensorDParam.d[indexDParam++] += d*tensorXhat.d[indexTensor]; tensorDParam.d[indexDParam++] += d; } } } /** * compute partial to x_hat * *
 @l/@x_hat[i] = @l/@y[i] * gamma  
*/ private void partialXHat(Tensor_F64 dout) { int indexDOut = dout.startIndex; for (int stack = 0,indexTensor = 0; stack < miniBatchSize; stack++) { for( int indexVar = 0; indexVar < D; indexVar++ , indexTensor++) { // see encoding of params tensorDXhat.d[indexTensor] = dout.d[indexDOut++]*params.d[indexVar*2]; } } } /** * compute partial of the input x * *
 @l/@x[i] = @l/@x_hat[i] / sqrt(sigma^2 + eps) + @l/@var * 2*(x[i]-mean)/M + @l/@mean * 1/M 
*/ private void partialX( Tensor_F64 tensorDX ) { double M_var = miniBatchSize-1; int indexDX = tensorDX.startIndex; for (int stack = 0,indexTensor = 0; stack < miniBatchSize; stack++) { for (int indexVar = 0; indexVar < D; indexVar++, indexTensor++, indexDX++ ) { double val = tensorDXhat.d[indexTensor] / tensorStd.d[indexVar]; val += tensorDVar.d[indexVar]*2*tensorDiffX.d[indexTensor]/M_var + tensorDMean.d[indexVar]/miniBatchSize; tensorDX.d[indexDX] = val; } } } /** * compute the mean partial * *
 @l/@mean = (sum( @l/@x_hat[i] * (-1/sqrt(var + EPS)) ) - @l/@var * (2/M) * sum( x[i] - mean )
*/ private void partialMean() { tensorDMean.zero(); tensorTmp.zero(); double M_var = miniBatchSize-1; for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) { for( int indexVar = 0; indexVar < D; indexVar++, indexTensor++ ) { // sum( x[i] - mean ) tensorTmp.d[indexVar] += tensorDiffX.d[indexTensor]; // @l/@x[i] * (-1) tensorDMean.d[indexVar] -= tensorDXhat.d[indexTensor]; } } for( int indexVar = 0; indexVar < D; indexVar++ ) { tensorDMean.d[indexVar] /= tensorStd.d[indexVar]; tensorDMean.d[indexVar] -= 2.0*tensorDVar.d[indexVar]*tensorTmp.d[indexVar]/M_var; } } /** * compute the variance partial * *
 @l/@var = sum( @l/@x_hat[i] * (x[i] - x_mean) *(-1/2)*(var + EPS)^(-3/2) 
*/ private void partialVariance() { tensorDVar.zero(); for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) { for( int indexVar = 0; indexVar < D; indexVar++, indexTensor++ ) { // @l/@x_hat[i] * (x[i] - x_mean) tensorDVar.d[indexVar] += tensorDXhat.d[indexTensor]*tensorDiffX.d[indexTensor]; } } // (-1/2)*(var + EPS)^(-3/2) for( int indexVar = 0; indexVar < D; indexVar++ ) { double sigmaPow3 = tensorStd.d[indexVar]; sigmaPow3 = sigmaPow3*sigmaPow3*sigmaPow3; tensorDVar.d[indexVar] /= (-2.0*sigmaPow3); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy