All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.optimize.solvers.accumulation.FancyBlockingQueue Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.optimize.solvers.accumulation;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.util.ThreadUtils;

import java.util.Collection;
import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

@Slf4j
public class FancyBlockingQueue implements BlockingQueue, Registerable {
    protected BlockingQueue backingQueue;
    protected volatile int consumers;

    protected ThreadLocal currentStep = new ThreadLocal<>();
    protected final AtomicLong step = new AtomicLong(0);
    protected final AtomicInteger state = new AtomicInteger(0);
    protected final AtomicInteger currentConsumers = new AtomicInteger(0);

    protected AtomicBoolean isFirst = new AtomicBoolean(false);
    protected AtomicBoolean isDone = new AtomicBoolean(true);

    protected AtomicInteger barrier = new AtomicInteger(0);
    protected AtomicInteger secondary = new AtomicInteger(0);

    protected AtomicInteger numElementsReady = new AtomicInteger(0);
    protected AtomicInteger numElementsDrained = new AtomicInteger(0);
    protected AtomicBoolean bypassMode = new AtomicBoolean(false);

    protected boolean isDebug = false;
    protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();


    public FancyBlockingQueue(@NonNull BlockingQueue queue) {
        this(queue, -1);
    }

    public FancyBlockingQueue(@NonNull BlockingQueue queue, int consumers) {
        this.backingQueue = queue;
        this.consumers = consumers;
        this.currentConsumers.set(consumers);
    }


    @Override
    public boolean add(E e) {
        return backingQueue.add(e);
    }

    @Override
    public boolean offer(E e) {
        return backingQueue.offer(e);
    }


    @Override
    public E remove() {
        return backingQueue.remove();
    }

    @Override
    public void fallbackToSingleConsumerMode(boolean reallyFallback) {
        bypassMode.set(reallyFallback);
    }

    @Override
    public void registerConsumers(int consumers) {
        lock.writeLock().lock();

        this.numElementsReady.set(backingQueue.size());
        this.numElementsDrained.set(0);
        this.consumers = consumers;
        this.currentConsumers.set(consumers);

        lock.writeLock().unlock();
    }

    @Override
    public void put(E e) throws InterruptedException {
        lock.readLock().lock();
        log.trace("Adding value to the buffer. Current size: [{}]", backingQueue.size());
        backingQueue.put(e);
        lock.readLock().unlock();
    }

    @Override
    public boolean isEmpty() {
        if (bypassMode.get())
            return backingQueue.isEmpty();


        boolean res = numElementsDrained.get() >= numElementsReady.get();

        if (isDebug)
            log.info("thread {} queries isEmpty: {}", Thread.currentThread().getId(), res);


        return res;
    }

    protected void synchronize(int consumers) {
        if (consumers == 1 || bypassMode.get())
            return;

        if (isDebug)
            log.info("thread {} locking at FBQ", Thread.currentThread().getId());

        // any first thread entering this block - will reset this field to false
        isDone.compareAndSet(true, false);

        // last thread will set isDone to true
        if (barrier.incrementAndGet() == consumers) {
            secondary.set(0);
            barrier.set(0);
            isFirst.set(false);
            isDone.set(true);
        } else {
            // just wait, till last thread will set isDone to true
            while (!isDone.get())
                ThreadUtils.uncheckedSleep(1);
        }

        // second lock here needed only to ensure we won't get overrun over isDone flag
        if (secondary.incrementAndGet() == consumers) {
            isFirst.set(true);
        } else {
            while (!isFirst.get())
                ThreadUtils.uncheckedSleep(1);
        }

        if (isDebug)
            log.info("thread {} unlocking at FBQ", Thread.currentThread().getId());

    }

    @Override
    public E poll() {
        if (bypassMode.get())
            return backingQueue.poll();

        // if that's first step, set local step counter to -1
        if (currentStep.get() == null)
            currentStep.set(new AtomicLong(-1));

        // we block until everyone else step forward
        while (step.get() == currentStep.get().get())
            ThreadUtils.uncheckedSleep(1);

        E object = peek();

        // we wait until all consumers peek() this object from queue
        synchronize(currentConsumers.get());

        currentStep.get().incrementAndGet();


        // last consumer shifts queue on step further
        if (state.incrementAndGet() == currentConsumers.get()) {

            // we're removing current head of queue
            remove();

            numElementsDrained.incrementAndGet();

            // and moving step counter further
            state.set(0);
            step.incrementAndGet();
        }

        // we wait until all consumers know that queue is updated (for isEmpty())
        synchronize(currentConsumers.get());
        //log.info("Second lock passed");

        // now, every consumer in separate threads will get it's own copy of CURRENT head of the queue
        return object;
    }

    @Override
    public E element() {
        return backingQueue.element();
    }

    @Override
    public void clear() {
        backingQueue.clear();
        step.set(0);
    }

    @Override
    public int size() {
        return backingQueue.size();
    }

    @Override
    public E peek() {
        return backingQueue.peek();
    }

    @Override
    public boolean offer(E e, long timeout, TimeUnit unit) throws InterruptedException {
        return backingQueue.offer(e, timeout, unit);
    }

    @Override
    public E take() throws InterruptedException {
        return null;
    }

    @Override
    public E poll(long timeout, TimeUnit unit) throws InterruptedException {
        return backingQueue.poll(timeout, unit);
    }


    @Override
    public int remainingCapacity() {
        return backingQueue.remainingCapacity();
    }

    @Override
    public boolean remove(Object o) {
        return backingQueue.remove(o);
    }

    @Override
    public boolean containsAll(Collection c) {
        return backingQueue.containsAll(c);
    }

    @Override
    public boolean addAll(Collection c) {
        return backingQueue.addAll(c);
    }

    @Override
    public boolean removeAll(Collection c) {
        return backingQueue.removeAll(c);
    }

    @Override
    public boolean retainAll(Collection c) {
        return backingQueue.retainAll(c);
    }

    @Override
    public boolean contains(Object o) {
        return backingQueue.contains(o);
    }


    @Override
    public Iterator iterator() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    @Override
    public  T[] toArray(T[] a) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int drainTo(Collection c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public int drainTo(Collection c, int maxElements) {
        throw new UnsupportedOperationException();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy