org.nd4j.linalg.jcublas.CublasPointer Maven / Gradle / Ivy
The newest version!
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.jcublas;
import jcuda.Pointer;
import org.apache.commons.lang3.tuple.Triple;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.DevicePointerInfo;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.context.CudaContext;
import java.util.Arrays;
/**
* Wraps the allocation
* and freeing of resources on a cuda device
* @author bam4d
*
*/
public class CublasPointer implements AutoCloseable {
/**
* The underlying cuda buffer that contains the host and device memory
*/
private JCudaBuffer buffer;
private Pointer devicePointer;
private Pointer hostPointer;
private boolean closed = false;
private INDArray arr;
private CudaContext cudaContext;
private boolean resultPointer = false;
/**
* frees the underlying
* device memory allocated for this pointer
*/
@Override
public void close() throws Exception {
if( !isResultPointer()) {
destroy();
}
}
/**
* The actual destroy method
*/
public void destroy() {
if(!closed) {
if(arr != null)
buffer.freeDevicePointer(arr.offset(),arr.length());
else
buffer.freeDevicePointer(0,buffer.length());
closed = true;
}
}
/**
*
* @return
*/
public JCudaBuffer getBuffer() {
return buffer;
}
/**
*
* @return
*/
public Pointer getDevicePointer() {
return devicePointer;
}
public Pointer getHostPointer() {
return hostPointer;
}
public void setHostPointer(Pointer hostPointer) {
this.hostPointer = hostPointer;
}
/**
* copies the result to the host buffer
*/
public void copyToHost() {
if(arr != null) {
int compLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
ContextHolder.getInstance().getMemoryStrategy().copyToHost(buffer,arr.offset(),arr.elementWiseStride(),compLength,cudaContext,arr.offset(),arr.elementWiseStride());
}
else {
ContextHolder.getInstance().getMemoryStrategy().copyToHost(buffer,0,cudaContext);
}
}
/**
* Creates a CublasPointer
* for a given JCudaBuffer
* @param buffer
*/
public CublasPointer(JCudaBuffer buffer,CudaContext context) {
this.buffer = buffer;
this.devicePointer = buffer.getDevicePointer(1, 0, buffer.length());
this.cudaContext = context;
context.initOldStream();
DevicePointerInfo info = buffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(0, buffer.length(), 1));
hostPointer = info.getPointers().getHostPointer();
ContextHolder.getInstance().getMemoryStrategy().setData(devicePointer,0,1,buffer.length(),info.getPointers().getHostPointer());
buffer.setCopied(Thread.currentThread().getName());
}
/**
* Creates a CublasPointer for a given INDArray.
*
* This wrapper makes sure that the INDArray offset, stride
* and memory pointers are accurate to the data being copied to and from the device.
*
* If the copyToHost function is used in in this class,
* the host buffer offset and data length is taken care of automatically
* @param array
*/
public CublasPointer(INDArray array,CudaContext context) {
//we have to reset the pointer to be zero offset due to the fact that
//vector based striding won't work with an array that looks like this
if(array instanceof IComplexNDArray) {
if(array.length() * 2 < array.data().length() && !array.isVector()) {
array = Shape.toOffsetZero(array);
}
}
this.cudaContext = context;
buffer = (JCudaBuffer) array.data();
//the name of this thread for knowing whether to copy data or not
String name = Thread.currentThread().getName();
this.arr = array;
if(array.elementWiseStride() < 0) {
this.arr = array.dup();
buffer = (JCudaBuffer) this.arr.data();
if(this.arr.elementWiseStride() < 0)
throw new IllegalStateException("Unable to iterate over buffer");
}
int compLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
int stride = arr instanceof IComplexNDArray ? BlasBufferUtil.getBlasStride(arr) / 2 : BlasBufferUtil.getBlasStride(arr);
//no striding for upload if we are using the whole buffer
this.devicePointer = buffer.getDevicePointer(
this.arr,
stride
,this.arr.offset()
,compLength);
/**
* Neat edge case here.
*
* The striding will overshoot the original array
* when the offset is zero (the case being when offset is zero
* sayon a getRow(0) operation.
*
* We need to allocate the data differently here
* due to how the striding works out.
*/
// Copy the data to the device iff the whole buffer hasn't been copied
if(!buffer.copied(name)) {
ContextHolder.getInstance().getMemoryStrategy().setData(buffer,0,1,buffer.length());
//mark the buffer copied
buffer.setCopied(name);
}
DevicePointerInfo info = buffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(0, buffer.length(), 1));
hostPointer = info.getPointers().getHostPointer();
}
/**
* Whether this is a result pointer or not
* A result pointer means that this
* pointer should not automatically be freed
* but instead wait for results to accumulate
* so they can be returned from
* the gpu first
* @return
*/
public boolean isResultPointer() {
return resultPointer;
}
/**
* Sets whether this is a result pointer or not
* A result pointer means that this
* pointer should not automatically be freed
* but instead wait for results to accumulate
* so they can be returned from
* the gpu first
* @return
*/
public void setResultPointer(boolean resultPointer) {
this.resultPointer = resultPointer;
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
if(devicePointer != null) {
if(arr != null) {
if(arr instanceof IComplexNDArray
&& arr.length() * 2
== buffer.length()
|| arr.length() == buffer.length())
appendWhereArrayLengthEqualsBufferLength(sb);
else
appendWhereArrayLengthLessThanBufferLength(sb);
}
else {
if(buffer.dataType() == DataBuffer.Type.DOUBLE) {
double[] set = new double[buffer.length()];
DataBuffer setBuffer = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setBuffer, 0, 1, buffer.length(), buffer, cudaContext, 1,0);
sb.append(setBuffer);
}
else if(buffer.dataType() == DataBuffer.Type.INT) {
int[] set = new int[buffer.length()];
DataBuffer setBuffer = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setBuffer, 0, 1, buffer.length(),buffer, cudaContext, 1, 0);
sb.append(setBuffer);
}
else {
float[] set = new float[buffer.length()];
DataBuffer setBuffer = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setBuffer,0,1,buffer.length(), buffer,cudaContext,1, 0);
sb.append(setBuffer);
}
}
}
else
sb.append("No device pointer yet");
return sb.toString();
}
private void appendWhereArrayLengthLessThanBufferLength(StringBuffer sb) {
int length = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
if(arr.data().dataType() == DataBuffer.Type.DOUBLE) {
double[] set = new double[length];
DataBuffer setString = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length,buffer, cudaContext, arr.elementWiseStride(),arr.offset());
sb.append(setString);
}
else if(arr.data().dataType() == DataBuffer.Type.INT) {
int[] set = new int[length];
DataBuffer setString = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length, buffer, cudaContext, arr.elementWiseStride(),arr.offset());
sb.append(setString);
}
else {
float[] set = new float[length];
DataBuffer setString = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1,length,buffer, cudaContext, arr.elementWiseStride(),arr.offset());
sb.append(setString);
}
}
private void appendWhereArrayLengthEqualsBufferLength(StringBuffer sb) {
int length = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
if(arr.data().dataType() == DataBuffer.Type.DOUBLE) {
double[] set = new double[length];
DataBuffer setString = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setString,0,1,length,buffer,cudaContext,1,0);
sb.append(setString);
}
else if(arr.data().dataType() == DataBuffer.Type.INT) {
int[] set = new int[length];
DataBuffer setString = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length, buffer, cudaContext, 1, 0);
sb.append(setString);
}
else {
float[] set = new float[length];
DataBuffer setString = Nd4j.createBuffer(set);
ContextHolder.getInstance().getMemoryStrategy().getData(setString, 0, 1, length, buffer, cudaContext, 1, 0);
sb.append(setString);
}
}
public static void free(CublasPointer...pointers) {
for(CublasPointer pointer : pointers) {
try {
pointer.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy