All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.ops.executioner.JCudaExecutioner Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas.ops.executioner;



import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.complex.LinearViewComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.LinearViewNDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.SimpleJCublas;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.nd4j.linalg.jcublas.util.KernelParamsWrapper;
import org.nd4j.linalg.util.ArrayUtil;


/**
 * JCuda executioner.
 * 

* Runs ops directly on the gpu * * @author Adam Gibson */ public class JCudaExecutioner extends DefaultOpExecutioner { private JCudaBuffer dummyFloatPointer, dummyDoublePointer; public JCudaExecutioner() { SimpleJCublas.init(); dummyFloatPointer = KernelFunctions.alloc(new float[]{1}); dummyDoublePointer =KernelFunctions.alloc(new double[]{1}); } @Override public Op exec(Op op) { //linear views and oblong offsets can't be handled by the gpu (due to the way the buffers are interpeted as vectors) if(op.x() instanceof LinearViewNDArray || op.x() instanceof IComplexNDArray || op.x().offset() > 0 && op.x().shape().length >= 2) return super.exec(op); if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; invoke(t); } else if (op instanceof Accumulation) { Accumulation acc = (Accumulation) op; invoke(acc); } else if (op instanceof ScalarOp) { ScalarOp sc = (ScalarOp) op; invoke(sc); } return op; } @Override public void iterateOverAllRows(Op op) { persist(op); if(op.x().isRowVector()) { //reset the op in case op.setX(op.x()); if(op.y() != null) op.setY(op.y()); op.setZ(op.z()); exec(op); } //execute row wise else if(op.x().isMatrix() || op.x().isColumnVector()) { if(op.x() instanceof IComplexNDArray) { IComplexNDArray original = (IComplexNDArray) op.x(); IComplexNDArray originalZ = (IComplexNDArray) op.z(); IComplexNDArray y = (IComplexNDArray) op.y(); for(int i = 0; i < original.rows(); i++) { IComplexNDArray row = original.slice(i); IComplexNDArray zRow = originalZ.slice(i); IComplexNDArray rowRaveled = row.ravel(); IComplexNDArray zRowRaveled = zRow.ravel(); op.setX(rowRaveled); op.setZ(zRowRaveled); if(y != null) op.setY(y.slice(i)); exec(op); zRow.assign(op.z()); } } else { INDArray original = op.x(); INDArray originalZ = op.z(); INDArray y = op.y(); for(int i = 0; i < original.rows(); i++) { INDArray row = original.getRow(i); INDArray zRow = originalZ.getRow(i); op.setX(row); op.setZ(zRow); if(y != null) op.setY(y.getRow(i)); exec(op); zRow.assign(op.z()); } } } else { INDArray originalX = op.x(); INDArray originalZ = op.z(); for(int i = 0; i < originalX.slices(); i++) { INDArray slice = originalX.slice(i); INDArray zSlice = originalZ.slice(i); op.setX(slice); op.setZ(zSlice); iterateOverAllRows(op); } } //on the recursive case only free buffers where the buffer is the base case if(op.x().length() == op.x().data().length()) unPersistAndFree(op); } @Override public void iterateOverAllColumns(Op op) { //persist for the duration of the usage of the buffer persist(op); if(op.x().isRowVector()) { exec(op); } //execute row wise else if(op.x().isMatrix() || op.x().isColumnVector()) { exec(op,1); } else { if(op.x() instanceof IComplexNDArray) { IComplexNDArray originalX = (IComplexNDArray) op.x(); IComplexNDArray originalZ = (IComplexNDArray) op.z(); IComplexNDArray y = (IComplexNDArray) op.y(); for(int i = 0; i < op.x().slices(); i++) { op.setX(originalX.getColumn(i)); op.setZ(originalZ.getColumn(i)); if(y != null) op.setY(y.getColumn(i)); iterateOverAllColumns(op); } } else { INDArray originalX = op.x(); INDArray originalZ = op.z(); INDArray y = op.y(); for(int i = 0; i < op.x().slices(); i++) { op.setX(originalX.getColumn(i)); op.setZ(originalZ.getColumn(i)); if(y != null) op.setY(y.getColumn(i)); iterateOverAllColumns(op); } } } //only free once the whole recursion has expired if(op.x().data().length() == op.x().length()) unPersistAndFree(op); } private JCudaBuffer dummyDouble() { return dummyDoublePointer; } private JCudaBuffer dummyFloat() { return dummyFloatPointer; } @Override public INDArray execAndReturn(TransformOp op) { invoke(op); return op.z(); } @Override public Op exec(Op op, int dimension) { persist(op); //only accumulate along a particular dimension if (op instanceof Accumulation) { Accumulation a = (Accumulation) op; return exec(a); } for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); exec(op2); if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; TransformOp t2 = (TransformOp) op2; t.z().vectorAlongDimension(i, dimension).assign(t2.z()); } } unPersistAndFree(op); return op; } @Override public INDArray exec(Accumulation op, int dimension) { if(dimension == Integer.MAX_VALUE) { op.setX(op.x().linearView()); if(op.x() instanceof IComplexNDArray) return Nd4j.scalar(execAndReturn(op).currentResultComplex()); else return Nd4j.scalar(execAndReturn(op).currentResult()); } else if(op.x().isScalar()) return op.x(); if(op.x() instanceof IComplexNDArray) { IComplexNDArray ret = Nd4j.createComplex(ArrayUtil.removeIndex(op.x().shape(), dimension)); IComplexNDArray linear = ret.linearView(); if(op.x().isRowVector()) { //same shape if(dimension == 0) { //no reduction return op.x(); } else if(dimension == 1) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } else if(op.x().isColumnVector()) { if(dimension == 0) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } //cache the whole buffer on each piece until after the processing along //each dimension is done persist(op); for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); IComplexNumber result = execAndReturn((Accumulation) op2).currentResultComplex(); linear.putScalar(i, result); } unPersistAndFree(op); return ret; } else { if(op.x().isRowVector()) { //same shape if(dimension == 0) { //no reduction return op.x(); } else if(dimension == 1) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } else if(op.x().isColumnVector()) { if(dimension == 0) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } INDArray ret = Nd4j.create(ArrayUtil.removeIndex(op.x().shape(), dimension)); INDArray linear = ret.linearView(); persist(op); for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); Number result = execAndReturn((Accumulation) op2).currentResult(); linear.putScalar(i,result.doubleValue()); } unPersistAndFree(op); return ret; } } @Override public INDArray execAndReturn(TransformOp op, int dimension) { //don't free device pointer until after operation is done persist(op); for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); exec(op2); if (op instanceof TransformOp) { TransformOp t = op; TransformOp t2 = (TransformOp) op2; t.z().vectorAlongDimension(i, dimension).assign(t2.z()); } } //don't cache buffers anymore unPersistAndFree(op); return op.z(); } //save the pointers till they are all done being used private void persist(Op op) { persist(op.x()); persist(op.y()); persist(op.z()); } //allow the pointers to be freed private void unPersistAndFree(Op op) { unPersistAndFree(op.x()); unPersistAndFree(op.y()); unPersistAndFree(op.z()); } //persist() on a jcuda buffer forces the buffer uploaded to the gpu to be cached //this is useful for when you have arrays with offsets all viewing the same buffer //and intend on doing operations on slices of the same buffer private void persist(INDArray arr) { if(arr == null) return; arr.data().persist(); } private void unPersistAndFree(INDArray buffer) { if(buffer == null) return; unPersistAndFree(buffer.data()); } private void unPersistAndFree(DataBuffer buffer) { buffer.unPersist(); //free the buffer after un persisting JCudaBuffer buf = (JCudaBuffer) buffer; buf.freeDevicePointer(0); } @Override public INDArray execAndReturn(ScalarOp op, int dimension) { return exec(op, dimension).z(); } /** * Converts the given parameters * in to extra arguments to * pass to the kernel * * @param extraArgs the extra arguments * @param dataType the data type * @return */ private JCudaBuffer toArgs(Object[] extraArgs, String dataType) { if (dataType.equals("double")) { if (extraArgs == null || extraArgs.length < 1) return dummyDouble(); return KernelFunctions.alloc(PointerUtil.toDoubles(extraArgs)); } else if (dataType.equals("float")) { if (extraArgs == null || extraArgs.length < 1) return dummyFloat(); return KernelFunctions.alloc(PointerUtil.toFloats(extraArgs)); } throw new IllegalArgumentException("Illegal datatype"); } private void invoke(Accumulation op) { if(!KernelFunctionLoader.getInstance().exists(op.name())) super.exec(op); INDArray result = Nd4j.create(2); if (op.y() != null) { //int n,int xOffset,int yOffset, double *dx, double *dy,int incx,int incy,double *result Object[] kernelParams = new Object[] { op.n(), op.x().offset(), op.y().offset(), op.x(), op.y(), op.x().majorStride(), op.y().majorStride(), toArgs(op.extraArgs(), getType(op)), result }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultOp(op, result)) { invokeFunction(op, kParams.getKernelParameters()); kParams.close(); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } else { //int n, int xOffset,double *dx,int incx,double result //NOTE THE STRIDE HERE. The stride should be set to 1. //The reason for this is because we only upload //the vector itself that is needed to the gpu //If you never need to use the whole array //with striding change this back to arr.majorStride() Object[] kernelParams = new Object[] { op.n(), op.x().offset(), op.x(), op.x().majorStride(), toArgs(op.extraArgs(), getType(op)), result }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultOp(op, result)) { invokeFunction(op, kParams.getKernelParameters()); kParams.close(); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } } private void invokeFunction(Op op, Object... kernelParams) { /** * Invoke a cuda kernel by name. This will be wrt the function name. * Functions that are accumulations or transforms have names that end with _strided. * */ String functionName = op instanceof TransformOp || op instanceof Accumulation ? op.name() + "_strided" : op.name(); int blocks = PointerUtil.getNumBlocks(op.n(), KernelFunctions.BLOCKS, KernelFunctions.THREADS); int threads = PointerUtil.getNumThreads(op.n(), KernelFunctions.THREADS); KernelFunctions.invoke( blocks ,threads ,functionName ,getType(op) ,kernelParams); } private void invoke(ScalarOp op) { if(!KernelFunctionLoader.getInstance().exists(op.name())) super.exec(op); if (op.y() != null) { Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), op.y().offset(), op.x(), op.y(), op.x().majorStride(), op.y().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } else { Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), PointerUtil.getPointer(op), op.x(), op.x().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); kParams.close(); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } } private String getType(Op op) { return op.x().data().dataType() == DataBuffer.Type.DOUBLE ? "double" : "float"; } private void invoke(TransformOp op) { if(!KernelFunctionLoader.getInstance().exists(op.name()) || op.x() instanceof IComplexNDArray) { super.exec(op); return; } if (op.y() != null) { /** * Construct pointer arguments in the following order: * n * offset, * pointer to buffer * increment, * extraArgs, * result */ Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), op.y().offset(), op.x(), op.y(), op.x().majorStride(), op.y().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); kParams.close(); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } else { //int n,int idx,double *dy,int incy,double *result Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), op.x(), 1, toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); kParams.close(); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy