org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator Maven / Gradle / Ivy
package org.deeplearning4j.optimize.solvers.accumulation;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.StepFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.ThresholdCompression;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.AtomicThrowable;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
* This GradientsAccumulator is suited for CUDA backend.
* @author [email protected]
public class EncodedGradientsAccumulator implements GradientsAccumulator, Registerable {
protected ThreadLocal accumulator = new ThreadLocal<>();
protected int parties;
protected MessageHandler handler;
protected List> messages = new ArrayList<>();
protected List workspaces = new ArrayList<>();
protected List locks = new ArrayList<>();
protected AtomicInteger workersCounter = new AtomicInteger(0);
protected ThreadLocal index = new ThreadLocal<>();
protected long initialMemory = 100 * 1024 * 1024L;
protected int queueSize = 5;
protected Double boundary = 1.0;
protected Queue externalSource;
protected AtomicBoolean isFirst = new AtomicBoolean(false);
protected AtomicBoolean isDone = new AtomicBoolean(true);
protected AtomicInteger barrier = new AtomicInteger(0);
protected AtomicInteger secondary = new AtomicInteger(0);
protected AtomicBoolean registered = new AtomicBoolean(false);
protected AtomicBoolean bypassMode = new AtomicBoolean(false);
protected final AtomicInteger currentConsumers = new AtomicInteger(0);
protected final AtomicThrowable throwable = new AtomicThrowable();
protected boolean isDebug = false;
protected final boolean relocatable;
protected WorkspaceConfiguration appliedConfiguration = WorkspaceConfiguration.builder().minSize(5 * 1024 * 1024L)
public EncodedGradientsAccumulator(double parties) {
this(Nd4j.getAffinityManager().getNumberOfDevices(), 1e-3);
// TODO: delete this one maybe?
public EncodedGradientsAccumulator(int parties) {
this(parties, 1e-3);
public EncodedGradientsAccumulator(int parties, double threshold) {
this(parties, new EncodingHandler(threshold), 100 * 1024 * 1024L, 10, 1.0);
protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize,
Double boundary) {
this.parties = parties;
this.handler = handler;
this.initialMemory = initialMemory;
this.queueSize = queueSize;
this.boundary = boundary;
// maybe not the best idea in the world, but we'll use cyclic workspace of 25MB to receive updates
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(initialMemory)
// we want to know, if we'll have to relocate data if accessed from different threads/devices
relocatable = Nd4j.getAffinityManager().getNumberOfDevices() > 1
&& !Nd4j.getAffinityManager().isCrossDeviceAccessSupported();
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
// we are going to take single-device systems as edge case: cpu & small models at single-gpu systems.
if (parties > numDevices && numDevices != 1)
throw new ND4JIllegalStateException("Number of parties [" + parties
+ "] should be less or equal to number of devices [" + numDevices + "]");
// pre-create Queues for local workers
int curDev = Nd4j.getAffinityManager().getDeviceForCurrentThread();
for (int i = 0; i < parties; i++) {
messages.add(new LinkedBlockingQueue(queueSize));
// we don't want device index to step out of boundaries here
int cDevice = numDevices > 1 ? i % numDevices : 0;
MemoryWorkspace ws = Nd4j.getWorkspaceManager().createNewWorkspace(configuration, "CGA-" + i, cDevice);
locks.add(new ReentrantLock());
public void fallbackToSingleConsumerMode(boolean reallyFallback) {
if (externalSource != null && externalSource instanceof Registerable)
((Registerable) externalSource).fallbackToSingleConsumerMode(reallyFallback);
public void registerConsumers(int numConsumers) {
// we don't want double spending here
if (registered.get()) {
if (isDebug)"Master thread locks at RC");
while (registered.get()) {
if (throwable.isTriggered())
throw new RuntimeException(throwable.get());
if (isDebug)"Master thread unlocks at RC");
// we're passing number of consumers for current session to externalSource, if applicable
if (externalSource != null && externalSource instanceof Registerable)
((Registerable) externalSource).registerConsumers(numConsumers);
protected void synchronize(int consumers) {
synchronize(consumers, false);
protected void synchronize(int consumers, boolean finalLock) {
if (consumers == 1 || bypassMode.get()) {
if (finalLock)
if (isDebug)"thread {} locking at CGA: {}", Thread.currentThread().getId(), currentConsumers.get());
// any first thread entering this block - will reset this field to false
isDone.compareAndSet(true, false);
// last thread will set isDone to true
if (barrier.incrementAndGet() == consumers) {
} else {
// just wait, till last thread will set isDone to true
while (!isDone.get()) {
if (throwable.isTriggered())
throw new RuntimeException(throwable.get());
// second lock here needed only to ensure we won't get overrun over isDone flag
if (secondary.incrementAndGet() == consumers) {
if (finalLock)
} else {
while (!isFirst.get()) {
if (throwable.isTriggered())
throw new RuntimeException(throwable.get());
if (isDebug)"thread {} unlocking at CGA: {}", Thread.currentThread().getId(), currentConsumers.get());
* This method applies accumulated updates via given StepFunction
* @param function
* @param params
public void applyUpdate(StepFunction function, INDArray params, INDArray updates) {
try {
// nullify given updates first
int cnt = 0;
while (!messages.get(index.get()).isEmpty()) {
INDArray compressed = messages.get(index.get()).poll();
int encoding =;
if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
Nd4j.getExecutioner().thresholdDecode(compressed, updates);
else if (encoding == ThresholdCompression.BITMAP_ENCODING)
Nd4j.getExecutioner().bitmapDecode(compressed, updates);
throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
if (cnt > 0 && isDebug)"Local updates to be applied: {}", cnt);
if (externalSource != null) {
int ent = 0;
while (!externalSource.isEmpty()) {
INDArray compressed = externalSource.poll();
// if we have multiple devices without p2p support - just duplicate messages right from host side
if (relocatable) {
try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager()
.getAndActivateWorkspace(appliedConfiguration, "CGA_APPLY")) {
INDArray compressed_copy = compressed.unsafeDuplication(true);
int encoding =;
if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
Nd4j.getExecutioner().thresholdDecode(compressed_copy, updates);
else if (encoding == ThresholdCompression.BITMAP_ENCODING)
Nd4j.getExecutioner().bitmapDecode(compressed_copy, updates);
throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
} else {
int encoding =;
if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
Nd4j.getExecutioner().thresholdDecode(compressed, updates);
else if (encoding == ThresholdCompression.BITMAP_ENCODING)
Nd4j.getExecutioner().bitmapDecode(compressed, updates);
throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
if (isDebug)"thread {} finished at Externals", Thread.currentThread().getId());
if (ent > 0 && isDebug)"External updates to be applied: {}", ent);
synchronize(currentConsumers.get(), true);
// TODO: average updates probably?
if (cnt > 0)
function.step(params, updates);
} catch (Exception e) {
throw new RuntimeException(e);
* This method applies accumulated updates via given StepFunction
* @param function
* @param params
* @param alpha
public void applyUpdate(StepFunction function, INDArray params, INDArray updates, double alpha) {
try {
// nullify given updates first
int cnt = 0;
while (!messages.get(index.get()).isEmpty()) {
INDArray compressed = messages.get(index.get()).poll();
int encoding =;
if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
Nd4j.getExecutioner().thresholdDecode(compressed, updates);
else if (encoding == ThresholdCompression.BITMAP_ENCODING)
Nd4j.getExecutioner().bitmapDecode(compressed, updates);
throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
if (cnt > 0 && isDebug)"Local updates to be applied: {}", cnt);
if (externalSource != null) {
int ent = 0;
while (!externalSource.isEmpty()) {
INDArray compressed = externalSource.poll();
// if we have multiple devices without p2p support - just duplicate messages right from host side
if (relocatable) {
try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager()
.getAndActivateWorkspace(appliedConfiguration, "CGA_APPLY")) {
INDArray compressed_copy = compressed.unsafeDuplication(true);
int encoding =;
if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
Nd4j.getExecutioner().thresholdDecode(compressed_copy, updates);
else if (encoding == ThresholdCompression.BITMAP_ENCODING)
Nd4j.getExecutioner().bitmapDecode(compressed_copy, updates);
throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
} else {
int encoding =;
if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
Nd4j.getExecutioner().thresholdDecode(compressed, updates);
else if (encoding == ThresholdCompression.BITMAP_ENCODING)
Nd4j.getExecutioner().bitmapDecode(compressed, updates);
throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
if (ent > 0 && isDebug)"External updates to be applied: {}", ent);
synchronize(currentConsumers.get(), true);
// TODO: average updates? might have sense
if (cnt > 0)
function.step(params, updates, alpha);
} catch (Exception e) {
throw new RuntimeException(e);
* This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
* @param source
public void setExternalSource(Queue source) {
this.externalSource = source;
* This method does initialization of given worker wrt Thread-Device Affinity
public void touch() {
if (index.get() == null) {
// set index
int numDevces = Nd4j.getAffinityManager().getNumberOfDevices();
if we have > 1 computational device, we assign workers to workspaces "as is", as provided via AffinityManager
if (numDevces > 1 && parties > 1) {
int localIndex = Nd4j.getAffinityManager().getDeviceForCurrentThread();
} else {
// if we have only 1 device (like cpu system, or single gpu), just attach consumer via flat index
* This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
* @param array
public void storeUpdate(INDArray array) {
try {
if (accumulator.get() == null) {
// we don't want accumulator to be attached to workspaces
try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
accumulator.set(Nd4j.create(array.shape(), array.ordering()));
// accumulate gradients updates in residental array
if (isDebug)"thread {} locking at Register", Thread.currentThread().getId());
// block until ParallelWrapper sends us message about number of threads in this cycle
if (!bypassMode.get())
while (!registered.get()) {
if (throwable.isTriggered())
throw new RuntimeException(throwable.get());
if (isDebug)"thread {} unlocking at Register", Thread.currentThread().getId());
// propagate changes & modify accumulator
// we're blocking here, untill all done broadcasting updates
} catch (Exception e) {
throw new RuntimeException(e);
* This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
* PLEASE NOTE: array is expected to be ready for use and match params dimensionality
* @param array
public void receiveUpdate(INDArray array) {
try {
// we're replicating COMPRESSED MESSAGES, decompression will be thread-local
for (int i = 0; i < parties; i++) {
// we don't want to have same workspace to be accessible by 2 different threads for now
With synchronized external data, it's impossible to deadlock here.
Each worker is guaranteed to have at least NUM_WORKERS slots in buffer.
So we use this lock just to ensure thread-safety of corresponding workspaces
try (MemoryWorkspace workspace = workspaces.get(i).notifyScopeEntered()) {
// we might just scope out of workspace here, instead of throwing error out
if ( > (initialMemory / queueSize)
/ Nd4j.sizeOfDataType(
throw new ND4JIllegalStateException("Not enough memory to handle update: ["
+ * Nd4j.sizeOfDataType(
+ " bytes required]. Please increase memory amount for GradientsAccumulator");
INDArray compressed = array.unsafeDuplication();
try {
} catch (InterruptedException e) {"Something bad at index_{}", i);
throw new RuntimeException(e);
} catch (Exception e) {
throw new RuntimeException(e);
* This method resets all accumulated updates (if any)
public void reset() {
// just replace accumulator, gc will do the rest
accumulator = new ThreadLocal<>();
// resetting this counter too
// throw away message queues
for (int i = 0; i < parties; i++) {
public static class Builder {
protected int parties;
protected double threshold = 1e-3;
protected long initialMemory = 100 * 1024 * 1024L;
protected int queueSize = 5;
protected MessageHandler handler;
protected Double boundary = null;
* This
* @param parties
public Builder(int parties) {
if (parties < 1)
throw new DL4JInvalidConfigException(
"Number of parties for GradientsAccumulation should be positive value");
this.parties = parties;
* This method allows to specify MessageHandler instance
* Default value: EncodingHandler
* @param handler
* @return
public Builder messageHandler(@NonNull MessageHandler handler) {
this.handler = handler;
return this;
* This method allows to set encoding threshold for this accumulator instance
* Default value: 1e-3
* @param threshold
* @return
public Builder encodingThreshold(double threshold) {
this.threshold = threshold;
return this;
* This method enables optional limit for max number of updates per message
* Default value: 1.0 (no limit)
* @param boundary positive value in range 0..1
* @return
public Builder updatesBoundary(double boundary) {
if (boundary >= 1.0)
return this;
if (boundary <= 0.0)
throw new DL4JInvalidConfigException("Boundary should have positive value");
this.boundary = boundary;
return this;
* This method allows to define buffer memory parameters for this GradientsAccumulator
* Default values: 100MB initialMemory, 5 queueSize
* @param initialMemory
* @param queueSize
* @return
public Builder memoryParameters(long initialMemory, int queueSize) {
this.initialMemory = initialMemory;
this.queueSize = queueSize;
return this;
public EncodedGradientsAccumulator build() {
if (handler == null) {
if (boundary == null)
handler = new EncodingHandler(threshold);
handler = new EncodingHandler(threshold, boundary);
EncodedGradientsAccumulator accumulator =
new EncodedGradientsAccumulator(parties, handler, initialMemory, queueSize, boundary);
return accumulator;