org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper Maven / Gradle / Ivy
The newest version!
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.cuda.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* cuDNN-based helper for the batch normalization layer.
*
* @author saudet
*/
@Slf4j
public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper {
public CudnnBatchNormalizationHelper(DataType dataType) {
super(dataType);
}
private static class CudnnBatchNormalizationContext extends CudnnContext {
private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator {
Deallocator(CudnnBatchNormalizationContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
deltaTensorDesc = new cudnnTensorStruct(), gammaBetaTensorDesc = new cudnnTensorStruct();
public CudnnBatchNormalizationContext() {
createHandles();
deallocator(new Deallocator(this));
}
public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) {
super(c);
srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
gammaBetaTensorDesc = new cudnnTensorStruct(c.gammaBetaTensorDesc);
}
@Override
protected void createHandles() {
super.createHandles();
checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnCreateTensorDescriptor(gammaBetaTensorDesc));
}
@Override
protected void destroyHandles() {
checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
checkCudnn(cudnnDestroyTensorDescriptor(gammaBetaTensorDesc));
super.destroyHandles();
}
}
protected final int batchNormMode = CUDNN_BATCHNORM_SPATIAL; // would need to increase rank of gamma and beta for CUDNN_BATCHNORM_PER_ACTIVATION
private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext();
private INDArray meanCache;
private INDArray varCache;
private double eps;
public boolean checkSupported(double eps, boolean isFixedGammaBeta) {
boolean supported = checkSupported();
if (eps < CUDNN_BN_MIN_EPSILON) {
supported = false;
log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + CUDNN_BN_MIN_EPSILON + ")");
}
return supported;
}
@Override
public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
this.eps = eps;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val miniBatch = (int) input.size(0);
val depth = (int) input.size(chIdx);
val inH = (int) input.size(hIdx);
val inW = (int) input.size(wIdx);
final boolean isHalf = (input.dataType() == DataType.HALF);
INDArray gammaOrig = null;
INDArray dGammaViewOrig = null;
INDArray dBetaViewOrig = null;
if(isHalf) { //Convert FP16 to FP32 if required (CuDNN BN doesn't support FP16 for these params, only for input/output)
gammaOrig = gamma;
dGammaViewOrig = dGammaView;
dBetaViewOrig = dBetaView;
/*
From CuDNN docs: bnScale, resultBnScaleDiff, resultBnBiasDiff, savedMean, savedInvVariance
"Note: The data type of this tensor descriptor must be 'float' for FP16 and FP32 input tensors, and 'double'
for FP64 input tensors."
>> Last 2 are the meanCache and varCache; first 3 are below
*/
gamma = gamma.castTo(DataType.FLOAT);
dGammaView = dGammaView.castTo(DataType.FLOAT);
dBetaView = dBetaView.castTo(DataType.FLOAT);
}
Gradient retGradient = new DefaultGradient();
if (!Shape.hasDefaultStridesForShape(epsilon)) {
// apparently not supported by cuDNN
epsilon = epsilon.dup('c');
}
val srcStride = ArrayUtil.toInts(input.stride());
val deltaStride = ArrayUtil.toInts(epsilon.stride());
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
dGammaView, dBetaView);
Pointer srcData = allocator.getPointer(input, context);
Pointer epsData = allocator.getPointer(epsilon, context);
Pointer dstData = allocator.getPointer(nextEpsilon, context);
Pointer gammaData = allocator.getPointer(gamma, context);
Pointer dGammaData = allocator.getPointer(dGammaView, context);
Pointer dBetaData = allocator.getPointer(dBetaView, context);
Pointer meanCacheData = allocator.getPointer(meanCache, context);
Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
dBetaData, eps, meanCacheData, varCacheData));
allocator.getFlowController().registerActionAllWrite(context, input, epsilon, nextEpsilon, gamma, dGammaView,
dBetaView);
retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
context.syncOldStream();
//Convert back and assign, if required:
if(isHalf){
gammaOrig.assign(gamma.castTo(DataType.HALF));
dGammaViewOrig.assign(dGammaView.castTo(DataType.HALF));
dBetaViewOrig.assign(dBetaView.castTo(DataType.HALF));
}
return new Pair<>(retGradient, nextEpsilon);
}
@Override
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.FLOAT16);
INDArray origGamma = gamma;
INDArray origBeta = beta;
INDArray origMean = mean;
INDArray origVar = var;
if(isHalf) {
gamma = gamma.castTo(DataType.FLOAT);
beta = beta.castTo(DataType.FLOAT);
mean = mean.castTo(DataType.FLOAT);
var = var.castTo(DataType.FLOAT);
}
//Notation difference between CuDNN and our implementation:
//Us: runningMean = (1-decay) * batchMean + decay * runningMean
//CuDNN: runningMean = decay * batchMean + (1-decay) * runningMean
//i.e., "decay" has a different meaning...
//Disable in-place updating of running mean/variance, so that all parameter changes are done via the update/gradient
// vector. This is necessary for BatchNormalization to be safe to use in distributed gradient sharing settings
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
val miniBatch = (int) x.size(0);
val inDepth = (int) x.size(chIdx);
val inH = (int) x.size(hIdx);
val inW = (int) x.size(wIdx);
val srcStride = ArrayUtil.toInts(x.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
val dstStride = ArrayUtil.toInts(activations.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context =
allocator.getFlowController().prepareActionAllWrite(x, activations, gamma, beta, mean, var);
Pointer srcData = allocator.getPointer(x, context);
Pointer dstData = allocator.getPointer(activations, context);
Pointer gammaData = allocator.getPointer(gamma, context);
Pointer betaData = allocator.getPointer(beta, context);
Pointer meanData = allocator.getPointer(mean, context);
Pointer varData = allocator.getPointer(var, context);
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
if (training) {
if(meanCache == null || meanCache.length() < mean.length()){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
meanCache = Nd4j.createUninitialized(x.dataType(), mean.length());
}
if(x.dataType() == DataType.HALF){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
meanCache = meanCache.castTo(DataType.FLOAT);
}
}
}
if(varCache == null || varCache.length() < mean.length()){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
varCache = Nd4j.createUninitialized(x.dataType(), mean.length());
}
if(nd4jDataType == DataType.HALF){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
varCache = varCache.castTo(DataType.FLOAT);
}
}
}
Pointer meanCacheData = allocator.getPointer(meanCache, context);
Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnBatchNormalizationForwardTraining(cudnnContext, batchNormMode, this.alpha, this.beta,
cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
cudnnContext.gammaBetaTensorDesc, gammaData, betaData, decay, meanData, varData, eps,
meanCacheData, varCacheData));
} else {
checkCudnn(cudnnBatchNormalizationForwardInference(cudnnContext, batchNormMode, this.alpha, this.beta,
cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData,
cudnnContext.gammaBetaTensorDesc, gammaData, betaData, meanData, varData, eps));
}
allocator.getFlowController().registerActionAllWrite(context, x, activations, gamma, beta, mean, var);
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
context.syncOldStream();
if(training) {
AtomicAllocator.getInstance().getAllocationPoint(meanCache).tickDeviceWrite();
AtomicAllocator.getInstance().getAllocationPoint(varCache).tickDeviceWrite();
}
if(training && isHalf){
//Update the running mean and variance arrays; also gamma/beta
origMean.assign(mean.castTo(DataType.HALF));
origVar.assign(var.castTo(DataType.HALF));
origGamma.assign(gamma.castTo(DataType.HALF));
origBeta.assign(beta.castTo(DataType.HALF));
}
return activations;
}
@Override
public INDArray getMeanCache(DataType dataType) {
if(dataType == DataType.HALF){
//Buffer is FP32
return meanCache.castTo(DataType.HALF);
}
return meanCache;
}
@Override
public INDArray getVarCache(DataType dataType) {
INDArray ret;
if(dataType == DataType.HALF){
INDArray vc = varCache.castTo(DataType.HALF);
ret = vc.mul(vc).rdivi(1.0).subi(eps);
} else {
ret = varCache.mul(varCache).rdivi(1.0).subi(eps);
}
if(dataType == DataType.HALF){
//Buffer is FP32
return ret.castTo(DataType.HALF);
}
return ret;
}
@Override
public Map helperMemoryUse() {
Map memUse = new HashMap<>();
memUse.put("meanCache", meanCache == null ? 0 : meanCache.length() * meanCache.data().getElementSize());
memUse.put("varCache", varCache == null ? 0 : varCache.length() * varCache.data().getElementSize());
return memUse;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy