All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.BaseCudnnHelper Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.nn.layers;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;

import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;
import static org.bytedeco.cuda.global.cudart.*;
import static org.bytedeco.cuda.global.cudnn.*;

/**
 * Functionality shared by all cuDNN-based helpers.
 *
 * @author saudet
 */
@Slf4j
public abstract class BaseCudnnHelper {

    protected static void checkCuda(int error) {
        if (error != cudaSuccess) {
            throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString());
        }
    }

    protected static void checkCudnn(int status) {
        if (status != CUDNN_STATUS_SUCCESS) {
            throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString());
        }
    }

    protected static class CudnnContext extends cudnnContext {

        protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
            Deallocator(CudnnContext c) {
                super(c);
            }

            @Override
            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnContext() {
            // insure that cuDNN initializes on the same device as ND4J for this thread
            Nd4j.create(1);
            AtomicAllocator.getInstance();
            // This needs to be called in subclasses:
            // createHandles();
            // deallocator(new Deallocator(this));
        }

        public CudnnContext(CudnnContext c) {
            super(c);
        }

        protected void createHandles() {
            checkCudnn(cudnnCreate(this));
        }

        protected void destroyHandles() {
            checkCudnn(cudnnDestroy(this));
        }
    }

    protected static class DataCache extends Pointer {

        static class Deallocator extends DataCache implements Pointer.Deallocator {
            Deallocator(DataCache c) {
                super(c);
            }

            @Override
            public void deallocate() {
                checkCuda(cudaFree(this));
                setNull();
            }
        }

        static class HostDeallocator extends DataCache implements Pointer.Deallocator {
            HostDeallocator(DataCache c) {
                super(c);
            }

            @Override
            public void deallocate() {
                checkCuda(cudaFreeHost(this));
                setNull();
            }
        }

        public DataCache() {}

        public DataCache(long size) {
            position = 0;
            limit = capacity = size;
            int error = cudaMalloc(this, size);
            if (error != cudaSuccess) {
                log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error
                                + "), proceeding with host memory");
                checkCuda(cudaMallocHost(this, size));
                deallocator(new HostDeallocator(this));
            } else {
                deallocator(new Deallocator(this));
            }
        }

        public DataCache(DataCache c) {
            super(c);
        }
    }

    protected static class TensorArray extends PointerPointer {

        static class Deallocator extends TensorArray implements Pointer.Deallocator {
            Pointer owner;

            Deallocator(TensorArray a, Pointer owner) {
                this.address = a.address;
                this.capacity = a.capacity;
                this.owner = owner;
            }

            @Override
            public void deallocate() {
                for (int i = 0; !isNull() && i < capacity; i++) {
                    cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i);
                    checkCudnn(cudnnDestroyTensorDescriptor(t));
                }
                if (owner != null) {
                    owner.deallocate();
                    owner = null;
                }
                setNull();
            }
        }

        public TensorArray() {}

        public TensorArray(long size) {
            PointerPointer p = new PointerPointer(size);
            p.deallocate(false);
            this.address = p.address();
            this.limit = p.limit();
            this.capacity = p.capacity();

            cudnnTensorStruct t = new cudnnTensorStruct();
            for (int i = 0; i < capacity; i++) {
                checkCudnn(cudnnCreateTensorDescriptor(t));
                this.put(i, t);
            }
            deallocator(new Deallocator(this, p));
        }

        public TensorArray(TensorArray a) {
            super(a);
        }
    }

    protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;

    protected final DataType nd4jDataType;
    protected final int dataType;
    protected final int dataTypeSize;
    // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
    protected final Pointer alpha;
    protected final Pointer beta;
    protected SizeTPointer sizeInBytes = new SizeTPointer(1);

    public BaseCudnnHelper(@NonNull DataType dataType){
        this.nd4jDataType = dataType;
        this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE
                : dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
        this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2;
        // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta
        this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f);
        this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f);
    }

    public static int toCudnnDataType(DataType type){
        switch (type){
            case DOUBLE:
                return CUDNN_DATA_DOUBLE;
            case FLOAT:
                return CUDNN_DATA_FLOAT;
            case INT:
                return CUDNN_DATA_INT32;
            case HALF:
                return CUDNN_DATA_HALF;
            default:
                throw new RuntimeException("Cannot convert type: " + type);
        }
    }

    public boolean checkSupported() {
        // add general checks here, if any
        return true;
    }


    /**
     * From CuDNN documentation -
     * "Tensors are restricted to having at least 4 dimensions... When working with lower dimensional data, it is
     * recommended that the user create a 4Dtensor, and set the size along unused dimensions to 1."
     *
     * This method implements that - basically appends 1s to the end (shape or stride) to make it length 4,
     * or leaves it unmodified if the length is already 4 or more.
     * This method can be used for both shape and strides
     *
     * @param shapeOrStrides
     * @return
     */
    protected static int[] adaptForTensorDescr(int[] shapeOrStrides){
        if(shapeOrStrides.length >= 4)
            return shapeOrStrides;
        int[] out = new int[4];
        int i=0;
        for(; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy