All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas.buffer;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.complex.CudaComplexConversion;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Base class for a data buffer
 *
 * @author Adam Gibson
 */
public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer {

    static AtomicLong allocated = new AtomicLong();
    static AtomicLong totalAllocated = new AtomicLong();
    private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
    /**
     * Pointers to contexts covers this buffer on the gpu at offset 0
     * for each thread.
     *
     * The column key is for offsets. If we only have one device allocation per thread
     * we will clobber anything that is already allocated on the gpu.
     *
     * This also allows us to make a simplifying assumption about how to allocate the data as follows:
     *
     * Always allocate for offset zero by default. This allows us to reuse the same pointer with an offset
     * for each extra allocations (say for row wise operations)
     *
     * This also prevents duplicate uploads to the gpu.
     */
    protected transient Table pointersToContexts = HashBasedTable.create();
    protected AtomicBoolean modified = new AtomicBoolean(false);
    protected Collection referencing = Collections.synchronizedSet(new HashSet());
    protected transient WeakReference ref;
    protected AtomicBoolean freed = new AtomicBoolean(false);
    private Pointer hostPointer;
    private Map copied = new ConcurrentHashMap<>();

    public BaseCudaDataBuffer(ByteBuf buf, int length) {
        super(buf, length);
    }

    public BaseCudaDataBuffer(float[] data, boolean copy) {
        super(data, copy);
    }

    public BaseCudaDataBuffer(double[] data, boolean copy) {
        super(data, copy);
    }

    public BaseCudaDataBuffer(int[] data, boolean copy) {
        super(data, copy);
    }

    /**
     * Base constructor
     *
     * @param length      the length of the buffer
     * @param elementSize the size of each element
     */
    public BaseCudaDataBuffer(int length, int elementSize) {
        super(length,elementSize);
    }

    public BaseCudaDataBuffer(int length) {
        super(length);
    }

    public BaseCudaDataBuffer(float[] data) {
        super(data);
    }

    public BaseCudaDataBuffer(int[] data) {
        super(data);
    }

    public BaseCudaDataBuffer(double[] data) {
        super(data);
    }

    @Override
    public boolean copied(String name) {
        Boolean copied = this.copied.get(name);
        if(copied == null)
            return false;
        return this.copied.get(name);
    }

    @Override
    public void setCopied(String name) {
        copied.put(name, true);
    }

    @Override
    public AllocationMode allocationMode() {
        return allocationMode;
    }

    @Override
    public ByteBuffer getHostBuffer() {
        return dataBuffer.nioBuffer();
    }

    @Override
    public void setHostBuffer(ByteBuffer hostBuffer) {
        this.dataBuffer = Unpooled.wrappedBuffer(hostBuffer);
    }

    @Override
    public Pointer getHostPointer() {
        if(hostPointer == null) {
            hostPointer = Pointer.to(asNio());
        }
        return hostPointer;
    }

    @Override
    public Pointer getHostPointer(int offset) {
        if(hostPointer == null) {
            hostPointer = Pointer.to(asNio());
        }
        return hostPointer.withByteOffset(offset * getElementSize());
    }


    @Override
    public void removeReferencing(String id) {
        referencing.remove(id);
    }

    @Override
    public Collection references() {
        return referencing;
    }

    @Override
    public int getElementSize() {
        return elementSize;
    }


    @Override
    public void addReferencing(String id) {
        referencing.add(id);
    }

    @Override
    public void put(int i, IComplexNumber result) {

        modified.set(true);
        if (dataType() == DataBuffer.Type.FLOAT) {
            JCublas2.cublasSetVector(
                    length(),
                    getElementSize(),
                    PointerUtil.getPointer(CudaComplexConversion.toComplex(result.asFloat()))
                    , 1
                    , getHostPointer()
                    , 1);
        }
        else {
            JCublas2.cublasSetVector(
                    length(),
                    getElementSize(),
                    PointerUtil.getPointer(CudaComplexConversion.toComplexDouble(result.asDouble()))
                    , 1
                    , getHostPointer()
                    , 1);
        }
    }




    @Override
    public Pointer getDevicePointer(int stride, int offset,int length) {
        String name = Thread.currentThread().getName();
        DevicePointerInfo devicePointerInfo = pointersToContexts.get(name,offset);
        if(devicePointerInfo == null) {
            int devicePointerLength = getElementSize() * length;
            allocated.addAndGet(devicePointerLength);
            totalAllocated.addAndGet(devicePointerLength);
            log.trace("Allocating {} bytes, total: {}, overall: {}", devicePointerLength, allocated.get(), totalAllocated);
            if(devicePointerInfo == null) {
                /**
                 * Add zero first no matter what.
                 * Allocate the whole buffer on the gpu
                 * and use offsets for any other pointers that come in.
                 * This will allow us to set device pointers with offsets
                 *
                 * with no extra allocation.
                 *
                 * Notice here we ignore the length of the actual array.
                 *
                 * We are going to allocate the whole buffer on the gpu only once.
                 *
                 */
                if(!pointersToContexts.contains(name,0)) {
                    devicePointerInfo = (DevicePointerInfo)
                            ContextHolder.getInstance()
                                    .getConf()
                                    .getMemoryStrategy()
                                    .alloc(this, 1, 0, this.length);

                    pointersToContexts.put(name, 0, devicePointerInfo);
                }

                if(offset > 0) {
                    /**
                     * Store the length for the offset of the pointer.
                     * Return the original pointer with an offset
                     * (these pointers can't be reused?)
                     *
                     * With the device pointer info,
                     * we want to store the original pointer.
                     * When retrieving the vector from the gpu later,
                     * we will use the recorded offset.
                     *
                     * Due to gpu instability (please correct me if I'm wrong here)
                     * we can't seem to reuse the pointers with the offset specified,
                     * therefore it is desirable to recreate this pointer later.
                     *
                     * This will prevent extra allocation as well
                     * as inform the length for retrieving data from the gpu
                     * for this particular offset and buffer.
                     *
                     */
                    Pointer zero = pointersToContexts.get(name,0).getPointer();
                    Pointer ret =  pointersToContexts.get(name,0).getPointer().withByteOffset(offset * getElementSize());
                    devicePointerInfo = new DevicePointerInfo(zero,length,stride,offset);
                    pointersToContexts.put(name, offset, devicePointerInfo);
                    return ret;

                }



            }


            freed.set(false);
        }

        /**
         * Return the device pointer with the specified offset.
         * Regardless of whether the device pointer has been allocated,
         * we need to return with it respect to the specified array
         * not the array's underlying buffer.
         */
        return devicePointerInfo.getPointer().withByteOffset(offset * getElementSize());
    }




    @Override
    public void set(Pointer pointer) {

        modified.set(true);

        if (dataType() == DataBuffer.Type.DOUBLE) {
            JCublas2.cublasDcopy(
                    ContextHolder.getInstance().getHandle(),
                    length(),
                    pointer,
                    1,
                    getHostPointer(),
                    1
            );
        } else {
            JCublas2.cublasScopy(
                    ContextHolder.getInstance().getHandle(),
                    length(),
                    pointer,
                    1,
                    getHostPointer(),
                    1
            );
        }


    }





    @Override
    public IComplexFloat getComplexFloat(int i) {
        return Nd4j.createFloat(getFloat(i), getFloat(i + 1));
    }

    @Override
    public IComplexDouble getComplexDouble(int i) {
        return Nd4j.createDouble(getDouble(i), getDouble(i + 1));
    }

    @Override
    public IComplexNumber getComplex(int i) {
        return dataType() == DataBuffer.Type.FLOAT ? getComplexFloat(i) : getComplexDouble(i);
    }

    /**
     * Set an individual element
     *
     * @param index the index of the element
     * @param from  the element to get data from
     */
    protected void set(int index, int length, Pointer from, int inc) {

        modified.set(true);

        int offset = getElementSize() * index;
        if (offset >= length() * getElementSize())
            throw new IllegalArgumentException("Illegal offset " + offset + " with index of " + index + " and length " + length());

        JCublas2.cublasSetVectorAsync(
                length
                , getElementSize()
                , from
                , inc
                , getHostPointer().withByteOffset(offset)
                , 1, ContextHolder.getInstance().getCudaStream());

        ContextHolder.syncStream();

    }

    /**
     * Set an individual element
     *
     * @param index the index of the element
     * @param from  the element to get data from
     */
    protected void set(int index, int length, Pointer from) {
        set(index, length, from, 1);
    }

    @Override
    public void assign(DataBuffer data) {
        JCudaBuffer buf = (JCudaBuffer) data;
        set(0, buf.getHostPointer());
    }





    /**
     * Set an individual element
     *
     * @param index the index of the element
     * @param from  the element to get data from
     */
    protected void set(int index, Pointer from) {
        set(index, 1, from);
    }

    @Override
    public boolean freeDevicePointer(int offset) {
        String name = Thread.currentThread().getName();
        DevicePointerInfo devicePointerInfo = pointersToContexts.get(name,offset);

        //nothing to free, there was no copy. Only the gpu pointer was reused with a different offset.
        if(offset != 0)
            pointersToContexts.remove(name,offset);
        else if(offset == 0 && isPersist) {
            return true;
        }
        else if (devicePointerInfo != null && !freed.get()) {
            allocated.addAndGet(-devicePointerInfo.getLength());
            log.trace("freeing {} bytes, total: {}", devicePointerInfo.getLength(), allocated.get());
            ContextHolder.getInstance().getMemoryStrategy().free(this,offset);
            freed.set(true);
            copied.remove(name);
            pointersToContexts.remove(name,offset);
            return true;


        }

        return false;
    }

    @Override
    public void copyToHost(int offset) {
        DevicePointerInfo devicePointerInfo = pointersToContexts.get(Thread.currentThread().getName(),offset);
        //prevent inconsistent pointers
        if (devicePointerInfo.getOffset() != offset)
            throw new IllegalStateException("Device pointer offset didn't match specified offset in pointer map");

        if (devicePointerInfo != null) {
            ContextHolder.syncStream();

            JCublas2.cublasGetVectorAsync(
                    (int) devicePointerInfo.getLength()
                    , getElementSize()
                    , devicePointerInfo.getPointer().withByteOffset(offset * getElementSize())
                    , devicePointerInfo.getStride()
                    , getHostPointer(devicePointerInfo.getOffset())
                    , devicePointerInfo.getStride()
                    , ContextHolder.getInstance().getCudaStream());

            ContextHolder.syncStream();


        }

        else
            throw new IllegalStateException("No offset found to copy");

    }






    @Override
    public void flush() {
        throw new UnsupportedOperationException();
    }






    @Override
    public void destroy() {
        dataBuffer.clear();
    }

    private void writeObject(java.io.ObjectOutputStream stream)
            throws IOException {
        stream.writeInt(length);
        stream.writeInt(elementSize);
        stream.writeBoolean(isPersist);
        if(dataType() == DataBuffer.Type.DOUBLE) {
            double[] d = asDouble();
            for(int i = 0; i < d.length; i++)
                stream.writeDouble(d[i]);
        }
        else if(dataType() == DataBuffer.Type.FLOAT) {
            float[] f = asFloat();
            for(int i = 0; i < f.length; i++)
                stream.writeFloat(f[i]);
        }


    }

    private void readObject(java.io.ObjectInputStream stream)
            throws IOException, ClassNotFoundException {
        length = stream.readInt();
        elementSize = stream.readInt();
        isPersist = stream.readBoolean();
        pointersToContexts = HashBasedTable.create();
        referencing = Collections.synchronizedSet(new HashSet());
        ref = new WeakReference(this,Nd4j.bufferRefQueue());
        freed = new AtomicBoolean(false);
        if(dataType() == DataBuffer.Type.DOUBLE) {
            double[] d = new double[length];
            for(int i = 0; i < d.length; i++)
                d[i] = stream.readDouble();
        } else if (dataType() == DataBuffer.Type.FLOAT) {
            float[] f = new float[length];
            for (int i = 0; i < f.length; i++)
                f[i] = stream.readFloat();
            BaseCudaDataBuffer buf = (BaseCudaDataBuffer) KernelFunctions.alloc(f);
            setHostBuffer(buf.getHostBuffer());
        }
    }



    @Override
    public Table getPointersToContexts() {
        return pointersToContexts;
    }

    public void setPointersToContexts( Table pointersToContexts) {
        this.pointersToContexts = pointersToContexts;
    }

    @Override
    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append("[");
        for(int i = 0; i < length(); i++) {
            sb.append(getDouble(i));
            if(i < length() - 1)
                sb.append(",");
        }
        sb.append("]");
        return sb.toString();

    }

    /**
     * Provides information about a device pointer
     *
     * @author bam4d
     */
    public static class DevicePointerInfo {
        final private Pointer pointer;
        final private long length;
        final private int stride;
        final private int offset;
        private boolean freed = false;

        public DevicePointerInfo(Pointer pointer, long length,int stride,int offset) {
            this.pointer = pointer;
            this.length = length;
            this.stride = stride;
            this.offset = offset;
        }

        public boolean isFreed() {
            return freed;
        }

        public void setFreed(boolean freed) {
            this.freed = freed;
        }

        public int getOffset() {
            return offset;
        }



        public int getStride() {
            return stride;
        }

        public Pointer getPointer() {
            return pointer;
        }

        public long getLength() {
            return length;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy