All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.nd4j.linalg.jcublas.buffer;

import java.io.IOException;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUresult;
import jcuda.jcublas.JCublas;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaMemcpyKind;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.complex.CudaComplexConversion;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Base class for a data buffer
 *
 * @author Adam Gibson
 */
public abstract class BaseCudaDataBuffer implements JCudaBuffer {

	protected transient long devicePointerLength;
	protected transient Pointer hostPointer;
	protected transient ByteBuffer hostBuffer;
	protected transient Map pointersToContexts = new ConcurrentHashMap<>();
    protected AtomicBoolean modified = new AtomicBoolean(false);
    protected int length;
    protected int elementSize;
    protected Collection referencing = Collections.synchronizedSet(new HashSet());
    protected transient WeakReference ref;
    protected boolean isPersist = false;
    protected AtomicBoolean freed = new AtomicBoolean(false);
    static AtomicLong allocated = new AtomicLong();
    static AtomicLong totalAllocated = new AtomicLong();
    
    private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
    
    @Override
    public void setHostBuffer(ByteBuffer hostBuffer) {
    	this.hostBuffer = hostBuffer;
    	this.hostPointer = Pointer.to(hostBuffer);
    	hostBuffer.order(ByteOrder.nativeOrder());
    }
    
    @Override
    public ByteBuffer getHostBuffer() {
    	return hostBuffer;
    }
    
    @Override
    public Pointer getHostPointer() {
    	return hostPointer;
    }

    @Override
    public void persist() {
        isPersist = true;
    }

    @Override
    public boolean isPersist() {
        return isPersist;
    }

    /**
     * Base constructor
     *
     * @param length      the length of the buffer
     * @param elementSize the size of each element
     */
    public BaseCudaDataBuffer(int length, int elementSize) {
        this.length = length;
        this.elementSize = elementSize;
        hostBuffer = ByteBuffer.allocate(length*elementSize);
        hostBuffer.order(ByteOrder.nativeOrder());
        hostPointer = Pointer.to(hostBuffer);
    }

    @Override
    public void removeReferencing(String id) {
        referencing.remove(id);
    }

    @Override
    public Collection references() {
        return referencing;
    }

    @Override
    public void addReferencing(String id) {
        referencing.add(id);
    }

    @Override
    public void put(int i, IComplexNumber result) {
        
        modified.set(true);
        if (dataType() == DataBuffer.Type.FLOAT) {
            JCublas.cublasSetVector(
                    length(),
                    new cuComplex[]{CudaComplexConversion.toComplex(result.asFloat())}
                    , i
                    , 1
                    , hostPointer
                    , 1);
        } else {
            JCublas.cublasSetVector(
                    length(),
                    new cuDoubleComplex[]{CudaComplexConversion.toComplexDouble(result.asDouble())}
                    , i
                    , 1
                    , hostPointer
                    , 1);
        }
    }

    @Override
    public float[] asFloat() {
        return hostBuffer.asFloatBuffer().array();
    }

    @Override
    public double[] asDouble() {
        //
    	double[] ret = new double[length()];
        DoubleBuffer buf = getDoubleBuffer();
        for(int i = 0; i < length(); i++) {
            ret[i] = buf.get(i);
        }
        return ret;
    }

    @Override
    public int[] asInt() {
        return hostBuffer.asIntBuffer().array();
    }

    @Override
	public CUdeviceptr getDevicePointer() {
        CUdeviceptr devicePointer = pointersToContexts.get(Thread.currentThread().getName());

    	if(devicePointer == null) {
    		devicePointer = new CUdeviceptr();
    		devicePointerLength = getElementSize() * length();
    		allocated.addAndGet(devicePointerLength);
    		totalAllocated.addAndGet(devicePointerLength);
    		log.trace("Allocating {} bytes, total: {}, overall: {}", devicePointerLength, allocated.get(), totalAllocated);
    		checkResult(JCuda.cudaMalloc(devicePointer, devicePointerLength));
            pointersToContexts.put(Thread.currentThread().getName(),devicePointer);
    		freed.set(false);
    	}
    	
    	return devicePointer;
    }


    @Override
    public void set(Pointer pointer) {
        
        modified.set(true);

        if (dataType() == DataBuffer.Type.DOUBLE) {
            JCublas.cublasDcopy(
                    length(),
                    pointer,
                    1,
                    hostPointer,
                    1
            );
        } else {
            JCublas.cublasScopy(
                    length(),
                    pointer,
                    1,
                    hostPointer,
                    1
            );
        }


    }


    /**
     * Copy the data of this buffer to another buffer on the gpu
     *
     * @param to the buffer to copy data to
     */
    protected void copyTo(JCudaBuffer to) {
        
        for(int i = 0; i < length(); i++) {
            to.put(i,getDouble(i));
        }
    }

    @Override
    public void assign(Number value) {
        assign(value, 0);
    }





    @Override
    public IComplexFloat getComplexFloat(int i) {
        return Nd4j.createFloat(getFloat(i), getFloat(i + 1));
    }

    @Override
    public IComplexDouble getComplexDouble(int i) {
        return Nd4j.createDouble(getDouble(i), getDouble(i + 1));
    }

    @Override
    public IComplexNumber getComplex(int i) {
        return dataType() == DataBuffer.Type.FLOAT ? getComplexFloat(i) : getComplexDouble(i);
    }

    /**
     * Set an individual element
     *
     * @param index the index of the element
     * @param from  the element to get data from
     */
    protected void set(int index, int length, Pointer from, int inc) {
        
        modified.set(true);

        int offset = getElementSize() * index;
        if (offset >= length() * getElementSize())
            throw new IllegalArgumentException("Illegal offset " + offset + " with index of " + index + " and length " + length());
        JCublas.cublasSetVector(
                length
                , getElementSize()
                , from
                , inc
                , hostPointer.withByteOffset(offset)
                , 1);

    }

    /**
     * Set an individual element
     *
     * @param index the index of the element
     * @param from  the element to get data from
     */
    protected void set(int index, int length, Pointer from) {
        
        set(index, length, from, 1);
    }

    @Override
    public void assign(DataBuffer data) {
        
        JCudaBuffer buf = (JCudaBuffer) data;
        set(0, buf.getHostPointer());
    }
    protected ByteBuffer getBuffer() {
        return getBuffer(0);
    }

    protected java.nio.FloatBuffer getFloatBuffer(long offset) {
        return getHostBuffer(offset*Sizeof.FLOAT).asFloatBuffer();
    }

    protected java.nio.FloatBuffer getFloatBuffer() {
        return getFloatBuffer(0);
    }

    protected java.nio.DoubleBuffer getDoubleBuffer(long offset) {
        return getHostBuffer(offset*Sizeof.DOUBLE).asDoubleBuffer();
    }

    protected java.nio.DoubleBuffer getDoubleBuffer() {
        return getDoubleBuffer(0);
    }

    protected ByteBuffer getBuffer(long offset) {
        //ByteBuffer buf = pinnedPointer.getByteBuffer(offset * elementSize(),elementSize() * length() - offset * elementSize());
        // Set the byte order of the ByteBuffer
        hostBuffer.order(ByteOrder.nativeOrder());
        return hostBuffer;
    }
    /**
     * Set an individual element
     *
     * @param index the index of the element
     * @param from  the element to get data from
     */
    protected void set(int index, Pointer from) {
        
        set(index, 1, from);
    }
    
    public static void checkResult(int cuResult)
    {
        if (cuResult != CUresult.CUDA_SUCCESS)
        {
            throw new CudaException(CUresult.stringFor(cuResult));
        }
    }
    
    @Override
    public boolean freeDevicePointer() {
        CUdeviceptr devicePointer = pointersToContexts.get(Thread.currentThread().getName());
    	if(devicePointer != null && !freed.get()) {
			allocated.addAndGet(-devicePointerLength);
	        log.trace("freeing {} bytes, total: {}", devicePointerLength, allocated.get());
	        checkResult(JCuda.cudaFree(devicePointer));
			devicePointer = null;
            freed.set(true);
            devicePointerLength = -1;
            pointersToContexts.remove(Thread.currentThread().getName());
			return true;	
    	} 
    	return false;
    }
    
    @Override
    public void copyToHost() {
        CUdeviceptr devicePointer = pointersToContexts.get(Thread.currentThread().getName());
        if(devicePointer != null)
    		checkResult(JCuda.cudaMemcpy(hostPointer, devicePointer, devicePointerLength, cudaMemcpyKind.cudaMemcpyDeviceToHost));
    }


    @Override
    public double[] getDoublesAt(int offset, int length) {
        return getDoublesAt(offset, 1, length);
    }

    @Override
    public float[] getFloatsAt(int offset, int length) {
        return getFloatsAt(offset, 1, length);
    }

    @Override
    public int getElementSize() {
        return elementSize;
    }

    @Override
    public int length() {
        return length;
    }




    @Override
    public void put(int i, float element) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void put(int i, double element) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void put(int i, int element) {
        throw new UnsupportedOperationException();
    }


    @Override
    public int getInt(int ix) {
        

        return 0;
    }



    @Override
    public void flush() {
        throw new UnsupportedOperationException();
    }


    @Override
    public void assign(int[] indices, float[] data, boolean contiguous) {
        assign(indices, data, contiguous, 1);
    }

    @Override
    public void assign(int[] indices, double[] data, boolean contiguous) {
        assign(indices, data, contiguous, 1);
    }
    
    private ByteBuffer getHostBuffer(long byteOffset)
    {
        if (hostBuffer == null)
        {
            return null;
        }
        if (!(hostBuffer instanceof ByteBuffer))
        {
            return null;
        }
        //hostBuffer.limit((int)(byteOffset + hostBuffer.capacity()));
        hostBuffer.position((int)byteOffset);
        return hostBuffer;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        BaseCudaDataBuffer that = (BaseCudaDataBuffer) o;

        if (devicePointerLength != that.devicePointerLength) return false;
        if (length != that.length) return false;
        if (elementSize != that.elementSize) return false;
        if (isPersist != that.isPersist) return false;
        if (hostPointer != null ? !hostPointer.equals(that.hostPointer) : that.hostPointer != null) return false;
        if (hostBuffer != null ? !hostBuffer.equals(that.hostBuffer) : that.hostBuffer != null) return false;
        if (pointersToContexts != null ? !pointersToContexts.equals(that.pointersToContexts) : that.pointersToContexts != null)
            return false;
        if (modified != null ? !modified.equals(that.modified) : that.modified != null) return false;
        if (referencing != null ? !referencing.equals(that.referencing) : that.referencing != null) return false;
        if (ref != null ? !ref.equals(that.ref) : that.ref != null) return false;
        return !(freed != null ? !freed.equals(that.freed) : that.freed != null);

    }

    @Override
    public int hashCode() {
        int result = (int) (devicePointerLength ^ (devicePointerLength >>> 32));
        result = 31 * result + (hostPointer != null ? hostPointer.hashCode() : 0);
        result = 31 * result + (hostBuffer != null ? hostBuffer.hashCode() : 0);
        result = 31 * result + (pointersToContexts != null ? pointersToContexts.hashCode() : 0);
        result = 31 * result + (modified != null ? modified.hashCode() : 0);
        result = 31 * result + length;
        result = 31 * result + elementSize;
        result = 31 * result + (referencing != null ? referencing.hashCode() : 0);
        result = 31 * result + (ref != null ? ref.hashCode() : 0);
        result = 31 * result + (isPersist ? 1 : 0);
        result = 31 * result + (freed != null ? freed.hashCode() : 0);
        return result;
    }

    @Override
    public void assign(int[] offsets, int[] strides, int n, DataBuffer... buffers) {
        //

        int count = 0;
        for (int i = 0; i < buffers.length; i++) {
            DataBuffer buffer = buffers[i];
            if (buffer instanceof JCudaBuffer) {
                JCudaBuffer buff = (JCudaBuffer) buffer;
                
                try(CublasPointer buffPointer = new CublasPointer(buff)) {
                
	                if (buff.dataType() == DataBuffer.Type.DOUBLE) {
	                	
	                    JCublas.cublasDcopy(
	                            buff.length()
	                            , buffPointer.withByteOffset(buff.getElementSize() * offsets[i])
	                            , strides[i]
	                            , getDevicePointer().withByteOffset(count * buff.getElementSize())
	                            , 1);
	                    
	                    
	                    count += (buff.length() - 1 - offsets[i]) / strides[i] + 1;
	                } else {
	                    JCublas.cublasScopy(buff.length()
	                            , buffPointer.withByteOffset(buff.getElementSize() * offsets[i])
	                            , strides[i]
	                            , getDevicePointer().withByteOffset(count * buff.getElementSize())
	                            , 1);
	                    
	                    count += (buff.length() - 1 - offsets[i]) / strides[i] + 1;
	                }
                } catch (Exception e) {
            		throw new RuntimeException("Could not run cublas command", e);
            	}
            } else
                throw new IllegalArgumentException("Only jcuda data buffers allowed");
        }
        copyToHost();
        freeDevicePointer();
    }

    @Override
    public void assign(DataBuffer... buffers) {
        

        int[] offsets = new int[buffers.length];
        int[] strides = new int[buffers.length];
        for (int i = 0; i < strides.length; i++)
            strides[i] = 1;
        assign(offsets, strides, buffers);
    }

    @Override
    public void assign(int[] offsets, int[] strides, DataBuffer... buffers) {
        assign(offsets, strides, length(), buffers);
    }
    
    @Override
	public void destroy() {
		freeDevicePointer();
		hostBuffer = null;
		hostPointer = null;
		
	}

    private void writeObject(java.io.ObjectOutputStream stream)
            throws IOException {
        stream.writeInt(length);
        stream.writeInt(elementSize);
        stream.writeBoolean(isPersist);
        if(dataType() == DataBuffer.Type.DOUBLE) {
            double[] d = asDouble();
            for(int i = 0; i < d.length; i++)
                stream.writeDouble(d[i]);
        }
        else if(dataType() == DataBuffer.Type.FLOAT) {
            float[] f = asFloat();
            for(int i = 0; i < f.length; i++)
                stream.writeFloat(f[i]);
        }


    }


    private void readObject(java.io.ObjectInputStream stream)
            throws IOException, ClassNotFoundException {
        length = stream.readInt();
        elementSize = stream.readInt();
        isPersist = stream.readBoolean();
        pointersToContexts = new ConcurrentHashMap<>();
        referencing = Collections.synchronizedSet(new HashSet());
        ref = new WeakReference(this,Nd4j.bufferRefQueue());
        freed = new AtomicBoolean(false);
        if(dataType() == DataBuffer.Type.DOUBLE) {
            double[] d = new double[length];
            for(int i = 0; i < d.length; i++)
                d[i] = stream.readDouble();
            BaseCudaDataBuffer  buf = (BaseCudaDataBuffer) KernelFunctions.alloc(d);
            hostPointer = buf.hostPointer;
        }
        else if(dataType() == DataBuffer.Type.FLOAT) {
            float[] f = new float[length];
            for(int i = 0; i < f.length; i++)
                f[i] = stream.readFloat();
            BaseCudaDataBuffer  buf = (BaseCudaDataBuffer) KernelFunctions.alloc(f);
            hostPointer = buf.hostPointer;
        }
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy