All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.kernel.KernelFunctions Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas.kernel;


import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaMemcpyKind;
import jcuda.utils.KernelLauncher;
import org.nd4j.linalg.jcublas.SimpleJCublas;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.springframework.core.io.ClassPathResource;

import java.io.IOException;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;

import static jcuda.driver.JCudaDriver.cuCtxSynchronize;
import static jcuda.driver.JCudaDriver.cuLaunchKernel;

/**
 * Kernel functions.
 * 

* Derived from: * http://www.jcuda.org/samples/JCudaVectorAdd.java * * @author Adam Gibson */ public class KernelFunctions { public final static String NAME_SPACE = "org.nd4j.linalg.jcuda.jcublas"; public final static String DOUBLE = NAME_SPACE + ".double.functions"; public final static String FLOAT = NAME_SPACE + ".float.functions"; public final static String REDUCE = NAME_SPACE + ".reducefunctions"; public final static String SHARED_MEM_KEY = NAME_SPACE + ".sharedmem"; public final static String THREADS_KEY = NAME_SPACE + ".threads"; public final static String BLOCKS_KEY = NAME_SPACE + ".blocks"; public static int SHARED_MEM = 512; public static int THREADS = 128; public static int BLOCKS = 512; private static Set reduceFunctions = new ConcurrentSkipListSet<>(); private KernelFunctions() {} static { try { register(); } catch (Exception e) { throw new RuntimeException(e); } } /** * Called at initialization in the static context. * Registers cuda functions based on * the cudafunctions.properties in the classpath * * @throws IOException */ public static void register() throws Exception { ClassPathResource res = new ClassPathResource("/cudafunctions.properties"); if (!res.exists()) throw new IllegalStateException("Please put a cudafunctions.properties in your class path"); Properties props = new Properties(); props.load(res.getInputStream()); KernelFunctionLoader.getInstance().load(); String reduceFunctionsList = props.getProperty(REDUCE); for (String function : reduceFunctionsList.split(",")) reduceFunctions.add(function); SHARED_MEM = Integer.parseInt(props.getProperty(SHARED_MEM_KEY, "512")); THREADS = Integer.parseInt(props.getProperty(THREADS_KEY, "128")); BLOCKS = Integer.parseInt(props.getProperty(BLOCKS_KEY, "64")); } /** * Invoke a function with the given number of parameters * * @param blocks the number of blocks to launch the kernel * @param threadsPerBlock the number of threads per block * @param kernelParameters the parameters * @param dataType the data type to use */ public static void invoke(int blocks, int threadsPerBlock, String functionName,String dataType,Object...kernelParameters) { // Call the kernel function. //dot<<>>( dev_a, dev_b,dev_partial_c ); CUstream stream = ContextHolder.getInstance().getStream(); int sharedMemSize = threadsPerBlock * (dataType.equals("float") ? Sizeof.FLOAT : Sizeof.DOUBLE); KernelLauncher launcher = KernelFunctionLoader.launcher(functionName, dataType); if(launcher == null) throw new IllegalArgumentException("Launcher for function " + functionName + " and data type " + dataType + " does not exist!"); launcher.forFunction(functionName + "_" + dataType) .setBlockSize(threadsPerBlock,1,1) .setGridSize(blocks,1,1).setStream(stream) .setSharedMemSize(sharedMemSize) .call(kernelParameters); ContextHolder.syncStream(); } /** * Allocate a pointer of a given data type * * @param data the data for the pointer * @return the pointer */ public static JCudaBuffer alloc(double[] data) { // Allocate the device input data, and copy the // host input data to the device JCudaBuffer doubleBuffer = new CudaDoubleDataBuffer(data); return doubleBuffer; } /** * Allocate a pointer of a given data type * * @param data the data for the pointer * @return the pointer */ public static JCudaBuffer alloc(float[] data) { // Allocate the device input data, and copy the // host input data to the device JCudaBuffer floatBuffer = new CudaFloatDataBuffer(data); return floatBuffer; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy