
org.nd4j.linalg.jcublas.util.KernelParamsWrapper Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.jcublas.util;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaMemcpyKind;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.buffer.allocation.PinnedMemoryStrategy;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.JCudaExecutioner;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
/**
* Wraps the generation of kernel parameters
* , creating, copying
* and destroying any cuda device allocations
*
* @author bam4d
*
*/
public class KernelParamsWrapper implements AutoCloseable {
private boolean closeInvoked = false;
private boolean closeContext;
private CudaContext context;
private boolean scalarResult;
/**
* List of processed kernel parameters ready to be passed to the kernel
*/
final public Object[] kernelParameters;
/**
* The pointers that need to be freed as part of this closable resource
*/
final List pointersToFree;
/**
* The pointers that have results that need to be passed back to host buffers
*/
final List resultPointers;
/**
* The operation that should receive the result
*/
private Op resultOp;
/**
* The list of processed kernel parameters, These should be get passed to the cuda kernel
* @return
*/
public Object[] getKernelParameters() {
return kernelParameters;
}
/**
* conversion list of arrays to their assigned cublas pointer
*/
private Multimap arrayToPointer;
private int resultLength = 1;
/**
* set the array that will contain the results, If the array is not set, then data from the device will not be copied to the host
* @param array
* @return
*/
public KernelParamsWrapper setResultArray(INDArray array) {
CublasPointer resultPointer = arrayToPointer.get(array).iterator().next();
resultPointer.setResultPointer(true);
if(resultPointer == null) {
throw new RuntimeException("Results array must be supplied as a kernel parameter");
}
resultPointers.add(resultPointer);
return this;
}
/**
* set the Op that this result is for
* @param op
* @param result
* @return
*/
public KernelParamsWrapper setResultOp(Accumulation op, INDArray result,int...dimension) {
resultOp = op;
resultLength = result.length();
scalarResult = (dimension == null || dimension.length < 1 || dimension[0] == Integer.MAX_VALUE);
setResultArray(result);
return this;
}
/**
* Create a new wrapper for the kernel parameters.
*
* This wrapper manages the host - and device communication and.
*
* To set the result on a specific operation, use setResultOp()
* To set the array which is the result INDArray, use setResultArray()
* @param kernelParams
*/
public KernelParamsWrapper(Op op,Object... kernelParams) {
this(op,false, kernelParams);
}
/**
* Create a new wrapper for the kernel parameters.
*
* This wrapper manages the host - and device communication and.
*
* To set the result on a specific operation, use setResultOp()
* To set the array which is the result INDArray, use setResultArray()
* @param kernelParams
*/
public KernelParamsWrapper(Op op,boolean closeContext,Object... kernelParams) {
kernelParameters = new Object[kernelParams.length];
arrayToPointer = ArrayListMultimap.create();
pointersToFree = new ArrayList<>();
resultPointers = new ArrayList<>();
context = new CudaContext(closeContext);
context.initOldStream();
context.initStream();
this.closeContext = closeContext;
Map
© 2015 - 2025 Weber Informatics LLC | Privacy Policy