org.deeplearning4j.nn.layers.BaseCudnnHelper Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudart.*;
import static org.bytedeco.cuda.global.cudnn.*;
/**
* Functionality shared by all cuDNN-based helpers.
*
* @author saudet
*/
@Slf4j
public abstract class BaseCudnnHelper {
protected static void checkCuda(int error) {
if (error != cudaSuccess) {
throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString());
}
}
protected static void checkCudnn(int status) {
if (status != CUDNN_STATUS_SUCCESS) {
throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString());
}
}
protected static class CudnnContext extends cudnnContext {
protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
Deallocator(CudnnContext c) {
super(c);
}
@Override
public void deallocate() {
destroyHandles();
}
}
public CudnnContext() {
// insure that cuDNN initializes on the same device as ND4J for this thread
Nd4j.create(1);
AtomicAllocator.getInstance();
// This needs to be called in subclasses:
// createHandles();
// deallocator(new Deallocator(this));
}
public CudnnContext(CudnnContext c) {
super(c);
}
protected void createHandles() {
checkCudnn(cudnnCreate(this));
}
protected void destroyHandles() {
checkCudnn(cudnnDestroy(this));
}
}
protected static class DataCache extends Pointer {
static class Deallocator extends DataCache implements Pointer.Deallocator {
Deallocator(DataCache c) {
super(c);
}
@Override
public void deallocate() {
checkCuda(cudaFree(this));
setNull();
}
}
static class HostDeallocator extends DataCache implements Pointer.Deallocator {
HostDeallocator(DataCache c) {
super(c);
}
@Override
public void deallocate() {
checkCuda(cudaFreeHost(this));
setNull();
}
}
public DataCache() {}
public DataCache(long size) {
position = 0;
limit = capacity = size;
int error = cudaMalloc(this, size);
if (error != cudaSuccess) {
log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error
+ "), proceeding with host memory");
checkCuda(cudaMallocHost(this, size));
deallocator(new HostDeallocator(this));
} else {
deallocator(new Deallocator(this));
}
}
public DataCache(DataCache c) {
super(c);
}
}
protected static class TensorArray extends PointerPointer {
static class Deallocator extends TensorArray implements Pointer.Deallocator {
Pointer owner;
Deallocator(TensorArray a, Pointer owner) {
this.address = a.address;
this.capacity = a.capacity;
this.owner = owner;
}
@Override
public void deallocate() {
for (int i = 0; !isNull() && i < capacity; i++) {
cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i);
checkCudnn(cudnnDestroyTensorDescriptor(t));
}
if (owner != null) {
owner.deallocate();
owner = null;
}
setNull();
}
}
public TensorArray() {}
public TensorArray(long size) {
PointerPointer p = new PointerPointer(size);
p.deallocate(false);
this.address = p.address();
this.limit = p.limit();
this.capacity = p.capacity();
cudnnTensorStruct t = new cudnnTensorStruct();
for (int i = 0; i < capacity; i++) {
checkCudnn(cudnnCreateTensorDescriptor(t));
this.put(i, t);
}
deallocator(new Deallocator(this, p));
}
public TensorArray(TensorArray a) {
super(a);
}
}
protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
protected final DataType nd4jDataType;
protected final int dataType;
protected final int dataTypeSize;
// both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
protected final Pointer alpha;
protected final Pointer beta;
protected SizeTPointer sizeInBytes = new SizeTPointer(1);
public BaseCudnnHelper(@NonNull DataType dataType){
this.nd4jDataType = dataType;
this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE
: dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2;
// both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f);
this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f);
}
public static int toCudnnDataType(DataType type){
switch (type){
case DOUBLE:
return CUDNN_DATA_DOUBLE;
case FLOAT:
return CUDNN_DATA_FLOAT;
case INT:
return CUDNN_DATA_INT32;
case HALF:
return CUDNN_DATA_HALF;
default:
throw new RuntimeException("Cannot convert type: " + type);
}
}
public boolean checkSupported() {
// add general checks here, if any
return true;
}
/**
* From CuDNN documentation -
* "Tensors are restricted to having at least 4 dimensions... When working with lower dimensional data, it is
* recommended that the user create a 4Dtensor, and set the size along unused dimensions to 1."
*
* This method implements that - basically appends 1s to the end (shape or stride) to make it length 4,
* or leaves it unmodified if the length is already 4 or more.
* This method can be used for both shape and strides
*
* @param shapeOrStrides
* @return
*/
protected static int[] adaptForTensorDescr(int[] shapeOrStrides){
if(shapeOrStrides.length >= 4)
return shapeOrStrides;
int[] out = new int[4];
int i=0;
for(; i