All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.parallelism.parameterserver.ParameterServerTrainerContext Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.parallelism.parameterserver;

import io.aeron.driver.MediaDriver;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.factory.TrainerContext;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.nd4j.parameterserver.client.ParameterServerClient;
import org.nd4j.parameterserver.node.ParameterServerNode;

public class ParameterServerTrainerContext implements TrainerContext {

    private ParameterServerNode parameterServerNode;
    private MediaDriver mediaDriver;
    private MediaDriver.Context mediaDriverContext;
    private int statusServerPort = 33000;
    private int numUpdatesPerEpoch = 1;
    private String[] parameterServerArgs;
    private int numWorkers = 1;

    /**
     * Initialize the context
     *
     * @param model
     * @param args the arguments to initialize with (maybe null)
     */
    @Override
    public void init(Model model, Object... args) {
        mediaDriverContext = new MediaDriver.Context();
        mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
        parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
        if (parameterServerArgs == null)
            parameterServerArgs = new String[] {"-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p",
                            "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh",
                            "localhost", "-sp", String.valueOf(statusServerPort), "-u",
                            String.valueOf(numUpdatesPerEpoch)};

    }

    /**
     * Create a {@link Trainer}
     * based on the given parameters
     *
     * @param threadId   the thread id to use for this worker
     * @param model      the model to start the trainer with
     * @param rootDevice the root device id
     * @param useMDS     whether to use MultiDataSet or DataSet
     *                   or not
     * @param wrapper    the wrapper instance to use with this trainer (this reference is needed
     *                   for coordination with the {@link ParallelWrapper} 's {@link TrainingListener}
     * @return the created training instance
     */
    @Override
    public Trainer create(String uuid, int threadId, Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper,
                    WorkspaceMode mode, int averagingFrequency) {
        return ParameterServerTrainer.builder().originalModel(model).parameterServerClient(ParameterServerClient
                        .builder().aeron(parameterServerNode.getAeron())
                        .ndarrayRetrieveUrl(
                                        parameterServerNode.getSubscriber()[threadId].getResponder().connectionUrl())
                        .ndarraySendUrl(parameterServerNode.getSubscriber()[threadId].getSubscriber().connectionUrl())
                        .subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort)
                        .subscriberPort(40625 + threadId).subscriberStream(12 + threadId).build())
                        .replicatedModel(model).threadId(threadId).parallelWrapper(wrapper).useMDS(useMDS).build();
    }

    @Override
    public void finalizeRound(Model originalModel, Model... models) {
        // no-op
    }

    @Override
    public void finalizeTraining(Model originalModel, Model... models) {
        // no-op
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy