com.nativelibs4java.opencl.util.ParallelMath Maven / Gradle / Ivy
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.util;
import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLDoubleBuffer;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.nativelibs4java.opencl.util.ReductionUtils;
import com.nativelibs4java.opencl.util.ReductionUtils.Reductor;
import com.nativelibs4java.util.IOUtils;
import com.ochafik.util.listenable.Pair;
import static com.nativelibs4java.util.NIOUtils.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.nio.DoubleBuffer;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
*
* @author ochafik
*/
@SuppressWarnings("unused")
public class ParallelMath {
protected CLContext context;
protected CLQueue queue;
public ParallelMath() {
this(JavaCL.createBestContext().createDefaultQueue());
}
public ParallelMath(CLQueue queue) {
this.queue = queue;
CLContext context = queue.getContext();
}
public CLQueue getQueue() {
return queue;
}
public CLContext getContext() {
return getQueue().getContext();
}
public enum Fun1 {
log,
exp,
sqrt,
sin,
cos,
tan,
atan,
asin,
acos,
sinh,
cosh,
tanh,
asinh,
acosh,
atanh;
void expr(String a, StringBuilder out) {
out.append(name()).append('(').append(a).append(")");
}
}
public enum Fun2 {
atan2,
dist,
modulo("%"),
rshift(">>"),
lshift("<<"),
add("+"),
substract("-"),
multiply("*"),
divide("/");
String infixOp;
Fun2() {}
Fun2(String infixOp) {
this.infixOp = infixOp;
}
void expr(String a, String b, StringBuilder out) {
if (infixOp == null)
out.append(name()).append('(').append(a).append(", ").append(b).append(")");
else
out.append(a).append(' ').append(infixOp).append(' ').append(b);
}
}
public enum Primitive {
Float,
Double,
Long,
Int,
Short,
Byte,
Float2,
Double2,
Long2,
Int2,
Short2,
Byte2,
Float3,
Double3,
Long3,
Int3,
Short3,
Byte3,
Float4,
Double4,
Long4,
Int4,
Short4,
Byte4,
Float8,
Double8,
Long8,
Int8,
Short8,
Byte8,
Float16,
Double16,
Long16,
Int16,
Short16,
Byte16;
String type() {
return name().toLowerCase();
}
}
protected String createVectFun1Source(Fun1 function, Primitive type, StringBuilder out, boolean inPlace) {
String t = type.type();
String kernelName = "vect_" + function.name() + "_" + t + (inPlace ? "_inplace" : "");
out.append("__kernel void " + kernelName + "(\n");
if (!inPlace)
out.append("\t__global const " + t + "* in,\n");
out.append("\t__global " + t + "* out\n");
out.append(") {\n");
out.append("\tint i = get_global_id(0);\n");
out.append("\tout[i] = ");
function.expr(inPlace ? "out" : "in", out);
out.append("[i]);\n");
out.append("}\n");
return kernelName;
}
protected String createVectFun2Source(Fun2 function, Primitive type1, Primitive type2, Primitive typeOut, StringBuilder out) {
String t1 = type1.type(), t2 = type2.type(), to = typeOut.type();
String kernelName = "vect_" + function.name() + "_" + t1 + "_" + t2 + "_" + to;
out.append("__kernel void " + kernelName + "(\n");
out.append("\t__global const " + t1 + "* in1,\n");
out.append("\t__global const " + t2 + "* in2,\n");
out.append("\t__global " + to + "* out\n");
out.append(") {\n");
out.append("\tint i = get_global_id(0);\n");
out.append("\tout[i] = (" + to + ")");
function.expr("in1[i]", "in2[i]", out);
out.append(";\n");
out.append("}\n");
return kernelName;
}
private static class Fun1Kernels {
CLKernel inPlace, notInPlace;
}
private EnumMap> fun1Kernels = new EnumMap>(Fun1.class);
public synchronized CLKernel getKernel(Fun1 op, Primitive prim, boolean inPlace) throws CLBuildException {
EnumMap m = fun1Kernels.get(op);
if (m == null)
fun1Kernels.put(op, m = new EnumMap(Primitive.class));
Fun1Kernels kers = m.get(prim);
if (kers == null) {
StringBuilder out = new StringBuilder(300);
String inPlaceName = createVectFun1Source(op, prim, out, true);
String notInPlaceName = createVectFun1Source(op, prim, out, false);
CLProgram prog = getContext().createProgram(out.toString()).build();
kers = new Fun1Kernels();
kers.inPlace = prog.createKernel(inPlaceName);
kers.notInPlace = prog.createKernel(notInPlaceName);
m.put(prim, kers);
}
return inPlace ? kers.inPlace : kers.notInPlace;
}
static class PrimitiveTrio extends Pair> {
public PrimitiveTrio(Primitive a, Primitive b, Primitive c) {
super(a, new Pair(b, c));
}
}
private EnumMap> fun2Kernels = new EnumMap>(Fun2.class);
public synchronized CLKernel getKernel(Fun2 op, Primitive prim) throws CLBuildException {
return getKernel(op, prim, prim, prim);
}
public synchronized CLKernel getKernel(Fun2 op, Primitive prim1, Primitive prim2, Primitive primOut) throws CLBuildException {
Map m = fun2Kernels.get(op);
if (m == null)
fun2Kernels.put(op, m = new HashMap());
PrimitiveTrio key = new PrimitiveTrio(prim1, prim2, primOut);
CLKernel ker = m.get(key);
if (ker == null) {
StringBuilder out = new StringBuilder(300);
String name = createVectFun2Source(op, prim1, prim2, primOut, out);
CLProgram prog = getContext().createProgram(out.toString()).build();
ker = prog.createKernel(name);
m.put(key, ker);
}
return ker;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy