com.nativelibs4java.opencl.util.ParallelMath Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of javacl Show documentation
Show all versions of javacl Show documentation
JavaCL is an Object-Oriented API that makes the C OpenCL API available to Java in a very natural way.
It hides away the complexity of cross-platform C bindings, has a clean OO design (with generics, Java enums, NIO buffers, fully typed exceptions...), provides high-level features (OpenGL-interop, array reductions) and comes with samples and demos.
For more info, please visit http://code.google.com/p/nativelibs4java/wiki/OpenCL.
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.util;
import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.nativelibs4java.opencl.util.ReductionUtils;
import com.nativelibs4java.opencl.util.ReductionUtils.Reductor;
import com.nativelibs4java.util.IOUtils;
import com.nativelibs4java.util.Pair;
import static com.nativelibs4java.util.NIOUtils.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
*
* @author ochafik
*/
@SuppressWarnings("unused")
public class ParallelMath {
protected CLContext context;
protected CLQueue queue;
public ParallelMath() {
this(JavaCL.createBestContext().createDefaultQueue());
}
public ParallelMath(CLQueue queue) {
this.queue = queue;
CLContext context = queue.getContext();
}
public CLQueue getQueue() {
return queue;
}
public CLContext getContext() {
return getQueue().getContext();
}
protected String createVectFun1Source(Fun1 function, Primitive type, StringBuilder out) {
String t = type.clTypeName();
String kernelName = "vect_" + function.name() + "_" + t;// + (inPlace ? "_inplace" : "");
out.append("__kernel void " + kernelName + "(\n");
out.append("\t__global const " + t + "* in,\n");
out.append("\t__global " + t + "* out,\n");
out.append("\tlong length\n");
out.append(") {\n");
out.append("\tint i = get_global_id(0);\n");
out.append("\tif (i >= length) return;\n");
out.append("\tout[i] = ");
function.expr("in[i]", out);
out.append(";\n");
out.append("}\n");
return kernelName;
}
protected String createVectFun2Source(Fun2 function, Primitive type1, Primitive type2, Primitive typeOut, StringBuilder out, boolean secondOperandIsScalar) {
String t1 = type1.clTypeName(), t2 = type2.clTypeName(), to = typeOut.clTypeName();
String kernelName = "vect_" + function.name() + "_" + t1 + "_" + t2 + "_" + to;
out.append("__kernel void " + kernelName + "(\n");
out.append("\t__global const " + t1 + "* in1,\n");
if (secondOperandIsScalar)
out.append("\t" + t2 + " in2,\n");
else
out.append("\t__global const " + t2 + "* in2,\n");
out.append("\t__global " + to + "* out,\n");
out.append("\tlong length\n");
out.append(") {\n");
out.append("\tint i = get_global_id(0);\n");
out.append("\tif (i >= length) return;\n");
out.append("\tout[i] = (" + to + ")");
function.expr("in1[i]", (secondOperandIsScalar ? "in2" : "in2[i]"), out);
out.append(";\n");
out.append("}\n");
return kernelName;
}
private EnumMap> fun1Kernels = new EnumMap>(Fun1.class);
public synchronized CLKernel getKernel(Fun1 op, Primitive prim) throws CLBuildException {
EnumMap m = fun1Kernels.get(op);
if (m == null)
fun1Kernels.put(op, m = new EnumMap(Primitive.class));
CLKernel kers = m.get(prim);
if (kers == null) {
StringBuilder out = new StringBuilder(300);
String name = createVectFun1Source(op, prim, out);
CLProgram prog = getContext().createProgram(out.toString()).build();
kers = prog.createKernel(name);
m.put(prim, kers);
}
return kers;
}
static class PrimitiveTrio extends Pair, Pair> {
public PrimitiveTrio(Primitive a, Primitive b, Primitive c, boolean secondOperandIsScalar) {
super(new Pair(a, b), new Pair(c, secondOperandIsScalar));
}
}
private EnumMap> fun2Kernels = new EnumMap>(Fun2.class);
public synchronized CLKernel getKernel(Fun2 op, Primitive prim, boolean secondOperandIsScalar) throws CLBuildException {
return getKernel(op, prim, prim, prim, secondOperandIsScalar);
}
public synchronized CLKernel getKernel(Fun2 op, Primitive prim1, Primitive prim2, Primitive primOut, boolean secondOperandIsScalar) throws CLBuildException {
Map m = fun2Kernels.get(op);
if (m == null)
fun2Kernels.put(op, m = new HashMap());
PrimitiveTrio key = new PrimitiveTrio(prim1, prim2, primOut, secondOperandIsScalar);
CLKernel ker = m.get(key);
if (ker == null) {
StringBuilder out = new StringBuilder(300);
String name = createVectFun2Source(op, prim1, prim2, primOut, out, secondOperandIsScalar);
CLProgram prog = getContext().createProgram(out.toString()).build();
ker = prog.createKernel(name);
m.put(key, ker);
}
return ker;
}
}