All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.parallelism.parameterserver.ParameterServerParallelWrapper Maven / Gradle / Ivy

package org.deeplearning4j.parallelism.parameterserver;

import io.aeron.driver.MediaDriver;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.agrona.CloseHelper;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.parameterserver.client.ParameterServerClient;
import org.nd4j.parameterserver.node.ParameterServerNode;

import java.io.Closeable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * Parallelwrapper using
 * a parameter server
 * for training
 *
 * @author Adam Gibson
 */
@Builder
@Data
@Slf4j
public class ParameterServerParallelWrapper implements AutoCloseable {
    private ExecutorService executorService;
    private int numWorkers;
    private Trainer[] parameterServerClient;
    private ParameterServerNode parameterServerNode;
    private MediaDriver mediaDriver;
    private MediaDriver.Context mediaDriverContext;
    private boolean init = false;
    private Model model;
    private ComputationGraph computationGraph;
    private MultiLayerNetwork multiLayerNetwork;
    //work queue for datasets
    private LinkedBlockingQueue linkedBlockingQueue;
    private AtomicBoolean running;
    private int preFetchSize;
    private String[] parameterServerArgs;
    private int numUpdatesPerEpoch;
    private int numEpochs;
    private int statusServerPort = 33000;

    public void fit(DataSetIterator source) {
        if (!init)
            init(source);
        DataSetIterator iterator;
        if (preFetchSize > 0 && source.asyncSupported())
            iterator = new AsyncDataSetIterator(source, preFetchSize);
        else
            iterator = source;
        for (int i = 0; i < numEpochs; i++) {
            while (iterator.hasNext()) {
                DataSet next = iterator.next();
                addObject(next);
            }

            iterator.reset();

            log.info(String.format("Completed epoch %d", i));
        }

    }



    public void fit(MultiDataSetIterator multiDataSetIterator) {
        if (!init)
            init(multiDataSetIterator);

        MultiDataSetIterator iterator = null;
        if (preFetchSize > 0 && multiDataSetIterator.asyncSupported()) {
            iterator = new AsyncMultiDataSetIterator(multiDataSetIterator, preFetchSize);
        } else
            iterator = multiDataSetIterator;

        while (iterator.hasNext()) {
            org.nd4j.linalg.dataset.api.MultiDataSet next = iterator.next();
            addObject(next);
        }
    }

    //poll when workers are at capacity
    private void addObject(Object next) {
        try {
            while (!linkedBlockingQueue.offer(next, 1, TimeUnit.SECONDS))
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }


    private int numUpdatesPerEpoch(MultiDataSetIterator iterator) {
        if (!iterator.resetSupported())
            throw new IllegalStateException("Iterator must support reset()");
        int ret = 0;
        while (iterator.hasNext()) {
            iterator.next();
            ret++;
        }

        iterator.reset();

        return ret;
    }


    private int numUpdatesPerEpoch(DataSetIterator iterator) {
        if (!iterator.resetSupported())
            throw new IllegalStateException("Iterator must support reset()");
        int ret = 0;
        while (iterator.hasNext()) {
            iterator.next();
            ret++;
        }

        iterator.reset();

        return ret;
    }


    private void init(Object iterator) {
        if (numEpochs < 1)
            throw new IllegalStateException("numEpochs must be >= 1");
        //determine the number of updates per epoch (number of minibatches total for this iterator)
        //TODO: make this efficient
        if (iterator instanceof DataSetIterator) {
            DataSetIterator dataSetIterator = (DataSetIterator) iterator;
            numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
        } else if (iterator instanceof MultiDataSetIterator) {
            MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
            numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);

        } else
            throw new IllegalArgumentException(
                            "Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");

        mediaDriverContext = new MediaDriver.Context();
        mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
        parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
        running = new AtomicBoolean(true);
        if (parameterServerArgs == null)
            parameterServerArgs = new String[] {"-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p",
                            "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh",
                            "localhost", "-sp", String.valueOf(statusServerPort), "-u",
                            String.valueOf(numUpdatesPerEpoch)};

        if (numWorkers == 0)
            numWorkers = Runtime.getRuntime().availableProcessors();

        linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);

        //pass through args for the parameter server subscriber
        parameterServerNode.runMain(parameterServerArgs);

        while (!parameterServerNode.subscriberLaunched()) {
            try {
                Thread.sleep(10000);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }


        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }


        log.info("Parameter server started");

        parameterServerClient = new Trainer[numWorkers];
        executorService = Executors.newFixedThreadPool(numWorkers);

        for (int i = 0; i < numWorkers; i++) {
            Model model = null;
            if (this.model instanceof ComputationGraph) {
                ComputationGraph computationGraph = (ComputationGraph) this.model;
                model = computationGraph.clone();
            } else if (this.model instanceof MultiLayerNetwork) {
                MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
                model = multiLayerNetwork.clone();
            }
            parameterServerClient[i] = new Trainer(
                            ParameterServerClient.builder().aeron(parameterServerNode.getAeron())
                                            .ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder()
                                                            .connectionUrl())
                                            .ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber()
                                                            .connectionUrl())
                                            .subscriberHost("localhost").masterStatusHost("localhost")
                                            .masterStatusPort(statusServerPort).subscriberPort(40625 + i)
                                            .subscriberStream(12 + i).build(),
                            running, linkedBlockingQueue, model);
            final int j = i;
            executorService.submit(() -> parameterServerClient[j].start());

        }

        init = true;
        log.info("Initialized wrapper");
    }

    /**
     * Closes this resource, relinquishing any underlying resources.
     * This method is invoked automatically on objects managed by the
     * {@code try}-with-resources statement.
     * 

*

While this interface method is declared to throw {@code * Exception}, implementers are strongly encouraged to * declare concrete implementations of the {@code close} method to * throw more specific exceptions, or to throw no exception at all * if the close operation cannot fail. *

*

Cases where the close operation may fail require careful * attention by implementers. It is strongly advised to relinquish * the underlying resources and to internally mark the * resource as closed, prior to throwing the exception. The {@code * close} method is unlikely to be invoked more than once and so * this ensures that the resources are released in a timely manner. * Furthermore it reduces problems that could arise when the resource * wraps, or is wrapped, by another resource. *

*

Implementers of this interface are also strongly advised * to not have the {@code close} method throw {@link * InterruptedException}. *

* This exception interacts with a thread's interrupted status, * and runtime misbehavior is likely to occur if an {@code * InterruptedException} is {@linkplain Throwable#addSuppressed * suppressed}. *

* More generally, if it would cause problems for an * exception to be suppressed, the {@code AutoCloseable.close} * method should not throw it. *

*

Note that unlike the {@link Closeable#close close} * method of {@link Closeable}, this {@code close} method * is not required to be idempotent. In other words, * calling this {@code close} method more than once may have some * visible side effect, unlike {@code Closeable.close} which is * required to have no effect if called more than once. *

* However, implementers of this interface are strongly encouraged * to make their {@code close} methods idempotent. * * @throws Exception if this resource cannot be closed */ @Override public void close() throws Exception { if (executorService != null) executorService.shutdown(); if (mediaDriver != null) CloseHelper.close(mediaDriver); if (parameterServerNode != null) parameterServerNode.close(); } @AllArgsConstructor public static class Trainer implements AutoCloseable { private ParameterServerClient parameterServerClient; private AtomicBoolean running; private LinkedBlockingQueue work; private Model model; public void start() { log.info("Begin polling running queue"); while (running.get()) { try { Object next = work.poll(1, TimeUnit.SECONDS); if (next == null) continue; //send new parameters if (parameterServerClient.isReadyForNext()) { log.info("Retrieving new array"); //get the new parameters from the server INDArray newParams = parameterServerClient.getArray(); model.setParams(newParams); log.info("Set new params"); } else log.debug("Continuing training"); if (next instanceof DataSet) { DataSet dataSet = (DataSet) next; if (model instanceof ComputationGraph) { ComputationGraph computationGraph = (ComputationGraph) model; computationGraph.fit(dataSet); } else { MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) model; log.info("Calling fit on multi layer network"); multiLayerNetwork.fit(dataSet); } log.info("About to send params in"); //send the updated params parameterServerClient.pushNDArray(model.params()); log.info("Sent params"); } else { MultiDataSet dataSet = (MultiDataSet) next; if (model instanceof ComputationGraph) { ComputationGraph computationGraph = (ComputationGraph) model; computationGraph.fit(dataSet); } else { throw new IllegalArgumentException("MultiLayerNetworks can't fit multi datasets"); } log.info("Sending parameters"); //send the updated params parameterServerClient.pushNDArray(model.params()); } } catch (InterruptedException e) { e.printStackTrace(); Thread.currentThread().interrupt(); } } log.info("Worker finished"); } /** * Closes this resource, relinquishing any underlying resources. * This method is invoked automatically on objects managed by the * {@code try}-with-resources statement. *

*

While this interface method is declared to throw {@code * Exception}, implementers are strongly encouraged to * declare concrete implementations of the {@code close} method to * throw more specific exceptions, or to throw no exception at all * if the close operation cannot fail. *

*

Cases where the close operation may fail require careful * attention by implementers. It is strongly advised to relinquish * the underlying resources and to internally mark the * resource as closed, prior to throwing the exception. The {@code * close} method is unlikely to be invoked more than once and so * this ensures that the resources are released in a timely manner. * Furthermore it reduces problems that could arise when the resource * wraps, or is wrapped, by another resource. *

*

Implementers of this interface are also strongly advised * to not have the {@code close} method throw {@link * InterruptedException}. *

* This exception interacts with a thread's interrupted status, * and runtime misbehavior is likely to occur if an {@code * InterruptedException} is {@linkplain Throwable#addSuppressed * suppressed}. *

* More generally, if it would cause problems for an * exception to be suppressed, the {@code AutoCloseable.close} * method should not throw it. *

*

Note that unlike the {@link Closeable#close close} * method of {@link Closeable}, this {@code close} method * is not required to be idempotent. In other words, * calling this {@code close} method more than once may have some * visible side effect, unlike {@code Closeable.close} which is * required to have no effect if called more than once. *

* However, implementers of this interface are strongly encouraged * to make their {@code close} methods idempotent. * * @throws Exception if this resource cannot be closed */ @Override public void close() throws Exception {} } }