All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.nativelibs4java.opencl.util.ReductionUtils Maven / Gradle / Ivy

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */

package com.nativelibs4java.opencl.util;

import com.nativelibs4java.opencl.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.Buffer;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

import com.nativelibs4java.util.IOUtils;
import com.nativelibs4java.util.Pair;

import org.bridj.Pointer;
import org.bridj.Platform;
import static org.bridj.Pointer.*;

/**
 * Parallel reduction utils (max, min, sum and product computations on OpenCL buffers of any type)
 * @author Olivier
 */
public class ReductionUtils {
	static final int DEFAULT_MAX_REDUCTION_SIZE = 4;
	
    static String source;
    static final String sourcePath = ReductionUtils.class.getPackage().getName().replace('.', '/') + "/" + "Reduction.c";
    static synchronized String getSource() throws IOException {
        InputStream in = Platform.getClassLoader(ReductionUtils.class).getResourceAsStream(sourcePath);
        if (in == null)
            throw new FileNotFoundException(sourcePath);
        return source = IOUtils.readText(in);
    }

    public enum Operation {
        Add,
        Multiply,
        Min,
        Max;
    }
    
    public static Pair> getReductionCodeAndMacros(Operation op, OpenCLType valueType, int channels) throws IOException {
        Map macros = new LinkedHashMap();
        String cType = valueType.toCType() + (channels == 1 ? "" : channels);
        macros.put("OPERAND_TYPE", cType);
        String operation, seed;
        switch (op) {
            case Add:
                operation = "_add_";
                seed = "0";
                break;
            case Multiply:
                operation = "_mul_";
                seed = "1";
                break;
            case Min:
                operation = "_min_";
                switch (valueType) {
                    case Int:
                        seed = Integer.MAX_VALUE + "";
                        break;
                    case Long:
                        seed = Long.MAX_VALUE + "LL";
                        break;
                    case Short:
                        seed = Short.MAX_VALUE + "";
                        break;
                    case Float:
                        seed = "MAXFLOAT";
                        break;
                    case Double:
                        seed = "MAXDOUBLE";
                        break;
                    default:
                        throw new IllegalArgumentException("Unhandled seed type: " + valueType);
                }
                
                break;
            case Max:
                operation = "_max_";
                switch (valueType) {
                    case Int:
                        seed = Integer.MIN_VALUE + "";
                        break;
                    case Long:
                        seed = Long.MIN_VALUE + "LL";
                        break;
                    case Short:
                        seed = Short.MIN_VALUE + "";
                        break;
                    case Float:
                        seed = "-MAXFLOAT";
                        break;
                    case Double:
                        seed = "-MAXDOUBLE";
                        break;
                    default:
                        throw new IllegalArgumentException("Unhandled seed type: " + valueType);
                }
                break;
            default:
                throw new IllegalArgumentException("Unhandled operation: " + op);
        }
        macros.put("OPERATION", operation);
        macros.put("SEED", seed);
        return new Pair>(getSource(), macros);
    }
    public interface Reductor {
        /** Number of independent channels of the reductor */
        public int getChannels();
        public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor);
        public Pointer reduce(CLQueue queue, CLBuffer input, long inputLength, int maxReductionSize, CLEvent... eventsToWaitFor);
        public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, Pointer output, int maxReductionSize, CLEvent... eventsToWaitFor);
        /** 
         * Return the result of the reduction operation (with one value per channel).
         */
        public Pointer reduce(CLQueue queue, CLBuffer input, CLEvent... eventsToWaitFor);
        
    }
    /*public interface WeightedReductor {
        public CLEvent reduce(CLQueue queue, CLBuffer input, CLBuffer weights, long inputLength, B output, int maxReductionSize, CLEvent... eventsToWaitFor);
        public CLEvent reduce(CLQueue queue, CLBuffer input, CLBuffer weights, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor);
    }*/
    static int getNextPowerOfTwo(int i) {
        int shifted = 0;
        boolean lost = false;
        for (;;) {
            int next = i >> 1;
            if (next == 0) {
                if (lost)
                    return 1 << (shifted + 1);
                else
                    return 1 << shifted;
            }
            lost = lost || (next << 1 != i);
            shifted++;
            i = next;
        }
    }

    /**
     * Create a reductor for the provided operation and primitive type (on the provided number of independent channels).
* Channels are reduced independently, so that with 2 channels the max of elements { (1, 30), (2, 20), (3, 10) } would be (3, 30). */ public static Reductor createReductor(final CLContext context, Operation op, final OpenCLType valueType, final int valueChannels) throws CLBuildException { try { Pair> codeAndMacros = getReductionCodeAndMacros(op, valueType, valueChannels); CLProgram program = context.createProgram(codeAndMacros.getFirst()); program.defineMacros(codeAndMacros.getValue()); program.build(); CLKernel[] kernels = program.createKernels(); if (kernels.length != 1) throw new RuntimeException("Expected 1 kernel, found : " + kernels.length); final CLKernel kernel = kernels[0]; return new Reductor() { @Override public int getChannels() { return valueChannels; } @SuppressWarnings("unchecked") public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, Pointer output, int maxReductionSize, CLEvent... eventsToWaitFor) { Pair, CLEvent[]> outAndEvts = reduceHelper(queue, input, (int)inputLength, maxReductionSize, eventsToWaitFor); return outAndEvts.getFirst().read(queue, 0, valueChannels, output, false, outAndEvts.getSecond()); } @Override public Pointer reduce(CLQueue queue, CLBuffer input, long inputLength, int maxReductionSize, CLEvent... eventsToWaitFor) { Pointer output = Pointer.allocateArray((Class)valueType.type, valueChannels).order(context.getByteOrder()); CLEvent evt = reduce(queue, input, inputLength, output, maxReductionSize, eventsToWaitFor); //queue.finish(); //TODO evt.waitFor(); return output; } @Override public Pointer reduce(CLQueue queue, CLBuffer input, CLEvent... eventsToWaitFor) { return reduce(queue, input, input.getElementCount(), DEFAULT_MAX_REDUCTION_SIZE, eventsToWaitFor); } @Override public CLEvent reduce(CLQueue queue, CLBuffer input, long inputLength, CLBuffer output, int maxReductionSize, CLEvent... eventsToWaitFor) { Pair, CLEvent[]> outAndEvts = reduceHelper(queue, input, (int)inputLength, maxReductionSize, eventsToWaitFor); return outAndEvts.getFirst().copyTo(queue, 0, valueChannels, output, 0, outAndEvts.getSecond()); } @SuppressWarnings("unchecked") public Pair, CLEvent[]> reduceHelper(CLQueue queue, CLBuffer input, int inputLength, int maxReductionSize, CLEvent... eventsToWaitFor) { if (inputLength == 1) { return new Pair, CLEvent[]>(input, new CLEvent[0]); } if (inputLength == 1) { return new Pair, CLEvent[]>(input, new CLEvent[0]); } CLBuffer[] tempBuffers = new CLBuffer[2]; int depth = 0; CLBuffer currentOutput = null; CLEvent[] eventsArr = new CLEvent[1]; int[] blockCountArr = new int[1]; int maxWIS = (int)queue.getDevice().getMaxWorkItemSizes()[0]; while (inputLength > 1) { int blocksInCurrentDepth = inputLength / maxReductionSize; if (inputLength > blocksInCurrentDepth * maxReductionSize) blocksInCurrentDepth++; int iOutput = depth & 1; CLBuffer currentInput = depth == 0 ? input : tempBuffers[iOutput ^ 1]; currentOutput = (CLBuffer)tempBuffers[iOutput]; if (currentOutput == null) currentOutput = (CLBuffer)(tempBuffers[iOutput] = context.createBuffer(CLMem.Usage.InputOutput, valueType.type, blocksInCurrentDepth * valueChannels)); synchronized (kernel) { kernel.setArgs(currentInput, (long)blocksInCurrentDepth, (long)inputLength, (long)maxReductionSize, currentOutput); int workgroupSize = blocksInCurrentDepth; if (workgroupSize == 1) workgroupSize = 2; blockCountArr[0] = workgroupSize; eventsArr[0] = kernel.enqueueNDRange(queue, blockCountArr, null, eventsToWaitFor); } eventsToWaitFor = eventsArr; inputLength = blocksInCurrentDepth; depth++; } return new Pair, CLEvent[]>(currentOutput, eventsToWaitFor); } }; } catch (IOException ex) { Logger.getLogger(ReductionUtils.class.getName()).log(Level.SEVERE, null, ex); throw new RuntimeException("Failed to create a " + op + " reductor for type " + valueType + valueChannels, ex); } } }