All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.parallelism.trainer.DefaultTrainer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.parallelism.trainer;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * Trains datasets using a standard in memory
 * parameter averaging technique.
 * Think of this worker as the simplest form of doing parameter averaging
 *
 * @author Adam Gibson
 */
@Builder
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
public class DefaultTrainer extends Thread implements Trainer {

    protected Model replicatedModel;

    // TODO: make queue size configurable
    @Builder.Default
    protected LinkedBlockingQueue queue = new LinkedBlockingQueue<>(1);
    @Builder.Default
    protected LinkedBlockingQueue queueMDS = new LinkedBlockingQueue<>(1);
    @Builder.Default
    protected AtomicInteger running = new AtomicInteger(0);
    @Builder.Default
    protected AtomicBoolean shouldUpdate = new AtomicBoolean(false);
    @Builder.Default
    protected AtomicBoolean shouldStop = new AtomicBoolean(false);
    protected Exception thrownException;
    @Builder.Default
    protected volatile boolean useMDS = false;
    @Getter protected String uuid;
    @Builder.Default
    protected boolean onRootModel = false;
    @Builder.Default
    protected volatile AtomicLong lastEtlTime = new AtomicLong(0);

    @Builder.Default
    protected AtomicBoolean nullMode = new AtomicBoolean(false);
    protected DataSet nullDataSet;

    @Builder.Default
    protected AtomicBoolean isStopped = new AtomicBoolean(false);

    protected ParallelWrapper parallelWrapper;
    protected WorkspaceMode workspaceMode;
    protected int averagingFrequency;
    protected int threadId;
    protected Model originalModel;

    protected final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();

    @Override
    public void feedMultiDataSet(@NonNull MultiDataSet dataSet, long etlTime) {
        setupIfNeccessary();
        try {
            queueMDS.put(dataSet);
            running.incrementAndGet();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            // do nothing
        }

        if (lastEtlTime == null)
            lastEtlTime = new AtomicLong(0);

        lastEtlTime.set(etlTime);
    }

    @Override
    public void feedDataSet(DataSet dataSet, long etlTime) {
        setupIfNeccessary();
        if (dataSet != null) {
            try {
                queue.put(dataSet);
                running.incrementAndGet();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                // do nothing
            }
        } else {
            if (nullMode == null)
                nullMode = new AtomicBoolean(false);

            nullMode.set(true);
        }

        if (lastEtlTime == null)
            lastEtlTime = new AtomicLong(0);

        lastEtlTime.set(etlTime);
    }

    @Override
    public Model getModel() {
        return replicatedModel;
    }

    @Override
    public void updateModel(@NonNull Model model) {
        this.shouldUpdate.set(true);
        try {
            modelLock.writeLock().lock();
            if (replicatedModel instanceof MultiLayerNetwork) {


                replicatedModel.setParams(model.params().unsafeDuplication(true));

                Updater updater = ((MultiLayerNetwork) model).getUpdater();
                INDArray view = updater.getStateViewArray();

                if (view != null) {
                    updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
                    INDArray viewD = view.dup();

                    Nd4j.getExecutioner().commit();

                    updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
                }
            } else if (replicatedModel instanceof ComputationGraph) {
                replicatedModel.setParams(model.params().unsafeDuplication(true));

                ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
                INDArray view = updater.getStateViewArray();

                if (view != null) {
                    INDArray viewD = view.dup();

                    Nd4j.getExecutioner().commit();

                    updater = ((ComputationGraph) replicatedModel).getUpdater();
                    updater.setStateViewArray(viewD);
                }
            }

            Nd4j.getExecutioner().commit();
        } finally {
            modelLock.writeLock().unlock();
        }
    }



    protected void setupIfNeccessary() {
        if (queue == null)
            queue = new LinkedBlockingQueue<>(1);
        if (queueMDS == null)
            queueMDS = new LinkedBlockingQueue<>(1);
        if (running == null)
            running = new AtomicInteger(0);
        if (shouldStop == null)
            shouldStop = new AtomicBoolean(false);
        if (shouldUpdate == null)
            shouldUpdate = new AtomicBoolean(false);
        if (isStopped == null)
            isStopped = new AtomicBoolean(false);
        if (lastEtlTime == null)
            lastEtlTime = new AtomicLong(0);
    }

    @Override
    public boolean isRunning() {
        // if Trainer thread got exception during training - rethrow it here
        if (thrownException != null)
            throw new RuntimeException(thrownException);

        return running.get() == 0;
    }

    @Override
    public void shutdown() {
        shouldStop.set(true);
        while (!isStopped.get())
            LockSupport.parkNanos(1000L);

        shouldStop.set(false);
        isStopped.set(false);
    }

    protected void fit(DataSet dataSet) {
        if (replicatedModel instanceof MultiLayerNetwork) {
            if (lastEtlTime == null)
                lastEtlTime = new AtomicLong(0);

            ((MultiLayerNetwork) replicatedModel).setLastEtlTime(lastEtlTime.get());

            // we want this model locked out for possible updates
            try {
                modelLock.readLock().lock();
                ((MultiLayerNetwork) replicatedModel).fit(dataSet);
            } finally {
                modelLock.readLock().unlock();
            }
        } else if (replicatedModel instanceof ComputationGraph) {
            if (lastEtlTime == null)
                lastEtlTime = new AtomicLong(0);

            ((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());

            // we want this model locked out for possible updates
            try {
                modelLock.readLock().lock();
                ((ComputationGraph) replicatedModel).fit(dataSet);
            } finally {
                modelLock.readLock().unlock();
            }
        }
    }

    protected void fit(MultiDataSet dataSet) {
        if (lastEtlTime == null)
            lastEtlTime = new AtomicLong(0);

        ((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());

        // we want this model locked out for possible updates
        try {
            modelLock.readLock().lock();
            ((ComputationGraph) replicatedModel).fit(dataSet);
        } finally {
            modelLock.readLock().unlock();
        }
    }

    /**
     * This method does post-initialization configuration of Model.
     * Good place to configure listeners and all such a things
     */
    protected void postInit() {
        Collection oldListeners = new ArrayList<>();
        Collection replicatedListeners = new ArrayList<>();

        if (parallelWrapper.getListeners() != null) {
            oldListeners.addAll(parallelWrapper.getListeners());
        }
        configureListeners(uuid, oldListeners, replicatedListeners);

        this.replicatedModel.setListeners(replicatedListeners);
    }

    @Override
    public void run() {
        setupIfNeccessary();
        AtomicInteger iterationsCounter = new AtomicInteger(0);

        // FIXME: make this thing CUDA-compatible, and avoid RC at originalModel relocation
        if (threadId == 0)
            onRootModel = true;

        try {
            // we create fresh network, with the same configuration, as initially created by user
            // however, we don't need clone or anything here
            if (originalModel instanceof MultiLayerNetwork) {
                if (!onRootModel) {
                    MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(
                                    ((MultiLayerNetwork) originalModel).getLayerWiseConfigurations().toJson());
                    conf.setTrainingWorkspaceMode(workspaceMode);
                    this.replicatedModel = new MultiLayerNetwork(conf);

                    replicatedModel.init();

                    // we replicate original model params & updater state, just in case it's pre-trained model
                    try {
                        modelLock.writeLock().lock();
                        replicatedModel.setParams(originalModel.params().unsafeDuplication(true));

                        Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
                        Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater();

                        if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null)
                            updaterReplica.setStateViewArray((MultiLayerNetwork) replicatedModel,
                                    updaterOrigina.getStateViewArray().unsafeDuplication(true), false);

                        Nd4j.getExecutioner().commit();
                    } finally {
                        modelLock.writeLock().unlock();
                    }
                } else {
                    this.replicatedModel = originalModel;
                    if (!((MultiLayerNetwork) replicatedModel).isInitCalled())
                        this.replicatedModel.init();

                    ((MultiLayerNetwork) replicatedModel).getLayerWiseConfigurations()
                                    .setTrainingWorkspaceMode(workspaceMode);
                }
            } else if (originalModel instanceof ComputationGraph) {
                if (!onRootModel) {
                    ComputationGraphConfiguration conf = ComputationGraphConfiguration
                                    .fromJson(((ComputationGraph) originalModel).getConfiguration().toJson());
                    conf.setTrainingWorkspaceMode(workspaceMode);

                    this.replicatedModel = new ComputationGraph(conf);
                    this.replicatedModel.init();

                    // we replicate original model params & updater state, just in case it's pre-trained model
                    try {
                        modelLock.writeLock().lock();
                        replicatedModel.setParams(originalModel.params().unsafeDuplication(true));

                        ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater();
                        ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater();

                        if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null)
                            updaterReplica.setStateViewArray(
                                    updaterOrigina.getStateViewArray().unsafeDuplication(true));

                        Nd4j.getExecutioner().commit();
                    } finally {
                        modelLock.writeLock().unlock();
                    }
                } else {
                    this.replicatedModel = originalModel;
                    this.replicatedModel.init();
                    ((ComputationGraph) replicatedModel).getConfiguration().setTrainingWorkspaceMode(workspaceMode);
                }
            }

            if (replicatedModel == null)
                log.error("replicatedModel is NULL at worker_{}", threadId);

            // classes that extend DefaultTrainer might hook something there
            postInit();

            if (!useMDS) {
                while (!shouldStop.get()) {
                    DataSet dataSet = null;
                    if (nullMode == null || !nullMode.get())
                        dataSet = queue.poll(10, TimeUnit.MILLISECONDS);
                    else {
                        // this code branch is for debugging only, please ignore :)
                        if (nullDataSet == null)
                            nullDataSet = new org.nd4j.linalg.dataset.DataSet(Nd4j.create(64, 28 * 28),
                                            Nd4j.create(64, 10));

                        dataSet = nullDataSet;
                    }
                    if (dataSet != null) {

                        fit(dataSet);

                        // if we don't support cross-device stuff (like multi-gpu on windows) - sync back to host
                        if (!Nd4j.getAffinityManager().isCrossDeviceAccessSupported() && (averagingFrequency == 0
                                        || iterationsCounter.incrementAndGet() % averagingFrequency == 0)
                                        && averagingRequired()) {
                            // we ensure all operations are finished in this training round
                            Nd4j.getExecutioner().commit();

                            // we ensure memory is updated on host side
                            Nd4j.getAffinityManager().ensureLocation(replicatedModel.params(),
                                            AffinityManager.Location.HOST);

                            if (replicatedModel instanceof MultiLayerNetwork) {
                                Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
                                if (updaterReplica.getStateViewArray() != null)
                                    Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
                                                    AffinityManager.Location.HOST);
                            } else {
                                ComputationGraphUpdater updaterReplica =
                                                ((ComputationGraph) replicatedModel).getUpdater();

                                if (updaterReplica.getStateViewArray() != null)
                                    Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
                                                    AffinityManager.Location.HOST);
                            }
                        }

                        running.decrementAndGet();
                    }
                }
            } else {
                // loop for MultiDataSet
                while (!shouldStop.get()) {
                    MultiDataSet dataSet = queueMDS.poll(10, TimeUnit.MILLISECONDS);
                    if (dataSet != null) {

                        // just fitting
                        fit(dataSet);

                        // if we don't support cross-device stuff (like multi-gpu on windows) - sync back to host
                        if (!Nd4j.getAffinityManager().isCrossDeviceAccessSupported() && (averagingFrequency == 0
                                        || iterationsCounter.incrementAndGet() % averagingFrequency == 0)
                                        && averagingRequired()) {
                            // we ensure all operations are finished in this training round
                            Nd4j.getExecutioner().commit();

                            // we ensure memory is updated on host side
                            Nd4j.getAffinityManager().ensureLocation(replicatedModel.params(),
                                            AffinityManager.Location.HOST);

                            ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater();

                            if (updaterReplica.getStateViewArray() != null)
                                Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
                                                AffinityManager.Location.HOST);
                        }

                        running.decrementAndGet();
                    }
                }
            }
        } catch (Exception e) {
            this.thrownException = e;
            throw new RuntimeException(e);
        } finally {
            log.debug("Terminating all workspaces for trainer_{}", threadId);
            Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
            isStopped.set(true);
        }
    }

    @Override
    public void waitTillRunning() {
        while (running.get() != 0) {
            // if Trainer thread got exception during training - rethrow it here
            //log.info("Thread {} running {}", Thread.currentThread().getId(), running.get());
            if (thrownException != null)
                throw new RuntimeException(thrownException);

            LockSupport.parkNanos(1000L);
        }
    }


    @Override
    public void updateModelParams(INDArray params) {
        try {
            modelLock.writeLock().lock();

            // just set it right away
            replicatedModel.setParams(params.unsafeDuplication(true));
            Nd4j.getExecutioner().commit();
        } finally {
            modelLock.writeLock().unlock();
        }
    }

    @Override
    public void updateUpdaterParams(INDArray params) {
        try {
            modelLock.writeLock().lock();

            if (replicatedModel instanceof ComputationGraph)
                ((ComputationGraph) replicatedModel).getUpdater().getStateViewArray().assign(params.unsafeDuplication(true));
            else if (replicatedModel instanceof MultiLayerNetwork)
                ((MultiLayerNetwork) replicatedModel).getUpdater().getStateViewArray().assign(params.unsafeDuplication(true));

            Nd4j.getExecutioner().commit();
        } finally {
            modelLock.writeLock().unlock();
        }
    }

    @Override
    public boolean averagingRequired() {
        return true;
    }

    protected static TrainingListener cloneListener(TrainingListener original) {
        if (original instanceof RoutingIterationListener) {
            return ((RoutingIterationListener) original).clone();
        }
        return original;
    }


    protected void configureListeners(String workerUUID, Collection oldListeners,
                    Collection replicatedListeners) {
        for (TrainingListener listener : oldListeners) {
            TrainingListener l = cloneListener(listener);

            if (l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                //We're assuming session ID is set by the original RoutingIterationListener constructor, which means
                // it will be synced across all cloned instances
                rl.setSessionID(((RoutingIterationListener) listener).getSessionID());
                rl.setWorkerID(workerUUID);

                StatsStorageRouter currentRouter = ((RoutingIterationListener) listener).getStorageRouter();
                if (currentRouter != null) {
                    //User has set router on the listener/model, instead of via the
                    // setListeners(StatsStorageRouter, ...) method
                    rl.setStorageRouter(currentRouter);
                } else {
                    rl.setStorageRouter(parallelWrapper.getStorageRouter());
                }

            }
            if (!replicatedListeners.contains((l))) {
                replicatedListeners.add(l);
            }
        }
    }


    public static class DefaultTrainerBuilder {
        public DefaultTrainerBuilder() {}
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy