All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jogamp.opencl.util.concurrent.CLCommandQueuePool Maven / Gradle / Ivy

The newest version!
/*
 * Created on Tuesday, May 03 2011
 */
package com.jogamp.opencl.util.concurrent;

import com.jogamp.common.util.InterruptSource;
import com.jogamp.opencl.CLCommandQueue;
import com.jogamp.opencl.CLDevice;
import com.jogamp.opencl.CLResource;
import com.jogamp.opencl.util.CLMultiContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

/**
 * A multithreaded, fixed size pool of OpenCL command queues.
 * It serves as a multiplexer distributing tasks over N queues usually run on N devices.
 * The usage of this pool is similar to {@link ExecutorService} but it uses {@link CLTask}s
 * instead of {@link Callable}s and provides a per-queue context for resource sharing across all tasks of one queue.
 * @author Michael Bien
 */
public class CLCommandQueuePool implements CLResource {

    private List contexts;
    private ExecutorService excecutor;
    private FinishAction finishAction = FinishAction.DO_NOTHING;
    private boolean released;

    private CLCommandQueuePool(final CLQueueContextFactory factory, final Collection queues) {
        this.contexts = initContexts(queues, factory);
        initExecutor();
    }

    private List initContexts(final Collection queues, final CLQueueContextFactory factory) {
        final List newContexts = new ArrayList(queues.size());

        int index = 0;
        for (final CLCommandQueue queue : queues) {

            CLQueueContext old = null;
            if(this.contexts != null && !this.contexts.isEmpty()) {
                old = this.contexts.get(index++);
                old.release();
            }

            newContexts.add(factory.setup(queue, old));
        }
        return newContexts;
    }

    private void initExecutor() {
        this.excecutor = Executors.newFixedThreadPool(contexts.size(), new QueueThreadFactory(contexts));
    }

    public static  CLCommandQueuePool create(final CLQueueContextFactory factory, final CLMultiContext mc, final CLCommandQueue.Mode... modes) {
        return create(factory, mc.getDevices(), modes);
    }

    public static  CLCommandQueuePool create(final CLQueueContextFactory factory, final Collection devices, final CLCommandQueue.Mode... modes) {
        final List queues = new ArrayList(devices.size());
        for (final CLDevice device : devices) {
            queues.add(device.createCommandQueue(modes));
        }
        return create(factory, queues);
    }

    public static  CLCommandQueuePool create(final CLQueueContextFactory factory, final Collection queues) {
        return new CLCommandQueuePool(factory, queues);
    }

    /**
     * Submits this task to the pool for execution returning its {@link Future}.
     * @see ExecutorService#submit(java.util.concurrent.Callable)
     */
    public  Future submit(final CLTask task) {
        return excecutor.submit(new TaskWrapper(task, finishAction));
    }

    /**
     * Submits all tasks to the pool for execution and returns their {@link Future}.
     * Calls {@link #submit(com.jogamp.opencl.util.concurrent.CLTask)} for every task.
     */
    public  List> submitAll(final Collection> tasks) {
        final List> futures = new ArrayList>(tasks.size());
        for (final CLTask task : tasks) {
            futures.add(submit(task));
        }
        return futures;
    }

    /**
     * Submits all tasks to the pool for immediate execution (blocking) and returns their {@link Future} holding the result.
     * @see ExecutorService#invokeAll(java.util.Collection)
     */
    public  List> invokeAll(final Collection> tasks) throws InterruptedException {
        final List> wrapper = wrapTasks(tasks);
        return excecutor.invokeAll(wrapper);
    }

    /**
     * Submits all tasks to the pool for immediate execution (blocking) and returns their {@link Future} holding the result.
     * @see ExecutorService#invokeAll(java.util.Collection, long, java.util.concurrent.TimeUnit)
     */
    public  List> invokeAll(final Collection> tasks, final long timeout, final TimeUnit unit) throws InterruptedException {
        final List> wrapper = wrapTasks(tasks);
        return excecutor.invokeAll(wrapper, timeout, unit);
    }

    private  List> wrapTasks(final Collection> tasks) {
        final List> wrapper = new ArrayList>(tasks.size());
        for (final CLTask task : tasks) {
            if(task == null) {
                throw new NullPointerException("at least one task was null");
            }
            wrapper.add(new TaskWrapper(task, finishAction));
        }
        return wrapper;
    }

    /**
     * Switches the context of all queues - this operation can be expensive.
     * Blocks until all tasks finish and sets up a new context for all queues.
     * @return this
     */
    public CLCommandQueuePool switchContext(final CLQueueContextFactory factory) {

        excecutor.shutdown();
        finishQueues(); // just to be sure

        contexts = initContexts(getQueues(), factory);
        initExecutor();
        return this;
    }

    /**
     * Calls {@link CLCommandQueue#flush()} on all queues.
     */
    public void flushQueues() {
        for (final CLQueueContext context : contexts) {
            context.queue.flush();
        }
    }

    /**
     * Calls {@link CLCommandQueue#finish()} on all queues.
     */
    public void finishQueues() {
        for (final CLQueueContext context : contexts) {
            context.queue.finish();
        }
    }

    /**
     * Releases all queues.
     */
    @Override
    public void release() {
        if(released) {
            throw new RuntimeException(getClass().getSimpleName()+" already released");
        }
        released = true;
        excecutor.shutdown();
        for (final CLQueueContext context : contexts) {
            context.queue.finish().release();
            context.release();
        }
    }

    /**
     * Returns the command queues used in this pool.
     */
    public List getQueues() {
        final List queues = new ArrayList(contexts.size());
        for (final CLQueueContext context : contexts) {
            queues.add(context.queue);
        }
        return queues;
    }

    /**
     * Returns the size of this pool (number of command queues).
     */
    public int getSize() {
        return contexts.size();
    }

    public FinishAction getFinishAction() {
        return finishAction;
    }

    @Override
    public boolean isReleased() {
        return released;
    }

    /**
     * Sets the action which is run after every completed task.
     * This is mainly intended for debugging, default value is {@link FinishAction#DO_NOTHING}.
     */
    public void setFinishAction(final FinishAction action) {
        this.finishAction = action;
    }

    @Override
    public String toString() {
        return getClass().getSimpleName()+" [queues: "+contexts.size()+" on finish: "+finishAction+"]";
    }

    private static class QueueThreadFactory implements ThreadFactory {

        private final List context;
        private int index;

        private QueueThreadFactory(final List queues) {
            this.context = queues;
            this.index = 0;
        }

        public synchronized Thread newThread(final Runnable runnable) {

            final SecurityManager sm = System.getSecurityManager();
            final ThreadGroup group = (sm != null) ? sm.getThreadGroup() : Thread.currentThread().getThreadGroup();

            final CLQueueContext queue = context.get(index);
            final QueueThread thread = new QueueThread(group, runnable, queue, index++);
            thread.setDaemon(true);

            return thread;
        }

    }

    private static class QueueThread extends InterruptSource.Thread {
        private final CLQueueContext context;
        public QueueThread(final ThreadGroup group, final Runnable runnable, final CLQueueContext context, final int index) {
            super(group, runnable, "queue-worker-thread-"+index+"["+context+"]");
            this.context = context;
        }
    }

    private static class TaskWrapper implements Callable {

        private final CLTask task;
        private final FinishAction mode;

        public TaskWrapper(final CLTask task, final FinishAction mode) {
            this.task = task;
            this.mode = mode;
        }

        public R call() throws Exception {
            final CLQueueContext context = ((QueueThread)Thread.currentThread()).context;
            // we make sure to only wrap tasks on the correct kind of thread, so this
            // shouldn't fail (trying to genericize QueueThread properly becomes tricky)
            @SuppressWarnings("unchecked")
            final
            R result = task.execute((C)context);
            if(mode.equals(FinishAction.FLUSH)) {
                context.queue.flush();
            }else if(mode.equals(FinishAction.FINISH)) {
                context.queue.finish();
            }
            return result;
        }

    }

    /**
     * The action executed after a task completes.
     */
    public enum FinishAction {

        /**
         * Does nothing, the task is responsible to make sure all computations
         * have finished when the task finishes
         */
        DO_NOTHING,

        /**
         * Flushes the queue on task completion.
         */
        FLUSH,

        /**
         * Finishes the queue on task completion.
         */
        FINISH
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy