All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.ops.executioner.JCudaExecutioner Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.nd4j.linalg.jcublas.ops.executioner;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;
import jcuda.runtime.cudaMemcpyKind;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.SimpleJCublas;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.nd4j.linalg.jcublas.util.KernelParamsWrapper;
import org.nd4j.linalg.util.ArrayUtil;


/**
 * JCuda executioner.
 * 

* Runs ops directly on the gpu * * @author Adam Gibson */ public class JCudaExecutioner implements OpExecutioner { private JCudaBuffer dummyFloatPointer, dummyDoublePointer; public JCudaExecutioner() { SimpleJCublas.init(); dummyFloatPointer = KernelFunctions.alloc(new float[]{1}); dummyDoublePointer =KernelFunctions.alloc(new double[]{1}); } @Override public Op exec(Op op) { if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; invoke(t); } else if (op instanceof Accumulation) { Accumulation acc = (Accumulation) op; invoke(acc); } else if (op instanceof ScalarOp) { ScalarOp sc = (ScalarOp) op; invoke(sc); } return op; } @Override public void iterateOverAllRows(Op op) { throw new UnsupportedOperationException(); } @Override public void iterateOverAllColumns(Op op) { throw new UnsupportedOperationException(); } private JCudaBuffer dummyDouble() { return dummyDoublePointer; } private JCudaBuffer dummyFloat() { return dummyFloatPointer; } @Override public INDArray execAndReturn(TransformOp op) { invoke(op); return op.z(); } @Override public Accumulation execAndReturn(Accumulation op) { return (Accumulation) exec(op); } @Override public INDArray execAndReturn(ScalarOp op) { return exec(op).z(); } @Override public Op exec(Op op, int dimension) { //only accumulate along a particular dimension if (op instanceof Accumulation) { Accumulation a = (Accumulation) op; return exec(a); } for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); exec(op2); if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; TransformOp t2 = (TransformOp) op2; t.z().vectorAlongDimension(i, dimension).assign(t2.z()); } } return op; } @Override public INDArray exec(Accumulation op, int dimension) { if(dimension == Integer.MAX_VALUE) { op.setX(op.x().linearView()); if(op.x() instanceof IComplexNDArray) return Nd4j.scalar(execAndReturn(op).currentResultComplex()); else return Nd4j.scalar(execAndReturn(op).currentResult()); } else if(op.x().isScalar()) return op.x(); if(op.x() instanceof IComplexNDArray) { IComplexNDArray ret = Nd4j.createComplex(ArrayUtil.removeIndex(op.x().shape(), dimension)); IComplexNDArray linear = ret.linearView(); if(op.x().isRowVector()) { //same shape if(dimension == 0) { //no reduction return op.x(); } else if(dimension == 1) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } else if(op.x().isColumnVector()) { if(dimension == 0) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); IComplexNumber result = execAndReturn((Accumulation) op2).currentResultComplex(); linear.putScalar(i, result); } return ret; } else { if(op.x().isRowVector()) { //same shape if(dimension == 0) { //no reduction return op.x(); } else if(dimension == 1) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } else if(op.x().isColumnVector()) { if(dimension == 0) { return Nd4j.scalar(execAndReturn(op).currentResult()); } } INDArray ret = Nd4j.create(ArrayUtil.removeIndex(op.x().shape(), dimension)); INDArray linear = ret.linearView(); for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); Number result = execAndReturn((Accumulation) op2).currentResult(); linear.putScalar(i,result.doubleValue()); } return ret; } } @Override public INDArray execAndReturn(TransformOp op, int dimension) { for (int i = 0; i < op.x().vectorsAlongDimension(dimension); i++) { Op op2 = op.opForDimension(i, dimension); exec(op2); if (op instanceof TransformOp) { TransformOp t = op; TransformOp t2 = (TransformOp) op2; t.z().vectorAlongDimension(i, dimension).assign(t2.z()); } } return op.z(); } @Override public INDArray execAndReturn(ScalarOp op, int dimension) { return exec(op, dimension).z(); } /** * Converts the given parameters * in to extra arguments to * pass to the kernel * * @param extraArgs the extra arguments * @param dataType the data type * @return */ private JCudaBuffer toArgs(Object[] extraArgs, String dataType) { if (dataType.equals("double")) { if (extraArgs == null || extraArgs.length < 1) return dummyDouble(); return KernelFunctions.alloc(PointerUtil.toDoubles(extraArgs)); } else if (dataType.equals("float")) { if (extraArgs == null || extraArgs.length < 1) return dummyFloat(); return KernelFunctions.alloc(PointerUtil.toFloats(extraArgs)); } throw new IllegalArgumentException("Illegal datatype"); } private void invoke(Accumulation op) { INDArray result = null; if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) { result = Nd4j.create(2); } else { result = Nd4j.create(2); } if (op.y() != null) { //int n,int xOffset,int yOffset, double *dx, double *dy,int incx,int incy,double *result Object[] kernelParams = new Object[] { op.n(), op.x().offset(), op.y().offset(), op.x(), op.y(), op.x().majorStride(), op.y().majorStride(), toArgs(op.extraArgs(), getType(op)), result }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultOp(op, result)) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } else { //int n, int xOffset,double *dx,int incx,double result Object[] kernelParams = new Object[] { op.n(), op.x().offset(), op.x(), op.x().majorStride(), toArgs(op.extraArgs(), getType(op)), result }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultOp(op, result)) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } } private void invokeFunction(Op op, Object... kernelParams) { String functionName = op instanceof TransformOp || op instanceof Accumulation ? op.name() + "_strided" : op.name(); int blocks = PointerUtil.getNumBlocks(op.n(), KernelFunctions.BLOCKS, KernelFunctions.THREADS); int threads = PointerUtil.getNumThreads(op.n(), KernelFunctions.THREADS); KernelFunctions.invoke(blocks,threads,functionName,getType(op),kernelParams); } private void invoke(ScalarOp op) { if (op.y() != null) { Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), op.y().offset(), op.x(), op.y(), op.x().majorStride(), op.y().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } else { //int n,int idx,double *dy,int incy,double *result //int n, int idx,double dx,double *dy,int incy,double *result Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), PointerUtil.getPointer(op), op.x(), op.x().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } } private String getType(Op op) { return op.x().data().dataType() == DataBuffer.Type.DOUBLE ? "double" : "float"; } private void invoke(TransformOp op) { if (op.y() != null) { /** * Construct pointer arguments in the following order: * n * offset, * pointer to buffer * increment, * extraArgs, * result */ Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), op.y().offset(), op.x(), op.y(), op.x().majorStride(), op.y().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } else { //int n,int idx,double *dy,int incy,double *result Object[] kernelParams = new Object[]{ op.n(), op.x().offset(), op.x(), op.x().majorStride(), toArgs(op.extraArgs(), getType(op)), op.z() }; try(KernelParamsWrapper kParams = new KernelParamsWrapper(kernelParams).setResultArray(op.z())) { invokeFunction(op, kParams.getKernelParameters()); } catch(Exception e) { throw new RuntimeException("Could not execute kernel", e); } } } public static int checkResult(int result) { if (result != cudaError.cudaSuccess) { throw new CudaException(cudaError.stringFor(result)); } return result; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy