org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.jcublas.buffer;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.complex.CudaComplexConversion;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
/**
* Base class for a data buffer
*
* @author Adam Gibson
*/
public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer {
static AtomicLong allocated = new AtomicLong();
static AtomicLong totalAllocated = new AtomicLong();
private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
/**
* Pointers to contexts covers this buffer on the gpu at offset 0
* for each thread.
*
* The column key is for offsets. If we only have one device allocation per thread
* we will clobber anything that is already allocated on the gpu.
*
* This also allows us to make a simplifying assumption about how to allocate the data as follows:
*
* Always allocate for offset zero by default. This allows us to reuse the same pointer with an offset
* for each extra allocations (say for row wise operations)
*
* This also prevents duplicate uploads to the gpu.
*/
protected transient Table pointersToContexts = HashBasedTable.create();
protected AtomicBoolean modified = new AtomicBoolean(false);
protected Collection referencing = Collections.synchronizedSet(new HashSet());
protected transient WeakReference ref;
protected AtomicBoolean freed = new AtomicBoolean(false);
private Pointer hostPointer;
private Map copied = new ConcurrentHashMap<>();
public BaseCudaDataBuffer(ByteBuf buf, int length) {
super(buf, length);
}
public BaseCudaDataBuffer(float[] data, boolean copy) {
super(data, copy);
}
public BaseCudaDataBuffer(double[] data, boolean copy) {
super(data, copy);
}
public BaseCudaDataBuffer(int[] data, boolean copy) {
super(data, copy);
}
/**
* Base constructor
*
* @param length the length of the buffer
* @param elementSize the size of each element
*/
public BaseCudaDataBuffer(int length, int elementSize) {
super(length,elementSize);
}
public BaseCudaDataBuffer(int length) {
super(length);
}
public BaseCudaDataBuffer(float[] data) {
super(data);
}
public BaseCudaDataBuffer(int[] data) {
super(data);
}
public BaseCudaDataBuffer(double[] data) {
super(data);
}
@Override
public boolean copied(String name) {
Boolean copied = this.copied.get(name);
if(copied == null)
return false;
return this.copied.get(name);
}
@Override
public void setCopied(String name) {
copied.put(name, true);
}
@Override
public AllocationMode allocationMode() {
return allocationMode;
}
@Override
public ByteBuffer getHostBuffer() {
return dataBuffer.nioBuffer();
}
@Override
public void setHostBuffer(ByteBuffer hostBuffer) {
this.dataBuffer = Unpooled.wrappedBuffer(hostBuffer);
}
@Override
public Pointer getHostPointer() {
if(hostPointer == null) {
hostPointer = Pointer.to(asNio());
}
return hostPointer;
}
@Override
public Pointer getHostPointer(int offset) {
if(hostPointer == null) {
hostPointer = Pointer.to(asNio());
}
return hostPointer.withByteOffset(offset * getElementSize());
}
@Override
public void removeReferencing(String id) {
referencing.remove(id);
}
@Override
public Collection references() {
return referencing;
}
@Override
public int getElementSize() {
return elementSize;
}
@Override
public void addReferencing(String id) {
referencing.add(id);
}
@Override
public void put(int i, IComplexNumber result) {
modified.set(true);
if (dataType() == DataBuffer.Type.FLOAT) {
JCublas2.cublasSetVector(
length(),
getElementSize(),
PointerUtil.getPointer(CudaComplexConversion.toComplex(result.asFloat()))
, 1
, getHostPointer()
, 1);
}
else {
JCublas2.cublasSetVector(
length(),
getElementSize(),
PointerUtil.getPointer(CudaComplexConversion.toComplexDouble(result.asDouble()))
, 1
, getHostPointer()
, 1);
}
}
@Override
public Pointer getDevicePointer(int stride, int offset,int length) {
String name = Thread.currentThread().getName();
DevicePointerInfo devicePointerInfo = pointersToContexts.get(name,offset);
if(devicePointerInfo == null) {
int devicePointerLength = getElementSize() * length;
allocated.addAndGet(devicePointerLength);
totalAllocated.addAndGet(devicePointerLength);
log.trace("Allocating {} bytes, total: {}, overall: {}", devicePointerLength, allocated.get(), totalAllocated);
if(devicePointerInfo == null) {
/**
* Add zero first no matter what.
* Allocate the whole buffer on the gpu
* and use offsets for any other pointers that come in.
* This will allow us to set device pointers with offsets
*
* with no extra allocation.
*
* Notice here we ignore the length of the actual array.
*
* We are going to allocate the whole buffer on the gpu only once.
*
*/
if(!pointersToContexts.contains(name,0)) {
devicePointerInfo = (DevicePointerInfo)
ContextHolder.getInstance()
.getConf()
.getMemoryStrategy()
.alloc(this, 1, 0, this.length);
pointersToContexts.put(name, 0, devicePointerInfo);
}
if(offset > 0) {
/**
* Store the length for the offset of the pointer.
* Return the original pointer with an offset
* (these pointers can't be reused?)
*
* With the device pointer info,
* we want to store the original pointer.
* When retrieving the vector from the gpu later,
* we will use the recorded offset.
*
* Due to gpu instability (please correct me if I'm wrong here)
* we can't seem to reuse the pointers with the offset specified,
* therefore it is desirable to recreate this pointer later.
*
* This will prevent extra allocation as well
* as inform the length for retrieving data from the gpu
* for this particular offset and buffer.
*
*/
Pointer zero = pointersToContexts.get(name,0).getPointer();
Pointer ret = pointersToContexts.get(name,0).getPointer().withByteOffset(offset * getElementSize());
devicePointerInfo = new DevicePointerInfo(zero,length,stride,offset);
pointersToContexts.put(name, offset, devicePointerInfo);
return ret;
}
}
freed.set(false);
}
/**
* Return the device pointer with the specified offset.
* Regardless of whether the device pointer has been allocated,
* we need to return with it respect to the specified array
* not the array's underlying buffer.
*/
return devicePointerInfo.getPointer().withByteOffset(offset * getElementSize());
}
@Override
public void set(Pointer pointer) {
modified.set(true);
if (dataType() == DataBuffer.Type.DOUBLE) {
JCublas2.cublasDcopy(
ContextHolder.getInstance().getHandle(),
length(),
pointer,
1,
getHostPointer(),
1
);
} else {
JCublas2.cublasScopy(
ContextHolder.getInstance().getHandle(),
length(),
pointer,
1,
getHostPointer(),
1
);
}
}
@Override
public IComplexFloat getComplexFloat(int i) {
return Nd4j.createFloat(getFloat(i), getFloat(i + 1));
}
@Override
public IComplexDouble getComplexDouble(int i) {
return Nd4j.createDouble(getDouble(i), getDouble(i + 1));
}
@Override
public IComplexNumber getComplex(int i) {
return dataType() == DataBuffer.Type.FLOAT ? getComplexFloat(i) : getComplexDouble(i);
}
/**
* Set an individual element
*
* @param index the index of the element
* @param from the element to get data from
*/
protected void set(int index, int length, Pointer from, int inc) {
modified.set(true);
int offset = getElementSize() * index;
if (offset >= length() * getElementSize())
throw new IllegalArgumentException("Illegal offset " + offset + " with index of " + index + " and length " + length());
JCublas2.cublasSetVectorAsync(
length
, getElementSize()
, from
, inc
, getHostPointer().withByteOffset(offset)
, 1, ContextHolder.getInstance().getCudaStream());
ContextHolder.syncStream();
}
/**
* Set an individual element
*
* @param index the index of the element
* @param from the element to get data from
*/
protected void set(int index, int length, Pointer from) {
set(index, length, from, 1);
}
@Override
public void assign(DataBuffer data) {
JCudaBuffer buf = (JCudaBuffer) data;
set(0, buf.getHostPointer());
}
/**
* Set an individual element
*
* @param index the index of the element
* @param from the element to get data from
*/
protected void set(int index, Pointer from) {
set(index, 1, from);
}
@Override
public boolean freeDevicePointer(int offset) {
String name = Thread.currentThread().getName();
DevicePointerInfo devicePointerInfo = pointersToContexts.get(name,offset);
//nothing to free, there was no copy. Only the gpu pointer was reused with a different offset.
if(offset != 0)
pointersToContexts.remove(name,offset);
else if(offset == 0 && isPersist) {
return true;
}
else if (devicePointerInfo != null && !freed.get()) {
allocated.addAndGet(-devicePointerInfo.getLength());
log.trace("freeing {} bytes, total: {}", devicePointerInfo.getLength(), allocated.get());
ContextHolder.getInstance().getMemoryStrategy().free(this,offset);
freed.set(true);
copied.remove(name);
pointersToContexts.remove(name,offset);
return true;
}
return false;
}
@Override
public void copyToHost(int offset) {
DevicePointerInfo devicePointerInfo = pointersToContexts.get(Thread.currentThread().getName(),offset);
//prevent inconsistent pointers
if (devicePointerInfo.getOffset() != offset)
throw new IllegalStateException("Device pointer offset didn't match specified offset in pointer map");
if (devicePointerInfo != null) {
ContextHolder.syncStream();
JCublas2.cublasGetVectorAsync(
(int) devicePointerInfo.getLength()
, getElementSize()
, devicePointerInfo.getPointer().withByteOffset(offset * getElementSize())
, devicePointerInfo.getStride()
, getHostPointer(devicePointerInfo.getOffset())
, devicePointerInfo.getStride()
, ContextHolder.getInstance().getCudaStream());
ContextHolder.syncStream();
}
else
throw new IllegalStateException("No offset found to copy");
}
@Override
public void flush() {
throw new UnsupportedOperationException();
}
@Override
public void destroy() {
dataBuffer.clear();
}
private void writeObject(java.io.ObjectOutputStream stream)
throws IOException {
stream.writeInt(length);
stream.writeInt(elementSize);
stream.writeBoolean(isPersist);
if(dataType() == DataBuffer.Type.DOUBLE) {
double[] d = asDouble();
for(int i = 0; i < d.length; i++)
stream.writeDouble(d[i]);
}
else if(dataType() == DataBuffer.Type.FLOAT) {
float[] f = asFloat();
for(int i = 0; i < f.length; i++)
stream.writeFloat(f[i]);
}
}
private void readObject(java.io.ObjectInputStream stream)
throws IOException, ClassNotFoundException {
length = stream.readInt();
elementSize = stream.readInt();
isPersist = stream.readBoolean();
pointersToContexts = HashBasedTable.create();
referencing = Collections.synchronizedSet(new HashSet());
ref = new WeakReference(this,Nd4j.bufferRefQueue());
freed = new AtomicBoolean(false);
if(dataType() == DataBuffer.Type.DOUBLE) {
double[] d = new double[length];
for(int i = 0; i < d.length; i++)
d[i] = stream.readDouble();
} else if (dataType() == DataBuffer.Type.FLOAT) {
float[] f = new float[length];
for (int i = 0; i < f.length; i++)
f[i] = stream.readFloat();
BaseCudaDataBuffer buf = (BaseCudaDataBuffer) KernelFunctions.alloc(f);
setHostBuffer(buf.getHostBuffer());
}
}
@Override
public Table getPointersToContexts() {
return pointersToContexts;
}
public void setPointersToContexts( Table pointersToContexts) {
this.pointersToContexts = pointersToContexts;
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
sb.append("[");
for(int i = 0; i < length(); i++) {
sb.append(getDouble(i));
if(i < length() - 1)
sb.append(",");
}
sb.append("]");
return sb.toString();
}
/**
* Provides information about a device pointer
*
* @author bam4d
*/
public static class DevicePointerInfo {
final private Pointer pointer;
final private long length;
final private int stride;
final private int offset;
private boolean freed = false;
public DevicePointerInfo(Pointer pointer, long length,int stride,int offset) {
this.pointer = pointer;
this.length = length;
this.stride = stride;
this.offset = offset;
}
public boolean isFreed() {
return freed;
}
public void setFreed(boolean freed) {
this.freed = freed;
}
public int getOffset() {
return offset;
}
public int getStride() {
return stride;
}
public Pointer getPointer() {
return pointer;
}
public long getLength() {
return length;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy