All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.util.KernelParamsWrapper Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas.util;

import static jcuda.driver.JCudaDriver.cuMemGetInfo;

import java.util.*;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaMemcpyKind;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.complex.ComplexDouble;
import org.nd4j.linalg.jcublas.complex.ComplexFloat;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.ops.executioner.JCudaExecutioner;

/**
 * Wraps the generation of kernel parameters
 * , creating, copying
 * and destroying any cuda device allocations
 *
 * @author bam4d
 *
 */
public class KernelParamsWrapper implements AutoCloseable {

    private boolean closeInvoked = false;

	/**
	 * List of processed kernel parameters ready to be passed to the kernel
	 */
	final public Object[] kernelParameters;

	/**
	 * The pointers that need to be freed as part of this closable resource
	 */
	final List pointersToFree;

	/**
	 * The pointers that have results that need to be passed back to host buffers
	 */
	final List resultPointers;

	/**
	 * The operation that should receive the result
	 */
	private Op resultOp;

	/**
	 * The list of processed kernel parameters, These should be get passed to the cuda kernel
	 * @return
	 */
	public Object[] getKernelParameters() {
		return kernelParameters;
	}

	/**
	 * conversion list of arrays to their assigned cublas pointer
	 */
	private Multimap arrayToPointer;



	/**
	 * set the array that will contain the results, If the array is not set, then data from the device will not be copied to the host 
	 * @param array
	 * @return
	 */
	public KernelParamsWrapper setResultArray(INDArray array) {

		CublasPointer resultPointer = arrayToPointer.get(array).iterator().next();

		if(resultPointer == null) {
			throw new RuntimeException("Results array must be supplied as a kernel parameter");
		}

		resultPointers.add(resultPointer);

		return this;
	}

	/**
	 * set the Op that this result is for
	 * @param op
	 * @param result
	 * @return
	 */
	public KernelParamsWrapper setResultOp(Accumulation op, INDArray result) {
		resultOp = op;
		setResultArray(result);
		return this;
	}

	/**
	 * Create a new wrapper for the kernel parameters.
	 *
	 * This wrapper manages the host - and device communication and.
	 *
	 * To set the result on a specific operation, use setResultOp()
	 * To set the array which is the result INDArray, use setResultArray()
	 * @param kernelParams
	 */
	public KernelParamsWrapper(Object... kernelParams) {
		kernelParameters = new Object[kernelParams.length];
		arrayToPointer = ArrayListMultimap.create();
		pointersToFree = new ArrayList<>();
		resultPointers = new ArrayList<>();

		for(int i = 0; i < kernelParams.length; i++) {
			Object arg = kernelParams[i];

			// If the instance is a JCudaBuffer we should assign it to the device
			if(arg instanceof JCudaBuffer) {
                JCudaBuffer buffer = (JCudaBuffer) arg;
				CublasPointer pointerToFree = new CublasPointer(buffer);
				kernelParameters[i] = pointerToFree.getDevicePointer();
				pointersToFree.add(pointerToFree);
			}
            else if(arg instanceof INDArray) {
                INDArray array = (INDArray) arg;
				CublasPointer pointerToFree = new CublasPointer(array);
				kernelParameters[i] = pointerToFree.getDevicePointer();
				pointersToFree.add(pointerToFree);
				arrayToPointer.put(array, pointerToFree);
			}
			else
				kernelParameters[i] = arg;

		}
	}

	/**
	 * Free all the buffers from this kernel's parameters
	 */
	@Override
	public void close() throws Exception {
        ContextHolder.syncStream();
        if(closeInvoked)
            return;

        for(CublasPointer cublasPointer : pointersToFree) {
			if(resultPointers.contains(cublasPointer)) {
				if(resultOp != null) {
					setResultForOp(resultOp, cublasPointer);
				}
                else
					cublasPointer.copyToHost();

			}
			cublasPointer.close();
		}


		long[] free = new long[1];
		long[] total = new long[1];
		cuMemGetInfo(free, total);
        closeInvoked = true;
	}

	/**
	 * Set the result within the accumulation operation
	 * @param acc
	 * @param devicePointer
	 */
	private void setResultForOp(Op acc, CublasPointer devicePointer) {

		if (devicePointer.getBuffer().dataType() == DataBuffer.Type.DOUBLE) {
			double[] data = new double[2];
			Pointer get = Pointer.to(data);
            ContextHolder.syncStream();

            JCuda.cudaMemcpyAsync(
                    get
                    , devicePointer.getDevicePointer()
                    , 2 * Sizeof.DOUBLE
                    , cudaMemcpyKind.cudaMemcpyDeviceToHost
                    , ContextHolder.getInstance().getCudaStream());

            ContextHolder.syncStream();

			if(acc instanceof Accumulation) {
				Accumulation acc2 = (Accumulation) acc;
				acc2.setCurrentResult(data[0]);
				acc2.setCurrentResultComplex(new ComplexDouble(data[0],data[1]));
			}


        }
		else {
			float[] data = new float[2];
			Pointer get = Pointer.to(data);
            ContextHolder.syncStream();

            JCuda.cudaMemcpyAsync(
                    get
                    , devicePointer.getDevicePointer()
                    , 2 * Sizeof.FLOAT
                    , cudaMemcpyKind.cudaMemcpyDeviceToHost
                    , ContextHolder.getInstance().getCudaStream());


			if(acc instanceof Accumulation) {
				Accumulation acc2 = (Accumulation) acc;
				acc2.setCurrentResult(data[0]);
				acc2.setCurrentResultComplex(new ComplexDouble(data[0],data[1]));
			}
		}
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy