org.deeplearning4j.optimize.solvers.accumulation.BasicGradientsAccumulator Maven / Gradle / Ivy
package org.deeplearning4j.optimize.solvers.accumulation;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.optimize.api.StepFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* This class provides accumulation for gradients for both input (i.e. updates coming from network) and output (comint from one ore more models training at the same time)
*
* @author [email protected]
*/
@Slf4j
public class BasicGradientsAccumulator implements GradientsAccumulator {
protected MessageHandler handler;
// here we'll store messages coming from "somewhere else"
protected transient Queue gradients;
// this field stores current accumulated
protected transient INDArray storage;
// this field stores updates ready-to-apply
protected transient INDArray updates;
// this counter tracks number of messages generated by this accumulation
protected transient AtomicLong ownCounter = new AtomicLong(0);
// this counter tracks number of messages received from somewhere
protected transient AtomicLong extCounter = new AtomicLong(0);
// FIXME: this mechanics should be improved i think.
protected int[] shape;
protected char ordering;
protected int parties = 0;
protected CyclicBarrier barrier;
protected AtomicLong firstOne = new AtomicLong(-1L);
protected List candidates = new CopyOnWriteArrayList<>();
protected ReentrantReadWriteLock updatesLock = new ReentrantReadWriteLock();
protected AtomicBoolean hasSomething = new AtomicBoolean(false);
/**
* Creates new GradientsAccumulator with starting threshold of 1e-3
*/
public BasicGradientsAccumulator(int parties) {
this(parties, new LocalHandler());
}
/**
* Creates new GradientsAccumulator with custom starting threshold
*
* @param handler MessageHandler instance that'll be used for communication purposes
*/
public BasicGradientsAccumulator(int parties, @NonNull MessageHandler handler) {
this.gradients = new LinkedTransferQueue<>();
this.handler = handler;
this.handler.initialize(this);
this.parties = parties;
barrier = new CyclicBarrier(parties);
}
/**
* This method applies accumulated updates via given StepFunction
*
* @param function
* @param params
*/
@Override
public void applyUpdate(StepFunction function, INDArray params, INDArray grad) {
try {
updatesLock.readLock().lock();
firstOne.compareAndSet(-1L, Thread.currentThread().getId());
if (hasSomething.get())
function.step(params, updates);
barrier.await();
if (firstOne.get() == Thread.currentThread().getId()) {
// one thread just nullifies this array
updates.assign(0.0);
hasSomething.set(false);
firstOne.set(-1L);
}
updatesLock.readLock().unlock();
barrier.await();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* This method applies accumulated updates via given StepFunction
*
* @param function
* @param params
*/
@Override
public void applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha) {
try {
updatesLock.readLock().lock();
firstOne.compareAndSet(-1L, Thread.currentThread().getId());
if (hasSomething.get())
function.step(params, updates, alpha);
barrier.await();
if (firstOne.get() == Thread.currentThread().getId()) {
// one thread just nullifies this array
updates.assign(0.0);
hasSomething.set(false);
firstOne.set(-1L);
}
updatesLock.readLock().unlock();
barrier.await();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
*
* @param array
*/
@Override
public void storeUpdate(INDArray array) {
/*
Here we want to do 4 things:
1) update accumulated values
2) invoke extractor, that'll (optionally) pull all updates above threshold
3) ???
4) PROFIT!
*/
try {
// commit should happen in each individual thread
Nd4j.getExecutioner().commit();
firstOne.compareAndSet(-1L, Thread.currentThread().getId());
// TODO: since we know number of elements in advance, we don't really need CopyOnWrite list here.
candidates.add(array);
barrier.await();
if (firstOne.get() == Thread.currentThread().getId()) {
// if accum is null, let's just create it here
if (storage == null) {
// we don't want state array to be attached to any workspace
shape = array.shape();
ordering = array.ordering();
try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
// TODO: if p2p isn't supported, this should become HOST-only allocation
storage = Nd4j.create(shape, ordering);
}
}
// accumulate our values, a
//storage.addi(array);
Nd4j.accumulate(storage, candidates);
// we ensure storage was updated successfully
Nd4j.getExecutioner().commit();
// if there's something to send - send it. Skip otherwise!!!
if (handler.broadcastUpdates(storage)) {
ownCounter.getAndIncrement();
}
// reset "first one" :)
firstOne.set(-1L);
candidates.clear();
}
barrier.await();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
*
* PLEASE NOTE: array is expected to be ready for use and match params dimensionality
*
* @param array
*/
@Override
public void receiveUpdate(INDArray array) {
extCounter.getAndIncrement();
updatesLock.writeLock().lock();
if (updates == null) {
try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
// TODO: this one has to be HOST-only if P2P is NOT supported
updates = Nd4j.create(array.shape(), array.ordering());
}
}
hasSomething.compareAndSet(false, true);
// if P2P is NOT supported - this call should be handled with cpu
updates.addi(array);
// we have to ensure, all operations were finished here
Nd4j.getExecutioner().commit();
updatesLock.writeLock().unlock();
}
/**
* This method resets all accumulated updates (if any)
*/
@Override
public void reset() {
updatesLock.writeLock().lock();
if (storage != null) {
storage.assign(0.0f);
}
if (updates != null)
updates.assign(0.0f);
updatesLock.writeLock().unlock();
}
/**
* This method does initialization of given worker wrt Thread-Device Affinity
*/
@Override
public void touch() {
// no-op
}
@Override
public void setExternalSource(Queue source) {
// TODO: to be implemented
throw new UnsupportedOperationException();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy