com.nativelibs4java.opencl.util.ReductionUtils Maven / Gradle / Ivy
Show all versions of javacl-jna Show documentation
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.util;
import com.nativelibs4java.opencl.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.Buffer;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.nativelibs4java.util.IOUtils;
import com.nativelibs4java.util.NIOUtils;
import com.ochafik.util.listenable.Pair;
/**
*
* @author Olivier
*/
public class ReductionUtils {
static String source;
static final String sourcePath = ReductionUtils.class.getPackage().getName().replace('.', '/') + "/" + "Reduction.c";
static synchronized String getSource() throws IOException {
InputStream in = ReductionUtils.class.getClassLoader().getResourceAsStream(sourcePath);
if (in == null)
throw new FileNotFoundException(sourcePath);
return source = IOUtils.readText(in);
}
public enum Operation {
Add,
Multiply,
Min,
Max;
}
public static Pair> getReductionCodeAndMacros(Operation op, OpenCLType valueType, int channels) throws IOException {
Map macros = new LinkedHashMap();
String cType = valueType.toCType() + (channels == 1 ? "" : channels);
macros.put("OPERAND_TYPE", cType);
String operation, seed;
switch (op) {
case Add:
operation = "_add_";
seed = "0";
break;
case Multiply:
operation = "_mul_";
seed = "1";
break;
case Min:
operation = "_min_";
switch (valueType) {
case Int:
seed = Integer.MAX_VALUE + "";
break;
case Long:
seed = Long.MAX_VALUE + "LL";
break;
case Short:
seed = Short.MAX_VALUE + "";
break;
case Float:
seed = "MAXFLOAT";
break;
case Double:
seed = "MAXDOUBLE";
break;
default:
throw new IllegalArgumentException("Unhandled seed type: " + valueType);
}
break;
case Max:
operation = "_max_";
switch (valueType) {
case Int:
seed = Integer.MIN_VALUE + "";
break;
case Long:
seed = Long.MIN_VALUE + "LL";
break;
case Short:
seed = Short.MIN_VALUE + "";
break;
case Float:
seed = "-MAXFLOAT";
break;
case Double:
seed = "-MAXDOUBLE";
break;
default:
throw new IllegalArgumentException("Unhandled seed type: " + valueType);
}
break;
default:
throw new IllegalArgumentException("Unhandled operation: " + op);
}
macros.put("OPERATION", operation);
macros.put("SEED", seed);
return new Pair>(getSource(), macros);
}
public interface Reductor {
public CLEvent reduce(CLQueue queue, CLBuffer
input, long inputLength, B output, int maxReductionSize, CLEvent... eventsToWaitFor);
public B reduce(CLQueue queue, CLBuffer
input, long inputLength, int maxReductionSize, CLEvent... eventsToWaitFor);
public CLEvent reduce(CLQueue queue, CLBuffer
input, long inputLength, CLBuffer
output, int maxReductionSize, CLEvent... eventsToWaitFor);
}
/*public interface WeightedReductor {
public CLEvent reduce(CLQueue queue, CLBuffer input, CLBuffer weights, long inputLength, B output, int maxReductionSize, CLEvent... eventsToWaitFor);
public CLEvent reduce(CLQueue queue, CLBuffer input, CLBuffer weights, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor);
}*/
static int getNextPowerOfTwo(int i) {
int shifted = 0;
boolean lost = false;
for (;;) {
int next = i >> 1;
if (next == 0) {
if (lost)
return 1 << (shifted + 1);
else
return 1 << shifted;
}
lost = lost || (next << 1 != i);
shifted++;
i = next;
}
}
public static Reductor
createReductor(final CLContext context, Operation op, OpenCLType valueType, final int valueChannels) throws CLBuildException {
try {
Pair> codeAndMacros = getReductionCodeAndMacros(op, valueType, valueChannels);
CLProgram program = context.createProgram(codeAndMacros.getFirst());
program.defineMacros(codeAndMacros.getValue());
program.build();
CLKernel[] kernels = program.createKernels();
if (kernels.length != 1)
throw new RuntimeException("Expected 1 kernel, found : " + kernels.length);
final CLKernel kernel = kernels[0];
return new Reductor() {
@SuppressWarnings("unchecked")
@Override
public CLEvent reduce(CLQueue queue, CLBuffer
input, long inputLength, B output, int maxReductionSize, CLEvent... eventsToWaitFor) {
input.getBufferClass().cast(output);
Pair, CLEvent[]> outAndEvts = reduceHelper(queue, input, (int)inputLength, maxReductionSize, eventsToWaitFor);
return outAndEvts.getFirst().read(queue, 0, valueChannels, output, false, outAndEvts.getSecond());
}
@Override
public B reduce(CLQueue queue, CLBuffer input, long inputLength, int maxReductionSize, CLEvent... eventsToWaitFor) {
B output = (B)NIOUtils.directBuffer((int)inputLength, context.getByteOrder(), (Class)input.getBufferClass());
CLEvent evt = reduce(queue, input, inputLength, output, maxReductionSize, eventsToWaitFor);
//queue.finish();
//TODO
evt.waitFor();
return output;
}
@Override
public CLEvent reduce(CLQueue queue, CLBuffer
input, long inputLength, CLBuffer
output, int maxReductionSize, CLEvent... eventsToWaitFor) {
Pair, CLEvent[]> outAndEvts = reduceHelper(queue, input, (int)inputLength, maxReductionSize, eventsToWaitFor);
return outAndEvts.getFirst().copyTo(queue, 0, valueChannels, output, 0, outAndEvts.getSecond());
}
@SuppressWarnings("unchecked")
public Pair, CLEvent[]> reduceHelper(CLQueue queue, CLBuffer input, int inputLength, int maxReductionSize, CLEvent... eventsToWaitFor) {
if (inputLength == 1) {
return new Pair, CLEvent[]>(input, new CLEvent[0]);
}
CLBuffer>[] tempBuffers = new CLBuffer>[2];
int depth = 0;
CLBuffer currentOutput = null;
CLEvent[] eventsArr = new CLEvent[1];
int[] blockCountArr = new int[1];
int maxWIS = (int)queue.getDevice().getMaxWorkItemSizes()[0];
while (inputLength > 1) {
int blocksInCurrentDepth = inputLength / maxReductionSize;
if (inputLength > blocksInCurrentDepth * maxReductionSize)
blocksInCurrentDepth++;
int iOutput = depth & 1;
CLBuffer> currentInput = depth == 0 ? input : tempBuffers[iOutput ^ 1];
currentOutput = (CLBuffer
)tempBuffers[iOutput];
if (currentOutput == null)
currentOutput = (CLBuffer
)(tempBuffers[iOutput] = context.createBuffer(CLMem.Usage.InputOutput, input.getElementClass(), blocksInCurrentDepth * valueChannels));
synchronized (kernel) {
kernel.setArgs(currentInput, (long)blocksInCurrentDepth, (long)inputLength, (long)maxReductionSize, currentOutput);
int workgroupSize = blocksInCurrentDepth;
if (workgroupSize == 1)
workgroupSize = 2;
blockCountArr[0] = workgroupSize;
eventsArr[0] = kernel.enqueueNDRange(queue, blockCountArr, null, eventsToWaitFor);
}
eventsToWaitFor = eventsArr;
inputLength = blocksInCurrentDepth;
depth++;
}
return new Pair, CLEvent[]>(currentOutput, eventsToWaitFor);
}
};
} catch (IOException ex) {
Logger.getLogger(ReductionUtils.class.getName()).log(Level.SEVERE, null, ex);
throw new RuntimeException("Failed to create a " + op + " reductor for type " + valueType + valueChannels, ex);
}
}
}