![JAR search and dependency download from the Maven repository](/logo.png)
com.nativelibs4java.opencl.util.ReductionUtils Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of javacl Show documentation
Show all versions of javacl Show documentation
JavaCL is an Object-Oriented API that makes the C OpenCL API available to Java in a very natural way.
It hides away the complexity of cross-platform C bindings, has a clean OO design (with generics, Java enums, NIO buffers, fully typed exceptions...), provides high-level features (OpenGL-interop, array reductions) and comes with samples and demos.
For more info, please visit http://code.google.com/p/nativelibs4java/wiki/OpenCL.
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.util;
import com.nativelibs4java.opencl.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.Buffer;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.nativelibs4java.util.IOUtils;
import com.nativelibs4java.util.Pair;
import org.bridj.Pointer;
import org.bridj.Platform;
import static org.bridj.Pointer.*;
/**
* Parallel reduction utils (max, min, sum and product computations on OpenCL buffers of any type)
* @author Olivier
*/
public class ReductionUtils {
static final int DEFAULT_MAX_REDUCTION_SIZE = 4;
static String source;
static final String sourcePath = ReductionUtils.class.getPackage().getName().replace('.', '/') + "/" + "Reduction.c";
static synchronized String getSource() throws IOException {
InputStream in = Platform.getClassLoader(ReductionUtils.class).getResourceAsStream(sourcePath);
if (in == null)
throw new FileNotFoundException(sourcePath);
return source = IOUtils.readText(in);
}
public enum Operation {
Add,
Multiply,
Min,
Max;
}
public static Pair> getReductionCodeAndMacros(Operation op, OpenCLType valueType, int channels) throws IOException {
Map macros = new LinkedHashMap();
String cType = valueType.toCType() + (channels == 1 ? "" : channels);
macros.put("OPERAND_TYPE", cType);
String operation, seed;
switch (op) {
case Add:
operation = "_add_";
seed = "0";
break;
case Multiply:
operation = "_mul_";
seed = "1";
break;
case Min:
operation = "_min_";
switch (valueType) {
case Int:
seed = Integer.MAX_VALUE + "";
break;
case Long:
seed = Long.MAX_VALUE + "LL";
break;
case Short:
seed = Short.MAX_VALUE + "";
break;
case Float:
seed = "MAXFLOAT";
break;
case Double:
seed = "MAXDOUBLE";
break;
default:
throw new IllegalArgumentException("Unhandled seed type: " + valueType);
}
break;
case Max:
operation = "_max_";
switch (valueType) {
case Int:
seed = Integer.MIN_VALUE + "";
break;
case Long:
seed = Long.MIN_VALUE + "LL";
break;
case Short:
seed = Short.MIN_VALUE + "";
break;
case Float:
seed = "-MAXFLOAT";
break;
case Double:
seed = "-MAXDOUBLE";
break;
default:
throw new IllegalArgumentException("Unhandled seed type: " + valueType);
}
break;
default:
throw new IllegalArgumentException("Unhandled operation: " + op);
}
macros.put("OPERATION", operation);
macros.put("SEED", seed);
return new Pair>(getSource(), macros);
}
public interface Reductor {
/** Number of independent channels of the reductor */
public int getChannels();
public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor);
public Pointer reduce(CLQueue queue, CLBuffer input, long inputLength, int maxReductionSize, CLEvent... eventsToWaitFor);
public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, Pointer output, int maxReductionSize, CLEvent... eventsToWaitFor);
/**
* Return the result of the reduction operation (with one value per channel).
*/
public Pointer reduce(CLQueue queue, CLBuffer input, CLEvent... eventsToWaitFor);
}
/*public interface WeightedReductor {
public CLEvent reduce(CLQueue queue, CLBuffer input, CLBuffer weights, long inputLength, B output, int maxReductionSize, CLEvent... eventsToWaitFor);
public CLEvent reduce(CLQueue queue, CLBuffer input, CLBuffer weights, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor);
}*/
static int getNextPowerOfTwo(int i) {
int shifted = 0;
boolean lost = false;
for (;;) {
int next = i >> 1;
if (next == 0) {
if (lost)
return 1 << (shifted + 1);
else
return 1 << shifted;
}
lost = lost || (next << 1 != i);
shifted++;
i = next;
}
}
/**
* Create a reductor for the provided operation and primitive type (on the provided number of independent channels).
* Channels are reduced independently, so that with 2 channels the max of elements { (1, 30), (2, 20), (3, 10) } would be (3, 30).
*/
public static Reductor createReductor(final CLContext context, Operation op, final OpenCLType valueType, final int valueChannels) throws CLBuildException {
try {
Pair> codeAndMacros = getReductionCodeAndMacros(op, valueType, valueChannels);
CLProgram program = context.createProgram(codeAndMacros.getFirst());
program.defineMacros(codeAndMacros.getValue());
program.build();
CLKernel[] kernels = program.createKernels();
if (kernels.length != 1)
throw new RuntimeException("Expected 1 kernel, found : " + kernels.length);
final CLKernel kernel = kernels[0];
return new Reductor() {
@Override
public int getChannels() {
return valueChannels;
}
@SuppressWarnings("unchecked")
public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, Pointer output, int maxReductionSize, CLEvent... eventsToWaitFor) {
Pair, CLEvent[]> outAndEvts = reduceHelper(queue, input, (int)inputLength, maxReductionSize, eventsToWaitFor);
return outAndEvts.getFirst().read(queue, 0, valueChannels, output, false, outAndEvts.getSecond());
}
@Override
public Pointer reduce(CLQueue queue, CLBuffer input, long inputLength, int maxReductionSize, CLEvent... eventsToWaitFor) {
Pointer output = Pointer.allocateArray((Class)valueType.type, valueChannels).order(context.getByteOrder());
CLEvent evt = reduce(queue, input, inputLength, output, maxReductionSize, eventsToWaitFor);
//queue.finish();
//TODO
evt.waitFor();
return output;
}
@Override
public Pointer reduce(CLQueue queue, CLBuffer input, CLEvent... eventsToWaitFor) {
return reduce(queue, input, input.getElementCount(), DEFAULT_MAX_REDUCTION_SIZE, eventsToWaitFor);
}
@Override
public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor) {
Pair, CLEvent[]> outAndEvts = reduceHelper(queue, input, (int)inputLength, maxReductionSize, eventsToWaitFor);
return outAndEvts.getFirst().copyTo(queue, 0, valueChannels, output, 0, outAndEvts.getSecond());
}
@SuppressWarnings("unchecked")
public Pair, CLEvent[]> reduceHelper(CLQueue queue, CLBuffer input, int inputLength, int maxReductionSize, CLEvent... eventsToWaitFor) {
if (inputLength == 1) {
return new Pair, CLEvent[]>(input, new CLEvent[0]);
}
if (inputLength == 1) {
return new Pair, CLEvent[]>(input, new CLEvent[0]);
}
CLBuffer>[] tempBuffers = new CLBuffer>[2];
int depth = 0;
CLBuffer currentOutput = null;
CLEvent[] eventsArr = new CLEvent[1];
int[] blockCountArr = new int[1];
int maxWIS = (int)queue.getDevice().getMaxWorkItemSizes()[0];
while (inputLength > 1) {
int blocksInCurrentDepth = inputLength / maxReductionSize;
if (inputLength > blocksInCurrentDepth * maxReductionSize)
blocksInCurrentDepth++;
int iOutput = depth & 1;
CLBuffer> currentInput = depth == 0 ? input : tempBuffers[iOutput ^ 1];
currentOutput = (CLBuffer)tempBuffers[iOutput];
if (currentOutput == null)
currentOutput = (CLBuffer)(tempBuffers[iOutput] = context.createBuffer(CLMem.Usage.InputOutput, valueType.type, blocksInCurrentDepth * valueChannels));
synchronized (kernel) {
kernel.setArgs(currentInput, (long)blocksInCurrentDepth, (long)inputLength, (long)maxReductionSize, currentOutput);
int workgroupSize = blocksInCurrentDepth;
if (workgroupSize == 1)
workgroupSize = 2;
blockCountArr[0] = workgroupSize;
eventsArr[0] = kernel.enqueueNDRange(queue, blockCountArr, null, eventsToWaitFor);
}
eventsToWaitFor = eventsArr;
inputLength = blocksInCurrentDepth;
depth++;
}
return new Pair, CLEvent[]>(currentOutput, eventsToWaitFor);
}
};
} catch (IOException ex) {
Logger.getLogger(ReductionUtils.class.getName()).log(Level.SEVERE, null, ex);
throw new RuntimeException("Failed to create a " + op + " reductor for type " + valueType + valueChannels, ex);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy