
com.nativelibs4java.opencl.blas.CLKernels Maven / Gradle / Ivy
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.blas;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLMem.Usage;
import com.nativelibs4java.opencl.CLPlatform.DeviceFeature;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.nativelibs4java.opencl.LocalSize;
import com.nativelibs4java.opencl.util.Fun1;
import com.nativelibs4java.opencl.util.Fun2;
import com.nativelibs4java.opencl.util.LinearAlgebraUtils;
import com.nativelibs4java.opencl.util.ParallelMath;
import com.nativelibs4java.opencl.util.Primitive;
import static com.nativelibs4java.opencl.blas.CLMatrixUtils.roundUp;
import static org.bridj.Pointer.pointerToInt;
/**
*
* @author ochafik
*/
public class CLKernels {
protected final LinearAlgebraUtils kernels;
protected final ParallelMath math;
protected final CLContext context;
protected final CLQueue queue;
private static volatile CLKernels instance;
public static synchronized void setInstance(CLKernels kernels) {
instance = kernels;
}
public static synchronized CLKernels getInstance() {
if (instance == null) {
try {
instance = new CLKernels();
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
}
return instance;
}
public CLKernels() throws IOException, CLBuildException {
this(
JavaCL.createBestContext(
DeviceFeature.DoubleSupport,
DeviceFeature.MaxComputeUnits
).createDefaultQueue()
);
}
public CLKernels(CLQueue queue) throws IOException, CLBuildException {
kernels = new LinearAlgebraUtils(queue);
math = new ParallelMath(queue);
context = queue.getContext();
this.queue = queue;
}
public CLEvent op1(Primitive prim, Fun1 fun, CLBuffer a, long rows, long columns, long stride, CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
long length = rows * stride;
if (out == null || out.getElementCount() < length)
throw new IllegalArgumentException("Expected buffer of length >= " + length + ", got " + out);
//if (out != null)
// out = (CLBuffer)context.createBuffer(Usage.Output, prim.primitiveType, length);
CLKernel kernel = math.getKernel(fun, prim);
synchronized (kernel) {
kernel.setArgs(a, out, length);
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)length }, eventsToWaitFor);
return evt;
}
}
public CLEvent op2(Primitive prim, Fun2 fun, CLBuffer a, CLBuffer b, long rows, long columns, long stride, CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
long length = rows * stride;
if (out == null || out.getElementCount() < length)
throw new IllegalArgumentException("Expected buffer of length >= " + length + ", got " + out.getElementCount());
//if (out != null)
// out = (CLBuffer)context.createBuffer(Usage.Output, prim.primitiveType, length);
CLKernel kernel = math.getKernel(fun, prim, false);
synchronized (kernel) {
kernel.setArgs(a, b, out, length);
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)length }, eventsToWaitFor);
return evt;
}
}
public CLEvent op2(Primitive prim, Fun2 fun, CLBuffer a, T b, long rows, long columns, long stride, CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
long length = rows * stride;
if (out == null || out.getElementCount() < length)
throw new IllegalArgumentException("Expected buffer of length >= " + length + ", got " + out.getElementCount());
//if (out != null)
// out = (CLBuffer)context.createBuffer(Usage.Output, prim.primitiveType, length);
CLKernel kernel = math.getKernel(fun, prim, true);
synchronized (kernel) {
kernel.setArgs(a, b, out, length);
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)length }, eventsToWaitFor);
return evt;
}
}
Map containsValueKernels = new HashMap();
public boolean containsValue(Primitive primitive, CLBuffer buffer, long length, V value, CLEvent... eventsToWaitFor) throws CLBuildException {
CLKernel kernel;
synchronized (containsValueKernels) {
kernel = containsValueKernels.get(primitive);
if (kernel == null) {
kernel = context.createProgram((
primitive.getRequiredPragmas() +
"__kernel void containsValue( \n" +
" __global const double* a, \n" +
" int length, \n" +
" double value, \n" +
" __global int* pOut \n" +
") { \n" +
" int i = get_global_id(0);\n" +
" if (i >= length) \n" +
" return; \n" +
" \n" +
" if (a[i] == value) \n" +
" *pOut = 1; \n" +
"} \n"
).replaceAll("double", primitive.clTypeName())).createKernel("containsValue");
containsValueKernels.put(primitive, kernel);
}
}
synchronized(kernel) {
CLBuffer pOut = context.createBuffer(Usage.Output, pointerToInt(0));
kernel.setArgs(buffer, (int)length, value, pOut);
kernel.enqueueNDRange(queue, new int[] { (int)length }, eventsToWaitFor).waitFor();
return pOut.read(queue).getInt() != 0;
}
}
Map clearKernels = new HashMap();
public CLEvent clear(Primitive primitive, CLBuffer buffer, long length, CLEvent... eventsToWaitFor) throws CLBuildException {
CLKernel kernel;
synchronized (clearKernels) {
kernel = clearKernels.get(primitive);
if (kernel == null) {
kernel = context.createProgram((
primitive.getRequiredPragmas() +
"__kernel void clear_buffer( \n" +
" __global double* a, \n" +
" int length \n" +
") { \n" +
" int i = get_global_id(0); \n" +
" if (i >= length) \n" +
" return; \n" +
" \n" +
" a[i] = (double)0; \n" +
"} \n"
).replaceAll("double", primitive.clTypeName())).createKernel("clear_buffer");
clearKernels.put(primitive, kernel);
}
}
synchronized(kernel) {
kernel.setArgs(buffer, (int)length);
CLEvent evt = kernel.enqueueNDRange(queue, new int[] { (int)length }, eventsToWaitFor);
//Object array = buffer.read(queue, evt).getArray();
return evt;
}
}
Map matrixMultiplyKernels = new HashMap();
public CLEvent matrixMultiply(Primitive prim,
CLBuffer a, long aRows, long aColumns, long aStride, int aBlockSize,
CLBuffer b, long bRows, long bColumns, long bStride, int bBlockSize,
CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
boolean useBlocks = false;
int blockSize = aBlockSize;
if (blockSize > 1 && blockSize == bBlockSize) {
long[] maxWorkItemSizes = queue.getDevice().getMaxWorkItemSizes();
useBlocks = maxWorkItemSizes.length >= 2 &&
maxWorkItemSizes[0] >= blockSize &&
maxWorkItemSizes[1] >= blockSize;
}
if (useBlocks) {
return blockMatrixMultiply(
blockSize, prim,
a, roundUp(aRows, blockSize), roundUp(aColumns, blockSize),
b, roundUp(bRows, blockSize), roundUp(bColumns, blockSize),
out, eventsToWaitFor);
} else {
return naiveMatrixMultiply(prim, a, aRows, aColumns, aStride, b, bRows, bColumns, bStride, out, eventsToWaitFor);
}
}
public CLEvent blockMatrixMultiply(int blockSize, Primitive prim, CLBuffer a, long aRows, long aColumns, CLBuffer b, long bRows, long bColumns, CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
if (out == null)
throw new IllegalArgumentException("Null output matrix !");
//if (out != null)
// out = (CLBuffer)context.createBuffer(Usage.Output, prim.primitiveType, aRows * bColumns);
CLKernel kernel;
String key = "block_" + blockSize + "_" + prim;
synchronized (matrixMultiplyKernels) {
kernel = matrixMultiplyKernels.get(key);
if (kernel == null) {
String src = prim.getRequiredPragmas() +
"#define BLOCK_SIZE " + blockSize + "\n" +
"#define AS(i, j) As[j + i * BLOCK_SIZE]\n" +
"#define BS(i, j) Bs[j + i * BLOCK_SIZE]\n" +
"\n" +
"__kernel void mulMat( " +
" __global const double* A, int aColumns, " +
" __global const double* B, int bColumns, " +
" __global double* C, " +
" __local double* As, " +
" __local double* Bs " +
") { " +
" // Block index\n" +
" int bx = get_group_id(0);\n" +
" int by = get_group_id(1);\n" +
"\n" +
" // Thread index\n" +
" int tx = get_local_id(0);\n" +
" int ty = get_local_id(1);\n" +
"\n" +
" // Index of the first sub-matrix of A processed by the block\n" +
" int aBegin = aColumns * BLOCK_SIZE * by + aColumns * ty + tx;\n" +
"\n" +
" // Index of the last sub-matrix of A processed by the block\n" +
" int aEnd = aBegin + aColumns;\n" +
"\n" +
" // Step size used to iterate through the sub-matrices of A\n" +
" int aStep = BLOCK_SIZE;\n" +
"\n" +
" // Index of the first sub-matrix of B processed by the block\n" +
" int bBegin = BLOCK_SIZE * bx + bColumns * ty + tx;\n" +
"\n" +
" // Step size used to iterate through the sub-matrices of B\n" +
" int bStep = BLOCK_SIZE * bColumns;\n" +
"\n" +
" // total is used to store the element of the block sub-matrix\n" +
" // that is computed by the thread\n" +
" float total = 0.0f;\n" +
"\n" +
" // Loop over all the sub-matrices of A and B\n" +
" // required to compute the block sub-matrix\n" +
" for (int a = aBegin, b = bBegin;\n" +
" a < aEnd;\n" +
" a += aStep, b += bStep) {\n" +
"\n" +
" // Load the matrices from device memory\n" +
" // to shared memory; each thread loads\n" +
" // one element of each matrix\n" +
" AS(ty, tx) = A[a];\n" +
" BS(ty, tx) = B[b];\n" +
"\t\n" +
" // Synchronize to make sure the matrices are loaded\n" +
" barrier(CLK_LOCAL_MEM_FENCE);\n" +
"\n" +
" // Multiply the two matrices together;\n" +
" // each thread computes one element\n" +
" // of the block sub-matrix \n" +
" #pragma unroll\n" +
" for (int k = 0; k < BLOCK_SIZE; ++k)\n" +
" total += AS(ty, k) * BS(k, tx);\n" +
"\n" +
" // Synchronize to make sure that the preceding\n" +
" // computation is done before loading two new\n" +
" // sub-matrices of A and B in the next iteration\n" +
" barrier(CLK_LOCAL_MEM_FENCE);\n" +
" }\n" +
"\n" +
" C[get_global_id(1) * get_global_size(0) + get_global_id(0)] = total;\n" +
"} "
;
String clTypeName = prim.clTypeName();
src = src.replaceAll("double", clTypeName);
kernel = context.createProgram(src).createKernel("mulMat");
matrixMultiplyKernels.put(key, kernel);
}
}
synchronized (kernel) {
kernel.setArgs(a, (int) aColumns, b, (int) bColumns, out,
LocalSize.ofFloatArray(blockSize * blockSize),
LocalSize.ofFloatArray(blockSize * blockSize));
CLEvent evt = kernel.enqueueNDRange(queue,
new int[]{(int) aRows, (int) bColumns},
new int[]{blockSize, blockSize},
eventsToWaitFor);
return evt;
}
}
public CLEvent naiveMatrixMultiply(Primitive prim,
CLBuffer a, long aRows, long aColumns, long aStride,
CLBuffer b, long bRows, long bColumns, long bStride,
CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
if (out == null)
throw new IllegalArgumentException("Null output matrix !");
//if (out != null)
// out = (CLBuffer)context.createBuffer(Usage.Output, prim.primitiveType, aRows * bColumns);
CLKernel kernel;
String key = "naive_" + prim;
synchronized (matrixMultiplyKernels) {
kernel = matrixMultiplyKernels.get(key);
if (kernel == null) {
String src = prim.getRequiredPragmas() +
"__kernel void mulMat( " +
" __global const double* a, int aRows, int aColumns, int aStride, " +
" __global const double* b, int bColumns, int bStride, " +
" __global double* c " +
") { " +
" int i = get_global_id(0); " +
" int j = get_global_id(1); " +
" " +
" if (i >= aRows || j >= bColumns) return; " +
" double total = 0; " +
" size_t iOff = i * (size_t)aStride; " +
" for (long k = 0; k < aColumns; k++) { " +
" total += a[iOff + k] * b[k * (size_t)bStride + j]; " +
" } " +
" c[i * (size_t)bStride + j] = total; " +
"} "
;
String clTypeName = prim.clTypeName();
src = src.replaceAll("double", clTypeName);
kernel = context.createProgram(src).createKernel("mulMat");
matrixMultiplyKernels.put(key, kernel);
}
}
synchronized (kernel) {
// assert aStride == aColumns: ("Weird a stride: aStride = " + aStride + ", aColumns = " + aColumns);
// assert bStride == bColumns: ("Weird b stride: bStride = " + bStride + ", bColumns = " + bColumns);
kernel.setArgs(a, (int)aRows, (int)aColumns, (int)aStride, b, (int)bColumns, (int)bStride, out);
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)aRows, (int)bColumns }, eventsToWaitFor);
return evt;
}
}
Map matrixTransposeKernels = new HashMap();
public CLEvent matrixTranspose(Primitive prim, CLBuffer a, long aRows, long aColumns, long aStride, CLBuffer out, CLEvent... eventsToWaitFor) throws CLBuildException {
if (out == null)
throw new IllegalArgumentException("Null output matrix !");
//if (out != null)
// out = (CLBuffer)context.createBuffer(Usage.Output, prim.primitiveType, aRows * aColumns);
CLKernel[] kernels;
synchronized (matrixTransposeKernels) {
kernels = matrixTransposeKernels.get(prim);
if (kernels == null) {
String src =
prim.getRequiredPragmas() +
"__kernel void transposeSelf( \n" +
" __global double* a, int aRows, int aColumns, int aStride \n" +
") { \n" +
" int i = get_global_id(0); \n" +
" int j = get_global_id(1); \n" +
" \n" +
" if (i >= aRows || j >= aColumns || j >= i) return; \n" +
" \n" +
" size_t aIndex = i * aStride + j; \n" +
" size_t outIndex = j * aRows + i; \n" +
" double temp = a[outIndex]; \n" +
" a[outIndex] = a[aIndex]; \n" +
" a[aIndex] = temp; \n" +
"} \n" +
"__kernel void transposeOther( \n" +
" __global const double* a, int aRows, int aColumns, int aStride, \n" +
" __global double* out \n" +
") { \n" +
" int i = get_global_id(0); \n" +
" int j = get_global_id(1); \n" +
" \n" +
" if (i >= aRows || j >= aColumns) return; \n" +
" \n" +
" size_t aIndex = i * aStride + j; \n" +
" size_t outIndex = j * aRows + i; \n" +
" out[outIndex] = a[aIndex]; \n" +
"} \n"
;
String clTypeName = prim.clTypeName();
src = src.replaceAll("double", clTypeName);
CLProgram program = context.createProgram(src);
kernels = new CLKernel[] { program.createKernel("transposeSelf"), program.createKernel("transposeOther") };
matrixTransposeKernels.put(prim, kernels);
}
}
boolean self = a.equals(out);
CLKernel kernel = kernels[self ? 0 : 1];
synchronized (kernel) {
if (self)
kernel.setArgs(a, (int)aRows, (int)aColumns, (int)aStride);
else
kernel.setArgs(a, (int)aRows, (int)aColumns, (int)aStride, out);
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)aRows, (int)aColumns }, eventsToWaitFor);
return evt;
}
}
public CLContext getContext() {
return context;
}
public CLQueue getQueue() {
return queue;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy