* Copyright (c) 2015-2018 Skymind, Inc.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
* SPDX-License-Identifier: Apache-2.0
package org.nd4j.parameterserver;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import com.beust.jcommander.Parameters;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import io.aeron.Aeron;
import io.aeron.driver.MediaDriver;
import io.aeron.driver.ThreadingMode;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.val;
import org.agrona.CloseHelper;
import org.agrona.concurrent.BusySpinIdleStrategy;
import org.json.JSONObject;
import org.nd4j.aeron.ipc.AeronNDArraySubscriber;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.aeron.ipc.NDArrayCallback;
import org.nd4j.aeron.ipc.NDArrayHolder;
import org.nd4j.aeron.ipc.response.AeronNDArrayResponder;
import org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.parameterserver.model.MasterConnectionInfo;
import org.nd4j.parameterserver.model.ServerState;
import org.nd4j.parameterserver.model.SlaveConnectionInfo;
import org.nd4j.parameterserver.model.SubscriberState;
import org.nd4j.parameterserver.updater.ParameterServerUpdater;
import org.nd4j.parameterserver.updater.SoftSyncParameterUpdater;
import org.nd4j.parameterserver.updater.SynchronousParameterUpdater;
import org.nd4j.parameterserver.util.CheckSocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
* Subscriber main class for
* the parameter
* averaging server
* @author Adam Gibson
@Parameters(separators = ",")
public class ParameterServerSubscriber implements AutoCloseable {
private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class);
@Parameter(names = {"-p", "--port"}, description = "The port to listen on for the daemon", arity = 1)
private int port = 40123;
@Parameter(names = {"-id", "--streamId"}, description = "The stream id to listen on", arity = 1)
private int streamId = 10;
@Parameter(names = {"-h", "--host"}, description = "Host for the server to bind to", arity = 1)
private String host = "localhost";
@Parameter(names = {"-d", "--deleteDirectoryOnStart"}, description = "Delete aeron directory on startup.",
arity = 1)
private boolean deleteDirectoryOnStart = true;
@Parameter(names = {"-m", "--master"}, description = "Whether this subscriber is a master node or not.", arity = 1)
private boolean master = false;
@Parameter(names = {"-pm", "--publishmaster"},
description = "Publish master url: host:port - this is for peer nodes needing to publish to another peer.",
arity = 1)
private String publishMasterUrl = "localhost:40123";
@Parameter(names = {"-md", "--mediadriverdirectory"},
description = "The media driver directory opName. This is for when the media driver is started as a separate process.",
arity = 1)
private String mediaDriverDirectoryName;
@Parameter(names = {"-sp", "--statusserverport"}, description = "The status server port, defaults to 9000.",
arity = 1)
private int statusServerPort = 9000;
@Parameter(names = {"-sh", "--statusserverhost"}, description = "The status host, defaults to localhost.",
arity = 1)
private String statusServerHost = "localhost";
@Parameter(names = {"-up", "--update"},
description = "The update opType for this parameter server. Defaults to sync. You can specify custom and use a jvm argument -Dorg.nd4j.parameterserver.updatetype=your.fully.qualified.class if you want to use a custom class. This must be able to be instantiated from an empty constructor though.",
arity = 1)
private String updateTypeString = UpdateType.SYNC.toString().toLowerCase();
private UpdateType updateType = UpdateType.SYNC;
@Parameter(names = {"-s", "--shape"}, description = "The shape of the ndarray", arity = 1)
private List shape;
@Parameter(names = {"-hbi", "--heartbeatinterval"}, description = "Heartbeat interval in ms", arity = 1)
private int heartbeatMs = 1000;
private ObjectMapper objectMapper = new ObjectMapper();
private ScheduledExecutorService scheduledExecutorService;
@Parameter(names = {"-u", "--updatesPerEpoch"}, description = "The number of updates per epoch", arity = 1,
required = true)
private int updatesPerEpoch;
* Specify a custom class as a jvm arg.
* Note that this class must be a fully qualified classname
public final static String CUSTOM_UPDATE_TYPE = "org.nd4j.parameterserver.updatetype";
* Update types are for
* instantiating various kinds of update types
public enum UpdateType {
private MediaDriver mediaDriver;
private AeronNDArrayResponder responder;
private AeronNDArraySubscriber subscriber;
private NDArrayCallback callback;
//alias for the callback where relevant
private ParameterServerListener parameterServerListener;
private Aeron aeron;
private ScheduledExecutorService heartbeat;
* Allow passing in a
* media driver that already exists
* @param mediaDriver
public ParameterServerSubscriber(MediaDriver mediaDriver) {
this.mediaDriver = mediaDriver;
* Return the current {@link SubscriberState}
* of this subscriber
* @return the current state of this subscriber
public SubscriberState asState() {
return SubscriberState.builder()
.parameterUpdaterStatus(parameterServerListener == null ? Collections.emptyMap()
: parameterServerListener.getUpdater().status())
.connectionInfo(isMaster() ? masterConnectionInfo().toString()
: slaveConnectionInfo().toString())
.serverState(subscriberLaunched() ?
* When this is a slave node
* it returns the connection url for this node
* and the associated master connection urls in the form of:
* host:port:streamId
* @return the slave connection info
public SlaveConnectionInfo slaveConnectionInfo() {
if (isMaster())
throw new IllegalStateException("Unable to determine slave connection info. This is a master node");
return SlaveConnectionInfo.builder().connectionUrl(subscriber.connectionUrl()).masterUrl(publishMasterUrl)
* When this is a master node,
* it returns the connection url for this node,
* it's slaves (if any exist) and the responder
* connection url in the form of:
* host:port:streamId
* @return the master connection info
public MasterConnectionInfo masterConnectionInfo() {
if (!isMaster())
throw new IllegalStateException("Unable to determine master connection info. This is a slave node");
return MasterConnectionInfo.builder().connectionUrl(subscriber.connectionUrl())
.responderUrl(responder.connectionUrl()).slaveUrls(new ArrayList<>()).build();
* @param args
public void run(String[] args) {
JCommander jcmdr = new JCommander(this);
try {
} catch (ParameterException e) {
//User provides invalid input -> print the usage info
try {
} catch (Exception e2) {
//ensure that the update opType is configured from the command line args
updateType = UpdateType.valueOf(updateTypeString.toUpperCase());
if (publishMasterUrl == null && !master)
throw new IllegalStateException("Please specify a master url or set master to true");
//allows passing in a media driver for things like unit tests
//also ensure we don't use a media driver when a directory is specified
//for a remote one
if (mediaDriver == null && mediaDriverDirectoryName == null) {
//length of array * sizeof(float)
int ipcLength = * 4;
//must be a power of 2
ipcLength *= 2;
//padding for NDArrayMessage
ipcLength += 64;
//Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED)
.maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy())
.receiverIdleStrategy(new BusySpinIdleStrategy())
.senderIdleStrategy(new BusySpinIdleStrategy());
mediaDriver = MediaDriver.launchEmbedded(mediaDriverCtx);
//set the variable since we are using a media driver directly
mediaDriverDirectoryName = mediaDriver.aeronDirectoryName();"Using media driver directory " + mediaDriver.aeronDirectoryName());
if (aeron == null)
this.aeron = Aeron.connect(getContext());
if (master) {
if (this.callback == null) {
ParameterServerUpdater updater = null;
//instantiate with shape instead of just length
switch (updateType) {
case SYNC:
updater = new SynchronousParameterUpdater(new InMemoryUpdateStorage(),
new InMemoryNDArrayHolder(Ints.toArray(shape)), updatesPerEpoch);
updater = new SoftSyncParameterUpdater();
case CUSTOM:
try {
updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE))
} catch (Exception e) {
throw new RuntimeException(e);
throw new IllegalStateException("Illegal opType of updater");
callback = new ParameterServerListener(Ints.toArray(shape), updater);
parameterServerListener = (ParameterServerListener) callback;
//start an extra daemon for responding to get queries
ParameterServerListener cast = (ParameterServerListener) callback;
responder = AeronNDArrayResponder.startSubscriber(aeron, host, port + 1, cast.getUpdater().ndArrayHolder(),
streamId + 1);"Started responder on master node " + responder.connectionUrl());
} else {
String[] publishMasterUrlArr = publishMasterUrl.split(":");
if (publishMasterUrlArr == null || publishMasterUrlArr.length < 2)
throw new IllegalStateException("Please specify publish master url as host:port");
callback = new PublishingListener(
String.format("aeron:udp?endpoint=%s:%s", publishMasterUrlArr[0], publishMasterUrlArr[1]),
Integer.parseInt(publishMasterUrlArr[2]), getContext());
}"Starting subscriber on " + host + ":" + port + " and stream " + streamId);
AtomicBoolean running = new AtomicBoolean(true);
//start a node
subscriber = AeronNDArraySubscriber.startSubscriber(aeron, host, port, callback, streamId, running);
while (!subscriber.launched()) {
//send heartbeat to a status server. There will usually be 1 status server per master.
//Only schedule this if a remote server is available.
if (CheckSocket.remotePortTaken(statusServerHost, statusServerPort, 10000)) {
scheduledExecutorService = Executors.newScheduledThreadPool(1);
final AtomicInteger failCount = new AtomicInteger(0);
scheduledExecutorService.scheduleAtFixedRate(() -> {
try {
if (failCount.get() >= 3)
SubscriberState subscriberState = asState();
JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort,
val entity ="Content-Type", "application/json")
} catch (Exception e) {
if (failCount.get() >= 3) {
log.warn("Failed to send update, shutting down likely?", e);
}, 0, heartbeatMs, TimeUnit.MILLISECONDS);
} else {"No status server found. Will not send heartbeats. Specified host was " + statusServerHost
+ " and port was " + statusServerPort);
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
//set the server for the status of the master and slave nodes
public void close() {
if (subscriber != null)
if (responder != null)
if (scheduledExecutorService != null)
//get a context
public Aeron.Context getContext() {
Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1)
.errorHandler(e -> log.error(e.toString(), e));
return ctx;
* Get the master ndarray from the
* internal {@link NDArrayHolder}
* @return the master ndarray
public INDArray getMasterArray() {
return parameterServerListener.getUpdater().ndArrayHolder().get();
* Returns true if the subscriber is launched
* @return
public boolean subscriberLaunched() {
return subscriber.launched();
public static void main(String[] args) {
new ParameterServerSubscriber().run(args);