All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.datasets.iterator;

import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ConcurrentModificationException;
import java.util.NoSuchElementException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * Async prefetching iterator wrapper for MultiDataSetIterator implementations
 * 

* PLEASE NOTE: If used together with CUDA backend, please use it with caution. * * @author Alex Black * @author [email protected] */ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { private final MultiDataSetIterator iterator; private final LinkedBlockingQueue queue; private IteratorRunnable runnable; private Thread thread; public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueLength) { if (queueLength <= 0) throw new IllegalArgumentException("Queue size must be > 0"); if (queueLength < 2) queueLength = 2; this.iterator = iterator; if (this.iterator.resetSupported()) this.iterator.reset(); this.queue = new LinkedBlockingQueue<>(queueLength); runnable = new IteratorRunnable(iterator.hasNext()); thread = new Thread(runnable); Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); thread.setDaemon(true); thread.start(); } @Override public MultiDataSet next(int num) { // TODO: why isn't supported? We could just check queue size throw new UnsupportedOperationException("Next(int) not supported for AsyncDataSetIterator"); } @Override public void setPreProcessor(MultiDataSetPreProcessor preProcessor) { iterator.setPreProcessor(preProcessor); } @Override public boolean resetSupported() { return iterator.resetSupported(); } @Override public boolean asyncSupported() { return false; } @Override public void reset() { if (!resetSupported()) throw new UnsupportedOperationException( "Cannot reset Async iterator wrapping iterator that does not support reset"); //Complication here: runnable could be blocking on either baseIterator.next() or blockingQueue.put() runnable.killRunnable = true; if (runnable.isAlive) { thread.interrupt(); } //Wait for runnable to exit, but should only have to wait very short period of time //This probably isn't necessary, but is included as a safeguard against race conditions try { runnable.runCompletedSemaphore.tryAcquire(5, TimeUnit.SECONDS); } catch (InterruptedException e) { } //Clear the queue, reset the base iterator, set up a new thread queue.clear(); iterator.reset(); runnable = new IteratorRunnable(iterator.hasNext()); thread = new Thread(runnable); Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); thread.setDaemon(true); thread.start(); } @Override public boolean hasNext() { if (!queue.isEmpty()) return true; if (runnable.isAlive) { //Empty blocking queue, but runnable is alive //(a) runnable is blocking on baseIterator.next() //(b) runnable is blocking on blockingQueue.put() //either way: there's at least 1 more element to come return runnable.hasLatch(); } else { if (!runnable.killRunnable && runnable.exception != null) { throw runnable.exception; //Something went wrong } //Runnable has exited, presumably because it has fetched all elements return runnable.hasLatch(); } } @Override public MultiDataSet next() { if (!hasNext()) { throw new NoSuchElementException(); } if (runnable.exception != null) { throw runnable.exception; } if (!queue.isEmpty()) { runnable.feeder.decrementAndGet(); return queue.poll(); //non-blocking, but returns null if empty } //Blocking queue is empty, but more to come //Possible reasons: // (a) runnable died (already handled - runnable.exception != null) // (b) baseIterator.next() hasn't returned yet -> wait for it try { //Normally: just do blockingQueue.take(), but can't do that here //Reason: what if baseIterator.next() throws an exception after // blockingQueue.take() is called? In this case, next() will never return while (runnable.exception == null) { MultiDataSet ds = queue.poll(5, TimeUnit.SECONDS); if (ds != null) { runnable.feeder.decrementAndGet(); return ds; } if (runnable.killRunnable) { //should never happen throw new ConcurrentModificationException("Reset while next() is waiting for element?"); } if (!runnable.isAlive && queue.isEmpty()) { if (runnable.exception != null) throw new RuntimeException("Exception thrown in base iterator", runnable.exception); throw new IllegalStateException( "Unexpected state occurred for AsyncMultiDataSetIterator: runnable died or no data available"); } } //exception thrown while getting data from base iterator throw runnable.exception; } catch (InterruptedException e) { throw new RuntimeException(e); //Shouldn't happen under normal circumstances } } @Override public void remove() { // no-op } /** * Shut down the async data set iterator thread * This is not typically necessary if using a single AsyncDataSetIterator * (thread is a daemon thread and so shouldn't block the JVM from exiting) * Behaviour of next(), hasNext() etc methods after shutdown of async iterator is undefined */ public void shutdown() { if (thread.isAlive()) { runnable.killRunnable = true; thread.interrupt(); } } private class IteratorRunnable implements Runnable { private volatile boolean killRunnable = false; private volatile boolean isAlive = true; private volatile RuntimeException exception; private Semaphore runCompletedSemaphore = new Semaphore(0); private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); private AtomicLong feeder = new AtomicLong(0); public IteratorRunnable(boolean hasNext) { this.isAlive = hasNext; } public boolean hasLatch() { /* This method was added to address possible race condition within runnable loop. Idea is simple: in 99% of cases semaphore won't lock in hasLatch calls, since method is called ONLY if there's nothing in queue, and if it's already locked within main runnable loop - we get fast TRUE. */ // this is added just to avoid expensive lock if (feeder.get() > 0 || !queue.isEmpty()) return true; try { lock.readLock().lock(); boolean result = iterator.hasNext() || feeder.get() != 0 || !queue.isEmpty(); if (!isAlive) return result; else while (isAlive) { // in normal scenario this cycle is possible to hit into feeder state, since readLock is taken result = feeder.get() != 0 || !queue.isEmpty() || iterator.hasNext(); if (result) return true; } return result; } finally { lock.readLock().unlock(); } } @Override public void run() { try { while (!killRunnable && iterator.hasNext()) { feeder.incrementAndGet(); lock.writeLock().lock(); MultiDataSet ds = iterator.next(); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking(); // feeder is temporary state variable, that shows if we have something between backend iterator and buffer lock.writeLock().unlock(); queue.put(ds); } isAlive = false; } catch (InterruptedException e) { //thread.interrupt() while put(DataSet) was blocking if (killRunnable) { return; } else exception = new RuntimeException("Runnable interrupted unexpectedly", e); //Something else interrupted } catch (RuntimeException e) { exception = e; if (lock.writeLock().isHeldByCurrentThread()) { lock.writeLock().unlock(); } } finally { isAlive = false; runCompletedSemaphore.release(); } } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy