All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.CublasPointer Maven / Gradle / Ivy

The newest version!
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas;

import jcuda.Pointer;
import org.apache.commons.lang3.tuple.Triple;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.DevicePointerInfo;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.context.CudaContext;

import java.util.Arrays;

/**
 * Wraps the allocation
 * and freeing of resources on a cuda device
 * @author bam4d
 *
 */
public class CublasPointer  implements AutoCloseable {

    /**
     * The underlying cuda buffer that contains the host and device memory
     */
    private JCudaBuffer buffer;
    private Pointer devicePointer;
    private Pointer hostPointer;
    private boolean closed = false;
    private INDArray arr;
    private CudaContext cudaContext;
    private boolean resultPointer = false;


    /**
     * frees the underlying
     * device memory allocated for this pointer
     */
    @Override
    public void close() throws Exception {
        if( !isResultPointer()) {
            destroy();
        }
    }


    /**
     * The actual destroy method
     */
    public void destroy() {
        if(!closed) {
            if(arr != null)
                buffer.freeDevicePointer(arr.offset(),arr.length());
            else
                buffer.freeDevicePointer(0,buffer.length());
            closed = true;
        }
    }


    /**
     *
     * @return
     */
    public JCudaBuffer getBuffer() {
        return buffer;
    }

    /**
     *
     * @return
     */
    public Pointer getDevicePointer() {
        return devicePointer;
    }

    public Pointer getHostPointer() {
        return hostPointer;
    }

    public void setHostPointer(Pointer hostPointer) {
        this.hostPointer = hostPointer;
    }

    /**
     * copies the result to the host buffer
     */
    public void copyToHost() {
        if(arr != null) {
            int compLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
            ContextHolder.getInstance().getMemoryStrategy().copyToHost(buffer,arr.offset(),arr.elementWiseStride(),compLength,cudaContext,arr.offset(),arr.elementWiseStride());
        }
        else {
            ContextHolder.getInstance().getMemoryStrategy().copyToHost(buffer,0,cudaContext);
        }
    }


    /**
     * Creates a CublasPointer
     * for a given JCudaBuffer
     * @param buffer
     */
    public CublasPointer(JCudaBuffer buffer,CudaContext context) {
        this.buffer = buffer;
        this.devicePointer = buffer.getDevicePointer(1, 0, buffer.length());
        this.cudaContext = context;
        context.initOldStream();
        DevicePointerInfo info = buffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(0, buffer.length(), 1));
        hostPointer = info.getPointers().getHostPointer();
        ContextHolder.getInstance().getMemoryStrategy().setData(devicePointer,0,1,buffer.length(),info.getPointers().getHostPointer());
        buffer.setCopied(Thread.currentThread().getName());
    }

    /**
     * Creates a CublasPointer for a given INDArray.
     *
     * This wrapper makes sure that the INDArray offset, stride
     * and memory pointers are accurate to the data being copied to and from the device.
     *
     * If the copyToHost function is used in in this class,
     * the host buffer offset and data length is taken care of automatically
     * @param array
     */
    public CublasPointer(INDArray array,CudaContext context) {
        //we have to reset the pointer to be zero offset due to the fact that
        //vector based striding won't work with an array that looks like this
        if(array instanceof IComplexNDArray) {
            if(array.length() * 2 < array.data().length()  && !array.isVector()) {
                array = Shape.toOffsetZero(array);
            }
        }
        this.cudaContext = context;
        buffer = (JCudaBuffer) array.data();

        //the name of this thread for knowing whether to copy data or not
        String name = Thread.currentThread().getName();
        this.arr = array;
        if(array.elementWiseStride() < 0) {
            this.arr = array.dup();
            buffer = (JCudaBuffer) this.arr.data();
            if(this.arr.elementWiseStride() < 0)
                throw new IllegalStateException("Unable to iterate over buffer");
        }

        int compLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
        int stride = arr instanceof IComplexNDArray ? BlasBufferUtil.getBlasStride(arr) / 2 : BlasBufferUtil.getBlasStride(arr);
        //no striding for upload if we are using the whole buffer
        this.devicePointer = buffer.getDevicePointer(
                this.arr,
                stride
                ,this.arr.offset()
                ,compLength);


        /**
         * Neat edge case here.
         *
         * The striding will overshoot the original array
         * when the offset is zero (the case being when offset is zero
         * sayon a getRow(0) operation.
         *
         * We need to allocate the data differently here
         * due to how the striding works out.
         */
        // Copy the data to the device iff the whole buffer hasn't been copied
        if(!buffer.copied(name)) {
            ContextHolder.getInstance().getMemoryStrategy().setData(buffer,0,1,buffer.length());
            //mark the buffer copied
            buffer.setCopied(name);

        }

        DevicePointerInfo info = buffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(0, buffer.length(), 1));
        hostPointer = info.getPointers().getHostPointer();


    }


    /**
     * Whether this is a result pointer or not
     * A result pointer means that this
     * pointer should not automatically be freed
     * but instead wait for results to accumulate
     * so they can be returned from
     * the gpu first
     * @return
     */
    public boolean isResultPointer() {
        return resultPointer;
    }

    /**
     * Sets whether this is a result pointer or not
     * A result pointer means that this
     * pointer should not automatically be freed
     * but instead wait for results to accumulate
     * so they can be returned from
     * the gpu first
     * @return
     */
    public void setResultPointer(boolean resultPointer) {
        this.resultPointer = resultPointer;
    }

    @Override
    public String toString() {
        StringBuffer sb = new StringBuffer();
        if(devicePointer != null) {
            if(arr != null) {
                if(arr instanceof IComplexNDArray
                        && arr.length() * 2
                        == buffer.length()
                        || arr.length() == buffer.length())
                    appendWhereArrayLengthEqualsBufferLength(sb);
                else
                    appendWhereArrayLengthLessThanBufferLength(sb);

            }
            else {
                if(buffer.dataType() == DataBuffer.Type.DOUBLE) {
                    double[] set = new double[buffer.length()];
                    DataBuffer setBuffer = Nd4j.createBuffer(set);
                    ContextHolder.getInstance().getMemoryStrategy().getData(setBuffer, 0, 1, buffer.length(), buffer, cudaContext, 1,0);
                    sb.append(setBuffer);
                }
                else if(buffer.dataType() == DataBuffer.Type.INT) {
                    int[] set = new int[buffer.length()];
                    DataBuffer setBuffer = Nd4j.createBuffer(set);
                    ContextHolder.getInstance().getMemoryStrategy().getData(setBuffer, 0, 1, buffer.length(),buffer, cudaContext, 1, 0);
                    sb.append(setBuffer);
                }
                else {
                    float[] set = new float[buffer.length()];
                    DataBuffer setBuffer = Nd4j.createBuffer(set);
                    ContextHolder.getInstance().getMemoryStrategy().getData(setBuffer,0,1,buffer.length(), buffer,cudaContext,1, 0);
                    sb.append(setBuffer);
                }


            }
        }
        else
            sb.append("No device pointer yet");
        return sb.toString();
    }


    private void appendWhereArrayLengthLessThanBufferLength(StringBuffer sb) {
        int length = arr instanceof  IComplexNDArray ? arr.length() * 2 : arr.length();

        if(arr.data().dataType() == DataBuffer.Type.DOUBLE) {
            double[] set = new double[length];
            DataBuffer setString = Nd4j.createBuffer(set);
            ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length,buffer, cudaContext, arr.elementWiseStride(),arr.offset());
            sb.append(setString);
        }
        else if(arr.data().dataType() == DataBuffer.Type.INT) {
            int[] set = new int[length];
            DataBuffer setString = Nd4j.createBuffer(set);
            ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length, buffer, cudaContext, arr.elementWiseStride(),arr.offset());
            sb.append(setString);
        }
        else {
            float[] set = new float[length];
            DataBuffer setString = Nd4j.createBuffer(set);
            ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1,length,buffer, cudaContext, arr.elementWiseStride(),arr.offset());
            sb.append(setString);
        }
    }

    private void appendWhereArrayLengthEqualsBufferLength(StringBuffer sb) {
        int length = arr instanceof  IComplexNDArray ? arr.length() * 2 : arr.length();
        if(arr.data().dataType() == DataBuffer.Type.DOUBLE) {
            double[] set = new double[length];
            DataBuffer setString = Nd4j.createBuffer(set);
            ContextHolder.getInstance().getMemoryStrategy().getData(setString,0,1,length,buffer,cudaContext,1,0);
            sb.append(setString);
        }
        else if(arr.data().dataType() == DataBuffer.Type.INT) {
            int[] set = new int[length];
            DataBuffer setString = Nd4j.createBuffer(set);
            ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length, buffer, cudaContext, 1, 0);
            sb.append(setString);
        }
        else {
            float[] set = new float[length];
            DataBuffer setString = Nd4j.createBuffer(set);
            ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length, buffer, cudaContext, 1, 0);
            sb.append(setString);
        }
    }


    public static void free(CublasPointer...pointers) {
        for(CublasPointer pointer : pointers) {
            try {
                pointer.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy