All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.optimize.solvers.accumulation;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.StepFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.ThresholdCompression;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.AtomicThrowable;

import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;

/**
 * This GradientsAccumulator is suited for CUDA backend.
 *
 * @author [email protected]
 */
@Slf4j
public class EncodedGradientsAccumulator implements GradientsAccumulator, Registerable {
    protected ThreadLocal accumulator = new ThreadLocal<>();

    protected int parties;
    protected MessageHandler handler;
    protected List> messages = new ArrayList<>();
    protected List workspaces = new ArrayList<>();
    protected List locks = new ArrayList<>();

    protected AtomicInteger workersCounter = new AtomicInteger(0);
    protected ThreadLocal index = new ThreadLocal<>();
    protected long initialMemory = 100 * 1024 * 1024L;
    protected int queueSize = 5;
    protected Double boundary = 1.0;

    protected Queue externalSource;

    protected AtomicBoolean isFirst = new AtomicBoolean(false);
    protected AtomicBoolean isDone = new AtomicBoolean(true);

    protected AtomicInteger barrier = new AtomicInteger(0);
    protected AtomicInteger secondary = new AtomicInteger(0);
    protected AtomicBoolean registered = new AtomicBoolean(false);
    protected AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected final AtomicInteger currentConsumers = new AtomicInteger(0);

    protected final AtomicThrowable throwable = new AtomicThrowable();

    protected boolean isDebug = false;
    protected final boolean relocatable;

    protected WorkspaceConfiguration appliedConfiguration = WorkspaceConfiguration.builder().minSize(5 * 1024 * 1024L)
                    .overallocationLimit(0.3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE)
                    .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).build();

    public EncodedGradientsAccumulator(double parties) {
        this(Nd4j.getAffinityManager().getNumberOfDevices(), 1e-3);
    }

    // TODO: delete this one maybe?
    public EncodedGradientsAccumulator(int parties) {
        this(parties, 1e-3);
    }

    public EncodedGradientsAccumulator(int parties, double threshold) {
        this(parties, new EncodingHandler(threshold), 100 * 1024 * 1024L, 10, 1.0);
    }

    protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize,
                                          Double boundary) {
        this.parties = parties;
        this.handler = handler;
        this.initialMemory = initialMemory;
        this.queueSize = queueSize;
        this.boundary = boundary;

        // maybe not the best idea in the world, but we'll use cyclic workspace of 25MB to receive updates
        WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(initialMemory)
                        .policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT)
                        .policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();


        // we want to know, if we'll have to relocate data if accessed from different threads/devices
        relocatable = Nd4j.getAffinityManager().getNumberOfDevices() > 1
                        && !Nd4j.getAffinityManager().isCrossDeviceAccessSupported();

        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();

        // we are going to take single-device systems as edge case: cpu & small models at single-gpu systems.
        if (parties > numDevices && numDevices != 1)
            throw new ND4JIllegalStateException("Number of parties [" + parties
                            + "] should be less or equal to number of devices [" + numDevices + "]");

        // pre-create Queues for local workers
        int curDev = Nd4j.getAffinityManager().getDeviceForCurrentThread();

        for (int i = 0; i < parties; i++) {
            messages.add(new LinkedBlockingQueue(queueSize));

            // we don't want device index to step out of boundaries here
            int cDevice = numDevices > 1 ? i % numDevices : 0;

            Nd4j.getAffinityManager().unsafeSetDevice(cDevice);
            MemoryWorkspace ws = Nd4j.getWorkspaceManager().createNewWorkspace(configuration, "CGA-" + i, cDevice);
            //ws.enableDebug(true);
            workspaces.add(ws);

            locks.add(new ReentrantLock());
        }
        Nd4j.getAffinityManager().unsafeSetDevice(curDev);

        handler.initialize(this);
    }

    @Override
    public void fallbackToSingleConsumerMode(boolean reallyFallback) {
        if (externalSource != null && externalSource instanceof Registerable)
            ((Registerable) externalSource).fallbackToSingleConsumerMode(reallyFallback);

        bypassMode.set(reallyFallback);
    }

    @Override
    public void registerConsumers(int numConsumers) {
        // we don't want double spending here
        if (registered.get()) {
            if (isDebug)
                log.info("Master thread locks at RC");

            while (registered.get()) {
                LockSupport.parkNanos(100L);
                if (throwable.isTriggered())
                    throw new RuntimeException(throwable.get());
            }

            if (isDebug)
                log.info("Master thread unlocks at RC");
        }

        // we're passing number of consumers for current session to externalSource, if applicable
        if (externalSource != null && externalSource instanceof Registerable)
            ((Registerable) externalSource).registerConsumers(numConsumers);

        currentConsumers.set(numConsumers);
        registered.set(true);
    }

    protected void synchronize(int consumers) {
        synchronize(consumers, false);
    }

    protected void synchronize(int consumers, boolean finalLock) {
        if (consumers == 1 || bypassMode.get()) {
            if (finalLock)
                registered.set(false);

            return;
        }

        if (isDebug)
            log.info("thread {} locking at CGA: {}", Thread.currentThread().getId(), currentConsumers.get());

        // any first thread entering this block - will reset this field to false
        isDone.compareAndSet(true, false);

        // last thread will set isDone to true
        if (barrier.incrementAndGet() == consumers) {
            secondary.set(0);
            barrier.set(0);
            isFirst.set(false);
            isDone.set(true);
        } else {
            // just wait, till last thread will set isDone to true
            while (!isDone.get()) {
                LockSupport.parkNanos(1000L);
                if (throwable.isTriggered())
                    throw new RuntimeException(throwable.get());
            }

        }

        // second lock here needed only to ensure we won't get overrun over isDone flag
        if (secondary.incrementAndGet() == consumers) {
            if (finalLock)
                registered.set(false);

            isFirst.set(true);
        } else {
            while (!isFirst.get()) {
                LockSupport.parkNanos(1000L);
                if (throwable.isTriggered())
                    throw new RuntimeException(throwable.get());
            }
        }

        if (isDebug)
            log.info("thread {} unlocking at CGA: {}", Thread.currentThread().getId(), currentConsumers.get());

    }

    /**
     * This method applies accumulated updates via given StepFunction
     *
     * @param function
     * @param params
     */
    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray updates) {
        try {
            // nullify given updates first
            Nd4j.getMemoryManager().memset(updates);
            //updates.assign(0.0);

            int cnt = 0;
            while (!messages.get(index.get()).isEmpty()) {
                INDArray compressed = messages.get(index.get()).poll();

                int encoding = compressed.data().getInt(3);
                if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                    Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                    Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                else
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);

                cnt++;
            }

            if (cnt > 0 && isDebug)
                log.info("Local updates to be applied: {}", cnt);

            if (externalSource != null) {
                int ent = 0;
                while (!externalSource.isEmpty()) {
                    INDArray compressed = externalSource.poll();

                    // if we have multiple devices without p2p support - just duplicate messages right from host side
                    if (relocatable) {
                        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager()
                                        .getAndActivateWorkspace(appliedConfiguration, "CGA_APPLY")) {
                            INDArray compressed_copy = compressed.unsafeDuplication(true);

                            int encoding = compressed.data().getInt(3);
                            if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                                Nd4j.getExecutioner().thresholdDecode(compressed_copy, updates);
                            else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                                Nd4j.getExecutioner().bitmapDecode(compressed_copy, updates);
                            else
                                throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                        }
                    } else {
                        int encoding = compressed.data().getInt(3);
                        if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                            Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                        else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                            Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                        else
                            throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                    }
                    cnt++;
                    ent++;
                }

                if (isDebug)
                    log.info("thread {} finished at Externals", Thread.currentThread().getId());

                if (ent > 0 && isDebug)
                    log.info("External updates to be applied: {}", ent);
            }

            synchronize(currentConsumers.get(), true);

            // TODO: average updates probably?

            if (cnt > 0)
                function.step(params, updates);
        } catch (Exception e) {
            throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    /**
     * This method applies accumulated updates via given StepFunction
     *
     * @param function
     * @param params
     * @param alpha
     */
    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray updates, double alpha) {
        try {
            // nullify given updates first
            Nd4j.getMemoryManager().memset(updates);
            //updates.assign(0.0);

            int cnt = 0;
            while (!messages.get(index.get()).isEmpty()) {
                INDArray compressed = messages.get(index.get()).poll();

                int encoding = compressed.data().getInt(3);
                if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                    Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                    Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                else
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);

                cnt++;
            }

            if (cnt > 0 && isDebug)
                log.info("Local updates to be applied: {}", cnt);

            if (externalSource != null) {
                int ent = 0;
                while (!externalSource.isEmpty()) {
                    INDArray compressed = externalSource.poll();


                    // if we have multiple devices without p2p support - just duplicate messages right from host side
                    if (relocatable) {
                        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager()
                                        .getAndActivateWorkspace(appliedConfiguration, "CGA_APPLY")) {
                            INDArray compressed_copy = compressed.unsafeDuplication(true);
                            int encoding = compressed.data().getInt(3);
                            if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                                Nd4j.getExecutioner().thresholdDecode(compressed_copy, updates);
                            else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                                Nd4j.getExecutioner().bitmapDecode(compressed_copy, updates);
                            else
                                throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                        }
                    } else {
                        int encoding = compressed.data().getInt(3);
                        if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                            Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                        else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                            Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                        else
                            throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);
                    }
                    cnt++;
                    ent++;
                }

                if (ent > 0 && isDebug)
                    log.info("External updates to be applied: {}", ent);
            }

            synchronize(currentConsumers.get(), true);

            // TODO: average updates? might have sense

            if (cnt > 0)
                function.step(params, updates, alpha);
        } catch (Exception e) {
            throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    /**
     * This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
     *
     * @param source
     */
    @Override
    public void setExternalSource(Queue source) {
        this.externalSource = source;
    }

    /**
     * This method does initialization of given worker wrt Thread-Device Affinity
     */
    @Override
    public void touch() {
        if (index.get() == null) {
            // set index
            int numDevces = Nd4j.getAffinityManager().getNumberOfDevices();

            /*
                if we have > 1 computational device, we assign workers to workspaces "as is", as provided via AffinityManager
             */
            if (numDevces > 1 && parties > 1) {
                int localIndex = Nd4j.getAffinityManager().getDeviceForCurrentThread();

                index.set(localIndex);
            } else {
                // if we have only 1 device (like cpu system, or single gpu), just attach consumer via flat index
                index.set(workersCounter.getAndIncrement());
            }
        }
    }

    /**
     * This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
     *
     * @param array
     */
    @Override
    public void storeUpdate(INDArray array) {
        try {
            if (accumulator.get() == null) {
                // we don't want accumulator to be attached to workspaces
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    accumulator.set(Nd4j.create(array.shape(), array.ordering()));
                }
            }

            // accumulate gradients updates in residental array
            accumulator.get().addi(array);

            if (isDebug)
                log.info("thread {} locking at Register", Thread.currentThread().getId());

            // block until ParallelWrapper sends us message about number of threads in this cycle
            if (!bypassMode.get())
                while (!registered.get()) {
                    LockSupport.parkNanos(100L);
                    if (throwable.isTriggered())
                        throw new RuntimeException(throwable.get());
                }

            if (isDebug)
                log.info("thread {} unlocking at Register", Thread.currentThread().getId());

            // propagate changes & modify accumulator
            handler.broadcastUpdates(accumulator.get());

            // we're blocking here, untill all done broadcasting updates
            synchronize(currentConsumers.get());
        } catch (Exception e) {
            throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    /**
     * This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
     * 

* PLEASE NOTE: array is expected to be ready for use and match params dimensionality * * @param array */ @Override public void receiveUpdate(INDArray array) { try { // we're replicating COMPRESSED MESSAGES, decompression will be thread-local for (int i = 0; i < parties; i++) { // we don't want to have same workspace to be accessible by 2 different threads for now /* With synchronized external data, it's impossible to deadlock here. Each worker is guaranteed to have at least NUM_WORKERS slots in buffer. So we use this lock just to ensure thread-safety of corresponding workspaces */ locks.get(i).lock(); try (MemoryWorkspace workspace = workspaces.get(i).notifyScopeEntered()) { // we might just scope out of workspace here, instead of throwing error out if (array.data().length() > (initialMemory / queueSize) / Nd4j.sizeOfDataType(array.data().dataType())) throw new ND4JIllegalStateException("Not enough memory to handle update: [" + array.data().length() * Nd4j.sizeOfDataType(array.data().dataType()) + " bytes required]. Please increase memory amount for GradientsAccumulator"); INDArray compressed = array.unsafeDuplication(); try { messages.get(i).put(compressed); } catch (InterruptedException e) { log.info("Something bad at index_{}", i); throw new RuntimeException(e); } } locks.get(i).unlock(); } } catch (Exception e) { throwable.setIfFirst(e); throw new RuntimeException(e); } } /** * This method resets all accumulated updates (if any) */ @Override public void reset() { // just replace accumulator, gc will do the rest accumulator = new ThreadLocal<>(); // resetting this counter too workersCounter.set(0); // throw away message queues for (int i = 0; i < parties; i++) { messages.get(i).clear(); } } public static class Builder { protected int parties; protected double threshold = 1e-3; protected long initialMemory = 100 * 1024 * 1024L; protected int queueSize = 5; protected MessageHandler handler; protected Double boundary = null; /** * This * @param parties */ public Builder(int parties) { if (parties < 1) throw new DL4JInvalidConfigException( "Number of parties for GradientsAccumulation should be positive value"); this.parties = parties; } /** * This method allows to specify MessageHandler instance * * Default value: EncodingHandler * @param handler * @return */ public Builder messageHandler(@NonNull MessageHandler handler) { this.handler = handler; return this; } /** * This method allows to set encoding threshold for this accumulator instance * * Default value: 1e-3 * @param threshold * @return */ public Builder encodingThreshold(double threshold) { this.threshold = threshold; return this; } /** * This method enables optional limit for max number of updates per message * * Default value: 1.0 (no limit) * @param boundary positive value in range 0..1 * @return */ public Builder updatesBoundary(double boundary) { if (boundary >= 1.0) return this; if (boundary <= 0.0) throw new DL4JInvalidConfigException("Boundary should have positive value"); this.boundary = boundary; return this; } /** * This method allows to define buffer memory parameters for this GradientsAccumulator * * Default values: 100MB initialMemory, 5 queueSize * @param initialMemory * @param queueSize * @return */ public Builder memoryParameters(long initialMemory, int queueSize) { this.initialMemory = initialMemory; this.queueSize = queueSize; return this; } public EncodedGradientsAccumulator build() { if (handler == null) { if (boundary == null) handler = new EncodingHandler(threshold); else handler = new EncodingHandler(threshold, boundary); } EncodedGradientsAccumulator accumulator = new EncodedGradientsAccumulator(parties, handler, initialMemory, queueSize, boundary); return accumulator; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy