All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.parallelism.MagicQueue Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.parallelism;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Limited Queue implementation, suited for multi-gpu prefetch.
 *
 * Basic idea is simple: DataSets are coming from DataSetIterator, and their device location is unknown.
 * So, for better performance DataSets should be transparently moved to the devices where they will be used later, and this should be done in background.
 *
 *
 * PLEASE NOTE: This class is pending removal, since better behavior was implemented as InterleavedCallback for AsyncDataSetIterator
 * @author [email protected]
 */
@Slf4j
@Deprecated
public class MagicQueue implements BlockingQueue {
    public enum Mode {
        THREADED, SEQUENTIAL,
    }

    public enum Type {
        DS, MDS
    }

    protected final List> backingQueues;
    protected final AtomicInteger nextBucket = new AtomicInteger(0);
    protected final int numberOfBuckets;
    protected final List handlers;
    protected int capacity = 10;
    protected Mode mode = Mode.THREADED;
    protected Type type = null;
    protected AtomicInteger interleavedCounter = new AtomicInteger(0);
    protected AtomicInteger interleavedPutter = new AtomicInteger(0);

    protected AtomicLong cntPut = new AtomicLong(0);
    protected AtomicLong cntGet = new AtomicLong(0);



    protected MagicQueue(int numberOfFlows, int capacity, Type type) {
        backingQueues = new ArrayList<>();
        this.type = type;
        this.capacity = capacity;
        handlers = new ArrayList<>();
        if (numberOfFlows > 1) {
            for (int i = 0; i < numberOfFlows; i++) {
                LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity);
                backingQueues.add(queue);

                QueueHandler handler = new QueueHandler(queue, capacity, i, type);

                Nd4j.getAffinityManager().attachThreadToDevice(handler, i);

                handler.start();
                handlers.add(handler);
            }
        } else {
            LinkedBlockingQueue queue = new LinkedBlockingQueue<>();
            backingQueues.add(queue);
        }

        numberOfBuckets = numberOfFlows;
    }

    /**
     * This method returns average queue size for all devices
     * @return
     */
    @Override
    public int size() {
        if (mode == Mode.THREADED) {
            if (numberOfBuckets > 1) {
                long cnt = 0;
                for (int i = 0; i < numberOfBuckets; i++) {
                    cnt += backingQueues.get(i).size();
                }

                return (int) Math.floor(cnt / numberOfBuckets);
            } else
                return backingQueues.get(0).size();
        } else {
            return (int) (cntPut.get() - cntGet.get());
        }
    }

    protected int size(int deviceId) {
        if (deviceId >= backingQueues.size())
            throw new RuntimeException("DeviceID exceeds number of actual backing queues");

        return backingQueues.get(deviceId).size();
    }

    @Override
    public boolean isEmpty() {
        return size() < 1;
    }

    /**
     * This method isn't supported
     * @param o
     * @return
     */
    @Override
    public boolean contains(Object o) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int drainTo(Collection c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int drainTo(Collection c, int maxElements) {
        throw new UnsupportedOperationException();
    }

    /**
     * This method isn't supported
     * @return
     */
    @Override
    public Iterator iterator() {
        throw new UnsupportedOperationException();
    }

    /**
     * This method isn't supported
     * @return
     */
    @Override
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    /**
     * This method isn't supported
     * @param a
     * @param 
     * @return
     */
    @Override
    public  T[] toArray(T[] a) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean add(T dataSet) {
        cntPut.incrementAndGet();
        if (numberOfBuckets > 1) {
            synchronized (this) {
                if (nextBucket.get() >= backingQueues.size())
                    nextBucket.set(0);
            }
            handlers.get(nextBucket.getAndIncrement()).put(dataSet);

            return true;
        } else {
            backingQueues.get(0).add(dataSet);
            return true;
        }
    }

    /**
     * This method isn't supported
     * @param o
     * @return
     */
    @Override
    public boolean remove(Object o) {
        throw new UnsupportedOperationException();
    }

    /**
     * This method isn't supported
     * @param c
     * @return
     */
    @Override
    public boolean containsAll(Collection c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean addAll(Collection c) {
        for (T ds : c) {
            boolean result = add(ds);

            if (!result)
                return result;
        }

        return true;
    }

    /**
     * This method isn't supported
     * @param c
     * @return
     */
    @Override
    public boolean removeAll(Collection c) {
        throw new UnsupportedOperationException();
    }

    /**
     * This method isn't supported
     * @param c
     * @return
     */
    @Override
    public boolean retainAll(Collection c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void clear() {
        for (Queue queue : backingQueues) {
            queue.clear();
        }

        cntPut.set(0);
        cntGet.set(0);
    }

    @Override
    public boolean offer(T dataSet) {
        if (numberOfBuckets > 1) {
            int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
            boolean res = backingQueues.get(deviceId).offer(dataSet);

            if (res)
                cntPut.incrementAndGet();

            return res;
        } else {
            boolean result = backingQueues.get(0).offer(dataSet);

            if (result)
                cntPut.incrementAndGet();

            return result;
        }
    }

    @Override
    public void put(T dataSet) throws InterruptedException {

        if (numberOfBuckets > 1) {
            synchronized (this) {
                if (nextBucket.get() >= backingQueues.size())
                    nextBucket.set(0);
            }

            handlers.get(nextBucket.getAndIncrement()).put(dataSet);
        } else {
            backingQueues.get(0).add(dataSet);
        }
        cntPut.incrementAndGet();
    }

    @Override
    public boolean offer(T dataSet, long timeout, TimeUnit unit) throws InterruptedException {
        if (numberOfBuckets > 1) {
            int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();

            boolean res = backingQueues.get(deviceId).offer(dataSet, timeout, unit);

            if (res)
                cntPut.incrementAndGet();

            return res;
        } else {
            boolean res = backingQueues.get(0).offer(dataSet, timeout, unit);

            if (res)
                cntPut.incrementAndGet();

            return res;
        }
    }

    @Override
    public T take() throws InterruptedException {
        try {
            if (mode == Mode.THREADED) {
                if (numberOfBuckets > 1) {
                    int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                    return backingQueues.get(deviceId).take();
                } else
                    return backingQueues.get(0).take();
            } else {
                T ds = backingQueues.get(interleavedCounter.getAndIncrement()).take();
                if (interleavedCounter.get() >= backingQueues.size())
                    interleavedCounter.set(0);

                return ds;
            }
        } catch (InterruptedException e) {
            throw e;
        } finally {
            cntGet.incrementAndGet();
        }
    }

    @Override
    public T remove() {
        throw new UnsupportedOperationException();
    }


    /**
     * This method is supposed to be called from managed thread, attached to specific device.
     * It returns 1 DataSet element from head of the queue, and deletes that element from Queue.
     * If queue is empty,
     *
     * Please note: if there's nothing available in Queue - NULL will be returned
     * @param time time to wait for something appear in queue
     * @param timeUnit TimeUnit for time param
     * @return
     */
    public T poll(long time, TimeUnit timeUnit) throws InterruptedException {
        if (mode == Mode.THREADED) {
            if (numberOfBuckets > 1) {
                int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                T ds = backingQueues.get(deviceId).poll(time, timeUnit);

                if (ds != null)
                    cntGet.incrementAndGet();

                return ds;
            } else {
                T ds = backingQueues.get(0).poll(time, timeUnit);

                if (ds != null)
                    cntGet.incrementAndGet();

                return ds;
            }
        } else {
            //log.info("Trying queue_{}; queue_0: {}; queue_1: {}", interleavedCounter.get(), backingQueues.get(0).size(), backingQueues.get(1).size());

            T ds = backingQueues.get(interleavedCounter.getAndIncrement()).poll(time, timeUnit);

            if (interleavedCounter.get() >= backingQueues.size())
                interleavedCounter.set(0);

            if (ds != null)
                cntGet.incrementAndGet();

            return ds;
        }
    }

    @Override
    public int remainingCapacity() {
        return 0;
    }

    /**
     * This method is supposed to be called from managed thread, attached to specific device.
     * It returns 1 DataSet element from head of the queue, and deletes that element from Queue
     *
     * Please note: if there's nothing available in Queue - NULL will be returned
     *
     * @return
     */
    @Override
    public T poll() {
        if (mode == Mode.THREADED) {
            if (numberOfBuckets > 1) {
                int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
                T ds = backingQueues.get(deviceId).poll();
                if (ds != null)
                    cntGet.incrementAndGet();
                return ds;
            } else {
                T ds = backingQueues.get(0).poll();

                if (ds != null)
                    cntGet.incrementAndGet();

                return ds;
            }
        } else {
            T ds = backingQueues.get(interleavedCounter.getAndIncrement()).poll();

            if (interleavedCounter.get() >= backingQueues.size())
                interleavedCounter.set(0);

            if (ds != null)
                cntGet.incrementAndGet();

            return ds;
        }
    }

    @Override
    public T element() {
        throw new UnsupportedOperationException();
    }

    @Override
    public T peek() {
        throw new UnsupportedOperationException();
    }

    public static class Builder {
        private int numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();
        private int capacity = 16;
        private Mode mode = Mode.THREADED;
        private Type type = Type.DS;

        public Builder() {

        }

        /**
         *
         * @param number
         * @return
         */
        public Builder setNumberOfBuckets(int number) {
            this.numberOfBuckets = number;

            return this;
        }

        /**
         *
         * @param type
         * @return
         */
        public Builder setType(@NonNull Type type) {
            this.type = type;
            return this;
        }

        /**
         *
         * @param mode
         * @return
         */
        public Builder setMode(@NonNull Mode mode) {
            this.mode = mode;
            return this;
        }

        /**
         * This method defines, how
         *
         * @param capacityPerFlow
         * @return
         */
        public Builder setCapacityPerFlow(int capacityPerFlow) {
            if (capacityPerFlow <= 0)
                throw new ND4JIllegalStateException("Capacity per flow value should be positive value");

            this.capacity = capacityPerFlow;
            return this;
        }

        public MagicQueue build() {
            if (numberOfBuckets < 1)
                numberOfBuckets = Nd4j.getAffinityManager().getNumberOfDevices();

            MagicQueue queue = new MagicQueue(numberOfBuckets, capacity, type);
            queue.mode = this.mode;


            return queue;
        }
    }

    private class QueueHandler extends Thread implements Runnable {
        private final BlockingQueue targetQueue;
        private final LinkedBlockingQueue bufferQueue;
        private final int device;
        private final int capacity;
        private final Type type;

        public QueueHandler(BlockingQueue queue, int capacity, int device, Type type) {
            this.targetQueue = queue;
            this.type = type;
            this.bufferQueue = new LinkedBlockingQueue(capacity);
            this.capacity = capacity;
            this.device = device;

            this.setDaemon(true);
            this.setName("MQ_THREAD " + device);
        }


        public void put(T dataSet) {
            try {
                bufferQueue.put(dataSet);
            } catch (InterruptedException e) {
                //
            }
        }

        @Override
        public void run() {
            Nd4j.create(1);
            WorkspaceConfiguration configuration = null;
            String id = "MQAD_THREAD";
            log.info("MQAD_THREAD started on device [{}/{}]", device,
                            Nd4j.getAffinityManager().getDeviceForCurrentThread());

            while (true) {
                try {
                    DataSet ds = null;
                    MultiDataSet mds = null;

                    if (type == Type.DS)
                        ds = (DataSet) bufferQueue.poll(1, TimeUnit.SECONDS);
                    else
                        mds = (MultiDataSet) bufferQueue.poll(1, TimeUnit.SECONDS);

                    if (ds != null) {
                        if (configuration == null) {
                            long initSize = Math.max(ds.getMemoryFootprint() * capacity, 10 * 1024L * 1024L);

                            configuration = WorkspaceConfiguration.builder().initialSize(initSize)
                                            .overallocationLimit(1.0).policyReset(ResetPolicy.ENDOFBUFFER_REACHED)
                                            .policyAllocation(AllocationPolicy.OVERALLOCATE).build();
                        }

                        try (MemoryWorkspace workspace =
                                        Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, id)) {
                            // now we initialize dataset on target device (if applicable)
                            ds.migrate();
                            /*
                            if (ds.getFeaturesMaskArray() != null)
                                ds.setFeaturesMaskArray(ds.getFeaturesMaskArray().migrate());
                            //Nd4j.getAffinityManager().touch(ds.getFeaturesMaskArray());
                            
                            if (ds.getLabelsMaskArray() != null)
                                ds.setLabelsMaskArray(ds.getLabelsMaskArray().migrate());
                            //Nd4j.getAffinityManager().touch(ds.getLabelsMaskArray());
                            
                            ds.setFeatures(ds.getFeatures().migrate());
                            ds.setLabels(ds.getLabels().migrate());
                            */
                            //Nd4j.getAffinityManager().touch(ds.getFeatures());
                            //Nd4j.getAffinityManager().touch(ds.getLabels());
                        }
                        //log.info("Tagged object as device_{}", Nd4j.getAffinityManager().getDeviceForArray(ds.getFeatures()));

                        targetQueue.put((T) ds);
                    } else if (mds != null) {
                        if (configuration == null) {
                            long initSize = Math.max(mds.getMemoryFootprint() * capacity, 10 * 1024L * 1024L);

                            configuration = WorkspaceConfiguration.builder().initialSize(initSize)
                                            .overallocationLimit(1.0).policyReset(ResetPolicy.ENDOFBUFFER_REACHED)
                                            .policyAllocation(AllocationPolicy.OVERALLOCATE).build();
                        }

                        try (MemoryWorkspace workspace =
                                        Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, id)) {
                            if (mds.getFeaturesMaskArrays() != null)
                                for (int i = 0; i < mds.getFeaturesMaskArrays().length; i++)
                                    mds.getFeaturesMaskArrays()[i] = mds.getFeaturesMaskArrays()[i].migrate();

                            if (mds.getLabelsMaskArrays() != null)
                                for (int i = 0; i < mds.getLabelsMaskArrays().length; i++)
                                    mds.getLabelsMaskArrays()[i] = mds.getLabelsMaskArrays()[i].migrate();

                            if (mds.getLabels() != null)
                                for (int i = 0; i < mds.getLabels().length; i++)
                                    mds.getLabels()[i] = mds.getLabels()[i].migrate();

                            if (mds.getFeatures() != null)
                                for (int i = 0; i < mds.getFeatures().length; i++)
                                    mds.getFeatures()[i] = mds.getFeatures()[i].migrate();

                            targetQueue.put((T) mds);
                        }
                    }
                } catch (InterruptedException e) {
                    log.warn("Got InterruptedException...");
                    break;
                }
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy