org.nd4j.linalg.jcublas.util.KernelParamsWrapper Maven / Gradle / Ivy
package org.nd4j.linalg.jcublas.util;
import static jcuda.driver.JCudaDriver.cuMemGetInfo;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaMemcpyKind;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.complex.ComplexDouble;
import org.nd4j.linalg.jcublas.complex.ComplexFloat;
import org.nd4j.linalg.jcublas.ops.executioner.JCudaExecutioner;
/**
* Wraps the generation of kernel parameters
* , creating, copying and destroying any cuda device allocations
* @author bam4d
*
*/
public class KernelParamsWrapper implements AutoCloseable {
/**
* List of processed kernel parameters ready to be passed to the kernel
*/
final public Object[] kernelParameters;
/**
* The pointers that need to be freed as part of this closable resource
*/
final Set pointersToFree;
/**
* The pointers that have results that need to be passed back to host buffers
*/
final Set resultPointers;
/**
* The operation that should receive the result
*/
private Op resultOp;
/**
* The list of processed kernel parameters, These should be get passed to the cuda kernel
* @return
*/
public Object[] getKernelParameters() {
return kernelParameters;
}
/**
* conversion list of arrays to their assigned cublas pointer
*/
private Map arrayToPointer;
/**
* set the array that will contain the results, If the array is not set, then data from the device will not be copied to the host
* @param array
* @return
*/
public KernelParamsWrapper setResultArray(INDArray array) {
CublasPointer resultPointer = arrayToPointer.get(array);
if(resultPointer == null) {
throw new RuntimeException("Results array must be supplied as a kernel parameter");
}
resultPointers.add(resultPointer);
return this;
}
/**
* set the Op that this result is for
* @param op
* @param result
* @return
*/
public KernelParamsWrapper setResultOp(Accumulation op, INDArray result) {
resultOp = op;
setResultArray(result);
return this;
}
/**
* Create a new wrapper for the kernel parameters.
*
* This wrapper manages the host - and device communication and.
*
* To set the result on a specific operation, use setResultOp()
* To set the array which is the result INDArray, use setResultArray()
* @param kernelParams
*/
public KernelParamsWrapper(Object... kernelParams) {
kernelParameters = new Object[kernelParams.length];
arrayToPointer = new HashMap<>();
pointersToFree = new HashSet<>();
resultPointers = new HashSet<>();
for(int i = 0; i
© 2015 - 2024 Weber Informatics LLC | Privacy Policy