
deepboof.impl.backward.standard.DSpatialBatchNorm_F64 Maven / Gradle / Ivy
/*
* Copyright (c) 2016, Peter Abeles. All Rights Reserved.
*
* This file is part of DeepBoof
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package deepboof.impl.backward.standard;
import deepboof.backward.DSpatialBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;
/**
* Implementation of {@link DSpatialBatchNorm} for {@link Tensor_F64}.
*
* @author Peter Abeles
*/
public class DSpatialBatchNorm_F64 extends BaseDBatchNorm_F64
implements DSpatialBatchNorm
{
int numChannels;
int numPixels;
// Number of elements statistics are computed from
// M_var is M-1 for computing unbiased variance
double M,M_var;
public DSpatialBatchNorm_F64(boolean requiresGammaBeta) {
super(requiresGammaBeta);
}
@Override
protected int[] createShapeVariables(int[] shapeInput) {
return new int[]{shapeInput[0]}; // one variable for each channel
}
// TODO push into base class?
@Override
public void _forward(Tensor_F64 input, Tensor_F64 output) {
if( input.length(0) <= 1 )
throw new IllegalArgumentException("There must be more than 1 minibatch");
tensorDiffX.reshape( input.shape );
tensorXhat.reshape( input.shape );
// just compute these variables onces. They are used all over the place
numChannels = input.length(1);
numPixels = TensorOps.outerLength(input.shape,2);
M = miniBatchSize*numPixels;
M_var = M-1;
if( learningMode ) {
forwardLearning(input, output);
} else {
forwardEvaluate(input, output);
}
}
private void forwardLearning(Tensor_F64 input, Tensor_F64 output) {
computeStatisticsAndNormalize(input);
if (requiresGammaBeta) {
applyGammaBeta(output);
} else {
// is gamma and beta are not adjustable then the output is the normalized x_hat
output.setTo(tensorXhat);
}
}
public void forwardEvaluate(Tensor_F64 input, Tensor_F64 output) {
int C = input.length(1);
int W = input.length(2);
int H = input.length(3);
int D = W*H;
int indexIn = input.startIndex;
int indexOut = output.startIndex;
if( hasGammaBeta() ) {
for (int batch = 0; batch < miniBatchSize; batch++) {
int indexP = params.startIndex;
for( int channel = 0; channel < C; channel++ ) {
double mean = tensorMean.d[channel];
double stdev_eps = tensorStd.d[channel];
double gamma = params.d[indexP++];
double beta = params.d[indexP++];
int end = indexIn + D;
while (indexIn < end) {
output.d[indexOut++] = (input.d[indexIn++] - mean)*(gamma / stdev_eps) + beta;
}
}
}
} else {
for (int batch = 0; batch < miniBatchSize; batch++) {
for (int channel = 0; channel < C; channel++) {
double mean = tensorMean.d[channel];
double stdev_eps = tensorStd.d[channel];
int end = indexIn + D;
while (indexIn < end) {
output.d[indexOut++] = (input.d[indexIn++] - mean) / stdev_eps;
}
}
}
}
}
/**
* Apply gamma and beta to normalized input x_hat
*/
private void applyGammaBeta(Tensor_F64 output) {
int indexOut = output.startIndex;
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double gamma = params.d[channel*2];
double beta = params.d[channel*2+1];
for (int pixel = 0; pixel < numPixels; pixel++) {
output.d[indexOut++] = gamma*tensorXhat.d[indexTensor++] + beta;
}
}
}
}
/**
* Computes and stores mean, standard deviation, and x_hat the normalized input vector
*/
private void computeStatisticsAndNormalize(Tensor_F64 input) {
tensorMean.zero();
tensorStd.zero();
tensorXhat.zero();
// compute the mean
int indexIn = input.startIndex;
for (int stack = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double sum = 0;
for (int pixel = 0; pixel < numPixels; pixel++) {
sum += input.d[indexIn++];
}
tensorMean.d[channel] += sum;
}
}
for (int channel = 0; channel < numChannels; channel++) {
tensorMean.d[channel] /= M;
}
// compute the unbiased standard deviation with EPS for numerical reasons
indexIn = input.startIndex;
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double sum = 0;
double channelMean = tensorMean.d[channel];
for (int pixel = 0; pixel < numPixels; pixel++, indexTensor++ ) {
double d = input.d[indexIn++] - channelMean;
tensorDiffX.d[indexTensor] = d;
sum += d*d;
}
tensorStd.d[channel] += sum;
}
}
for (int channel = 0; channel < numChannels; channel++) {
tensorStd.d[channel] = Math.sqrt( tensorStd.d[channel]/M_var + EPS);
}
// normalize so that mean is 1 and variance is 1
// x_hat = (x - mu)/std
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double channelStd = tensorStd.d[channel];
for (int pixel = 0; pixel < numPixels; pixel++, indexTensor++ ) {
tensorXhat.d[indexTensor] = tensorDiffX.d[indexTensor] / channelStd;
}
}
}
}
@Override
protected void _backwards(Tensor_F64 input, Tensor_F64 dout, Tensor_F64 gradientInput, List gradientParameters) {
// NOTE: @l/@y = dout
tensorDXhat.reshape( input.shape );
if( requiresGammaBeta ) {
partialXHat(dout);
} else {
// if gamma and beta is not required then gamma effectively = 1 and Dxhat = dout
tensorDXhat.setTo(dout);
}
partialVariance();
partialMean();
partialX(gradientInput);
if( requiresGammaBeta ) {
partialParameters(gradientParameters.get(0),dout);
}
}
/**
* compute partial of gamma and Beta
*
* @l/@gamma = sum( @l/y[i] * x_hat[i] )
* @l/@Beta = sum( @l/y[i] )
*/
private void partialParameters(Tensor_F64 tensorDParam , Tensor_F64 dout) {
tensorDParam.zero();
int indexDOut = dout.startIndex;
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
int indexDParam = 0;
for (int channel = 0; channel < numChannels; channel++) {
double sumDGamma = 0;
double sumDBeta = 0;
for (int pixel = 0; pixel < numPixels; pixel++, indexTensor++, indexDOut++) {
double d = dout.d[indexDOut];
sumDGamma += d*tensorXhat.d[indexTensor];
sumDBeta += d;
}
tensorDParam.d[indexDParam++] += sumDGamma;
tensorDParam.d[indexDParam++] += sumDBeta;
}
}
}
/**
* compute partial to x_hat
*
* @l/@x_hat[i] = @l/@y[i] * gamma
*/
private void partialXHat(Tensor_F64 dout) {
int indexDOut = dout.startIndex;
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double gamma = params.d[channel*2];
for (int pixel = 0; pixel < numPixels; pixel++) {
tensorDXhat.d[indexTensor++] = dout.d[indexDOut++]*gamma;
}
}
}
}
/**
* compute partial of the input x
*
* @l/@x[i] = @l/@x_hat[i] / sqrt(sigma^2 + eps) + @l/@var * 2*(x[i]-mean)/M + @l/@mean * 1/M
*/
private void partialX( Tensor_F64 tensorDX ) {
int indexDX = tensorDX.startIndex;
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double stdev = tensorStd.d[channel];
double dvar = tensorDVar.d[channel];
double dmean = tensorDMean.d[channel];
for (int pixel = 0; pixel < numPixels; pixel++, indexTensor++, indexDX++) {
double val = tensorDXhat.d[indexTensor]/stdev;
val += dvar*2*tensorDiffX.d[indexTensor]/M_var + dmean/M;
tensorDX.d[indexDX] = val;
}
}
}
}
/**
* compute the mean partial
*
* @l/@mean = (sum( @l/@x_hat[i] * (-1/sqrt(var + EPS)) ) - @l/@var * (2/M) * sum( x[i] - mean )
*/
private void partialMean() {
tensorDMean.zero();
tensorTmp.zero();
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double sumTmp = 0;
double sumDMean = 0;
for (int pixel = 0; pixel < numPixels; pixel++, indexTensor++) {
// sum( x[i] - mean )
sumTmp += tensorDiffX.d[indexTensor];
// @l/@x[i] * (-1)
sumDMean -= tensorDXhat.d[indexTensor];
}
tensorTmp.d[channel] += sumTmp;
tensorDMean.d[channel] += sumDMean;
}
}
for (int channel = 0; channel < numChannels; channel++) {
tensorDMean.d[channel] /= tensorStd.d[channel];
tensorDMean.d[channel] -= 2.0*tensorDVar.d[channel]*tensorTmp.d[channel]/M_var;
}
}
/**
* compute the variance partial
*
* @l/@var = sum( @l/@x_hat[i] * (x[i] - x_mean) *(-1/2)*(var + EPS)^(-3/2)
*/
private void partialVariance() {
tensorDVar.zero();
for (int stack = 0, indexTensor = 0; stack < miniBatchSize; stack++) {
for (int channel = 0; channel < numChannels; channel++) {
double sumDVar = 0;
for (int pixel = 0; pixel < numPixels; pixel++, indexTensor++) {
// @l/@x_hat[i] * (x[i] - x_mean)
sumDVar += tensorDXhat.d[indexTensor]*tensorDiffX.d[indexTensor];
}
tensorDVar.d[channel] += sumDVar;
}
}
// (-1/2)*(var + EPS)^(-3/2)
for (int channel = 0; channel < numChannels; channel++) {
double sigmaPow3 = tensorStd.d[channel];
sigmaPow3 = sigmaPow3*sigmaPow3*sigmaPow3;
tensorDVar.d[channel] /= (-2.0*sigmaPow3);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy