All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.optimize.solvers.accumulation;

import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.ThresholdCompression;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.AtomicThrowable;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;

@Slf4j
public class EncodedGradientsAccumulator implements GradientsAccumulator, Registerable {
    public static final long DEFAULT_INITIAL_MEMORY = 100 * 1024 * 1024L;
    protected ThreadLocal accumulator = new ThreadLocal<>();

    protected int parties;
    @Getter
    protected MessageHandler handler;
    @Getter
    protected List> messages = new ArrayList<>();
    protected List workspaces = new ArrayList<>();
    protected List locks = new ArrayList<>();

    protected AtomicInteger workersCounter = new AtomicInteger(0);
    protected ThreadLocal index = new ThreadLocal<>();
    protected long initialMemory = 100 * 1024 * 1024L;
    protected int queueSize = 5;
    protected Integer boundary = Integer.MAX_VALUE;
    protected boolean encodingDebugMode;

    protected IndexedTail externalSource;

    protected AtomicBoolean isFirst = new AtomicBoolean(false);
    protected AtomicBoolean isDone = new AtomicBoolean(true);

    protected AtomicInteger barrier = new AtomicInteger(0);
    protected AtomicInteger secondary = new AtomicInteger(0);
    protected AtomicBoolean registered = new AtomicBoolean(false);
    protected AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected final AtomicInteger currentConsumers = new AtomicInteger(0);

    protected final AtomicThrowable throwable = new AtomicThrowable();

    protected boolean isDebug = false;
    protected final boolean relocatable;

    protected ThreadLocal updatesApplied = new ThreadLocal<>();

    protected AtomicBoolean externalUpdatesAvailable = new AtomicBoolean(false);

    protected WorkspaceConfiguration appliedConfiguration = WorkspaceConfiguration.builder().minSize(5 * 1024 * 1024L)
                    .overallocationLimit(0.3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE)
                    .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).build();

    public EncodedGradientsAccumulator(int parties, double threshold) {
        this(parties, new AdaptiveThresholdAlgorithm(threshold), new ResidualClippingPostProcessor(5, 5), false);
    }

    public EncodedGradientsAccumulator(int parties, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, boolean encodingDebugMode) {
        this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, Integer.MAX_VALUE, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, Integer.MAX_VALUE, encodingDebugMode);
    }

    public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
                    int queueSize, Integer boundary, boolean encodingDebugMode) {
        this.parties = parties;
        this.handler = handler;
        this.initialMemory = initialMemory;
        this.queueSize = queueSize;
        this.boundary = boundary;
        this.encodingDebugMode = encodingDebugMode;

        // maybe not the best idea in the world, but we'll use cyclic workspace of 25MB to receive updates
        WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(initialMemory)
                        .policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT)
                        .policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();


        // we want to know, if we'll have to relocate data if accessed from different threads/devices
        relocatable = Nd4j.getAffinityManager().getNumberOfDevices() > 1
                        && !Nd4j.getAffinityManager().isCrossDeviceAccessSupported();

        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();

        // we are going to take single-device systems as edge case: cpu & small models at single-gpu systems.
        if (parties > numDevices && numDevices != 1)
            throw new ND4JIllegalStateException("Number of parties [" + parties
                            + "] should be less or equal to number of devices [" + numDevices + "]");

        // pre-create Queues for local workers
        int curDev = Nd4j.getAffinityManager().getDeviceForCurrentThread();

        for (int i = 0; i < parties; i++) {
            messages.add(new LinkedBlockingQueue(queueSize));

            // we don't want device index to step out of boundaries here
            int cDevice = numDevices > 1 ? i % numDevices : 0;

            Nd4j.getAffinityManager().unsafeSetDevice(cDevice);
            MemoryWorkspace ws = Nd4j.getWorkspaceManager().createNewWorkspace(configuration, "CGA-" + i, cDevice);
            //ws.enableDebug(true);
            workspaces.add(ws);

            locks.add(new ReentrantLock());
        }
        Nd4j.getAffinityManager().unsafeSetDevice(curDev);

        handler.initialize(this);
    }

    /**
     * This method returns optimal bufferSize for a given model
     *
     * We know, that updates are guaranteed to have MAX size of params / 16. So, here we go.
     * I.e. for model with 100m params, that's 400m of floats (or 800m of doubles)
     * The worst case for us is bitmap encoding, that takes 2 bits to encode each gradient value
     *
     * so, for float in worst case we'll have (100m / 16) int elements. So, our buffer size will be 6.25m * queueSize * 4 bytes per int
     *
     * @param paramsLength
     * @param numWorkers
     * @param queueSize
     * @return
     */
    public static long getOptimalBufferSize(long paramsLength, int numWorkers, int queueSize) {
        // we add 64kb just for future proof volatility
        val bufferSize = ((paramsLength / 16) + 65536) * numWorkers * queueSize * 4;
        return bufferSize;
    }


    public static long getOptimalBufferSize(Model model, int numWorkers, int queueSize) {
        return getOptimalBufferSize(model.params().length(), numWorkers, queueSize);
    }

    @Override
    public void fallbackToSingleConsumerMode(boolean reallyFallback) {
        if (externalSource != null && externalSource instanceof Registerable)
            ((Registerable) externalSource).fallbackToSingleConsumerMode(reallyFallback);

        bypassMode.set(reallyFallback);
    }

    @Override
    public void registerConsumers(int numConsumers) {
        // we don't want double spending here
        if (registered.get()) {
            if (isDebug)
                log.info("Master thread locks at RC");

            while (registered.get()) {
                ThreadUtils.uncheckedSleep(1);
                if (throwable.isTriggered())
                    throw new RuntimeException(throwable.get());
            }

            if (isDebug)
                log.info("Master thread unlocks at RC");
        }

        // we're passing number of consumers for current session to externalSource, if applicable
        if (externalSource != null && externalSource instanceof Registerable) {
            //externalUpdatesAvailable.set(!externalSource.isEmpty());

            ((Registerable) externalSource).registerConsumers(numConsumers);
        }

        currentConsumers.set(numConsumers);
        registered.set(true);
    }

    @Override
    public IndexedTail getExternalSource() {
        return externalSource;
    }

    @Override
    public void markExternalUpdates(boolean updatesAvailable) {
        externalUpdatesAvailable.set(updatesAvailable);
    }

    protected void synchronize(int consumers) {
        synchronize(consumers, false);
    }

    protected void synchronize(int consumers, boolean finalLock) {
        if (consumers == 1 || bypassMode.get()) {
            if (finalLock)
                registered.set(false);

            return;
        }

        if (isDebug)
            log.info("thread {} locking at CGA: {}", Thread.currentThread().getId(), currentConsumers.get());

        // any first thread entering this block - will reset this field to false
        isDone.compareAndSet(true, false);

        // last thread will set isDone to true
        if (barrier.incrementAndGet() == consumers) {
            secondary.set(0);
            barrier.set(0);
            isFirst.set(false);
            isDone.set(true);
        } else {
            // just wait, till last thread will set isDone to true
            while (!isDone.get()) {
                ThreadUtils.uncheckedSleep(1);
                if (throwable.isTriggered())
                    throw new RuntimeException(throwable.get());
            }
        }

        // second lock here needed only to ensure we won't get overrun over isDone flag
        if (secondary.incrementAndGet() == consumers) {
            if (finalLock)
                registered.set(false);

            isFirst.set(true);
        } else {
            while (!isFirst.get()) {
                ThreadUtils.uncheckedSleep(1);
                if (throwable.isTriggered())
                    throw new RuntimeException(throwable.get());
            }
        }

        if (isDebug)
            log.info("thread {} unlocking at CGA: {}", Thread.currentThread().getId(), currentConsumers.get());

    }

    /**
     * This method applies accumulated updates via given StepFunction
     *
     * @param function
     * @param params
     */
    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray updates, boolean isFinalStep) {
        if (updatesApplied.get() == null)
            updatesApplied.set(new AtomicLong(0));
        try {
            // nullify given updates first
            Nd4j.getMemoryManager().memset(updates);
            //updates.assign(0.0);

            int cnt = 0;
            while (!messages.get(index.get()).isEmpty()) {
                INDArray compressed = messages.get(index.get()).poll();

                int encoding = compressed.data().getInt(3);
                if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                    Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                    Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                else
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);

                cnt++;
            }

            if (cnt > 0 && isDebug)
                log.info("Local updates to be applied: {}", cnt);

            if (externalSource != null) {
                int ent = 0;
                if (externalSource.hasAnything()) {
                    externalSource.drainTo(updates);

                    cnt++;
                    ent++;
                }

                if (isDebug)
                    log.info("thread {} finished at Externals", Thread.currentThread().getId());

                if (ent > 0 && isDebug)
                    log.info("External updates to be applied: {}", ent);
            }

            if (isFinalStep)
                synchronize(currentConsumers.get(), isFinalStep);

            // TODO: average updates probably?

            if (cnt > 0) {
                function.step(params, updates);
                updatesApplied.get().addAndGet(cnt);
                if (isDebug)
                    log.info("Total updates applied so far for thread [{}]: [{}]", Thread.currentThread().getName(), updatesApplied.get());
            }
        } catch (Exception e) {
            throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    /**
     * This method applies accumulated updates via given StepFunction
     *
     * @param function
     * @param params
     * @param alpha
     */
    @Override
    public void applyUpdate(StepFunction function, INDArray params, INDArray updates, double alpha) {
        try {
            // nullify given updates first
            Nd4j.getMemoryManager().memset(updates);
            //updates.assign(0.0);

            int cnt = 0;
            while (!messages.get(index.get()).isEmpty()) {
                INDArray compressed = messages.get(index.get()).poll();

                int encoding = compressed.data().getInt(3);
                if (encoding == ThresholdCompression.FLEXIBLE_ENCODING)
                    Nd4j.getExecutioner().thresholdDecode(compressed, updates);
                else if (encoding == ThresholdCompression.BITMAP_ENCODING)
                    Nd4j.getExecutioner().bitmapDecode(compressed, updates);
                else
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + encoding);

                cnt++;
            }

            if (cnt > 0 && isDebug)
                log.info("Local updates to be applied: {}", cnt);

            if (externalSource != null) {
                int ent = 0;
                if (externalSource.hasAnything()) {
                    externalSource.drainTo(updates);

                    cnt++;
                    ent++;
                }

                if (ent > 0 && isDebug)
                    log.info("External updates to be applied: {}", ent);
            }

            synchronize(currentConsumers.get(), true);

            // TODO: average updates? might have sense

            if (cnt > 0)
                function.step(params, updates, alpha);
        } catch (Exception e) {
            throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    /**
     * This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
     *
     * @param source
     */
    @Override
    public void setExternalSource(IndexedTail source) {
        this.externalSource = source;
    }

    /**
     * This method does initialization of given worker wrt Thread-Device Affinity
     */
    @Override
    public void touch() {
        if (index.get() == null) {
            // set index
            int numDevces = Nd4j.getAffinityManager().getNumberOfDevices();

            /*
                if we have > 1 computational device, we assign workers to workspaces "as is", as provided via AffinityManager
             */
            if (numDevces > 1 && parties > 1) {
                int localIndex = Nd4j.getAffinityManager().getDeviceForCurrentThread();

                index.set(localIndex);
            } else {
                // if we have only 1 device (like cpu system, or single gpu), just attach consumer via flat index
                index.set(workersCounter.getAndIncrement());
            }
        }
    }

    /**
     * This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
     *
     * @param array
     */
    @Override
    public void storeUpdate(INDArray array, int iterationNumber, int epochNumber) {
        try {
            if (accumulator.get() == null) {
                // we don't want accumulator to be attached to workspaces
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    accumulator.set(Nd4j.create(array.shape(), array.ordering()));
                }
            }

            // accumulate gradients updates in residental array
            accumulator.get().addi(array);

            if (isDebug)
                log.info("thread {} locking at Register", Thread.currentThread().getId());

            // block until ParallelWrapper sends us message about number of threads in this cycle
            if (!bypassMode.get())
                while (!registered.get()) {
                    ThreadUtils.uncheckedSleep(1);
                    if (throwable.isTriggered())
                        throw new RuntimeException(throwable.get());
                }

            if (isDebug)
                log.info("thread {} unlocking at Register", Thread.currentThread().getId());

            // propagate changes & modify accumulator
            handler.broadcastUpdates(accumulator.get(), iterationNumber, epochNumber);

            // we're blocking here, untill all done broadcasting updates
            synchronize(currentConsumers.get());
        } catch (Exception e) {
            throwable.setIfFirst(e);
            throw new RuntimeException(e);
        }
    }

    /**
     * This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
     * 

* PLEASE NOTE: array is expected to be ready for use and match params dimensionality * * @param array */ @Override public void receiveUpdate(INDArray array) { try { // we're replicating COMPRESSED MESSAGES, decompression will be thread-local for (int i = 0; i < parties; i++) { // we don't want to have same workspace to be accessible by 2 different threads for now /* With synchronized external data, it's impossible to deadlock here. Each worker is guaranteed to have at least NUM_WORKERS slots in buffer. So we use this lock just to ensure thread-safety of corresponding workspaces */ locks.get(i).lock(); try (MemoryWorkspace workspace = workspaces.get(i).notifyScopeEntered()) { // we might just scope out of workspace here, instead of throwing error out if (array.data().length() > (initialMemory / queueSize) / Nd4j.sizeOfDataType(array.data().dataType())) throw new ND4JIllegalStateException("Not enough memory to handle update: [" + array.data().length() * Nd4j.sizeOfDataType(array.data().dataType()) + " bytes required]. Please increase memory amount for GradientsAccumulator"); INDArray compressed = array.unsafeDuplication(); try { messages.get(i).put(compressed); } catch (InterruptedException e) { Thread.currentThread().interrupt(); log.warn("Something bad at index_{}", i); throw new RuntimeException(e); } } locks.get(i).unlock(); } } catch (Exception e) { throwable.setIfFirst(e); throw new RuntimeException(e); } } /** * This method resets all accumulated updates (if any) */ @Override public void reset() { // just replace accumulator, gc will do the rest accumulator = new ThreadLocal<>(); // resetting this counter too workersCounter.set(0); // reset indexes too index = new ThreadLocal<>(); // throw away message queues for (int i = 0; i < parties; i++) { messages.get(i).clear(); } } @Override public boolean hasAnything() { return externalSource != null && externalSource.hasAnything(); //externalUpdatesAvailable.get(); } public static class Builder { protected int parties; protected ThresholdAlgorithm thresholdAlgorithm; protected ResidualPostProcessor residualPostProcessor; protected long initialMemory = DEFAULT_INITIAL_MEMORY; protected int queueSize = 5; protected MessageHandler handler; protected int boundary = Integer.MAX_VALUE; protected boolean encodingDebugMode; /** * This * @param parties */ public Builder(int parties) { if (parties < 1) throw new DL4JInvalidConfigException( "Number of parties for GradientsAccumulation should be positive value"); this.parties = parties; } /** * This method allows to specify MessageHandler instance * * Default value: EncodingHandler * @param handler * @return */ public Builder messageHandler(@NonNull MessageHandler handler) { this.handler = handler; return this; } /** * This method allows to set the ThresholdAlgorithm to be used for determining the threshold * @return */ public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) { this.thresholdAlgorithm = thresholdAlgorithm; return this; } /** * Set the residual post processor */ public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor){ this.residualPostProcessor = residualPostProcessor; return this; } /** * This method enables optional limit for max number of updates per message * * Default value: Integer.MAX_VALUE (no limit) * @param boundary positive value in range 0..1 * @return */ public Builder updatesBoundary(int boundary) { if (boundary <= 0) throw new DL4JInvalidConfigException("Boundary should have positive value"); this.boundary = boundary; return this; } /** * This method allows to define buffer memory parameters for this GradientsAccumulator * * Default values: 100MB initialMemory, 5 queueSize * @param initialMemory * @param queueSize * @return */ public Builder memoryParameters(long initialMemory, int queueSize) { this.initialMemory = initialMemory; this.queueSize = queueSize; return this; } public Builder encodingDebugMode(boolean enable){ this.encodingDebugMode = enable; return this; } public EncodedGradientsAccumulator build() { if (handler == null) { Preconditions.checkNotNull(thresholdAlgorithm, "Both threshold algorithm and handler are null - one or the other must be set"); handler = new EncodingHandler(thresholdAlgorithm, residualPostProcessor, boundary, encodingDebugMode); } EncodedGradientsAccumulator accumulator = new EncodedGradientsAccumulator(parties, handler, initialMemory, queueSize, boundary, encodingDebugMode); return accumulator; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy