Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.scaleout.actor.runner;
import akka.actor.*;
import akka.cluster.Cluster;
import akka.contrib.pattern.ClusterClient;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.routing.RoundRobinPool;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.scaleout.actor.core.ClusterListener;
import org.deeplearning4j.scaleout.actor.core.ModelSaver;
import org.deeplearning4j.scaleout.actor.core.actor.BatchActor;
import org.deeplearning4j.scaleout.actor.core.actor.MasterActor;
import org.deeplearning4j.scaleout.actor.core.actor.ModelSavingActor;
import org.deeplearning4j.scaleout.actor.core.actor.WorkerActor;
import org.deeplearning4j.scaleout.actor.util.ActorRefUtils;
import org.deeplearning4j.scaleout.aggregator.INDArrayAggregator;
import org.deeplearning4j.scaleout.aggregator.JobAggregator;
import org.deeplearning4j.scaleout.api.workrouter.WorkRouter;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.job.JobIterator;
import org.deeplearning4j.scaleout.messages.MoreWorkMessage;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.perform.WorkerPerformerFactory;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.scaleout.workrouter.IterativeReduceWorkRouter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.Duration;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
/**
* Controller for coordinating model training for a neural network based
* on parameters across a cluster for akka.
* @author Adam Gibson
*
*/
public class DeepLearning4jDistributed implements DeepLearningConfigurable,Serializable {
private static final long serialVersionUID = -4385335922485305364L;
private transient ActorSystem system;
private ActorRef mediator;
private static final Logger log = LoggerFactory.getLogger(DeepLearning4jDistributed.class);
private static String systemName = "ClusterSystem";
private String type = "master";
private Address masterAddress;
private JobIterator iter;
protected ActorRef masterActor;
protected ModelSaver modelSaver;
private transient ScheduledExecutorService exec;
private transient StateTracker stateTracker;
private int stateTrackerPort = -1;
private String masterHost;
private transient WorkRouter workRouter;
/**
* Master constructor
* @param type the type (worker)
* @param iter the dataset to use
*/
public DeepLearning4jDistributed(String type, JobIterator iter) {
this.type = type;
this.iter = iter;
}
/**
* Master constructor
* @param iter the dataset to use
*/
public DeepLearning4jDistributed(JobIterator iter,StateTracker stateTracker) {
this("master",iter);
this.stateTracker = stateTracker;
}
/**
* Master constructor
* @param iter the dataset to use
*/
public DeepLearning4jDistributed(JobIterator iter) {
this("master",iter);
}
/**
* The worker constructor
* @param type the type to use
* @param address the address of the master
*/
public DeepLearning4jDistributed(String type, String address) {
this.type = type;
URI u = URI.create(address);
masterAddress = Address.apply(u.getScheme(), u.getUserInfo(), u.getHost(), u.getPort());
}
public DeepLearning4jDistributed() {
super();
}
/**
* Start a backend with the given role
* @param joinAddress the join address
* @param c the neural network configuration
* @return the actor for this backend
*/
public Address startBackend(Address joinAddress,Configuration c,JobIterator iter,StateTracker stateTracker) {
ActorRefUtils.addShutDownForSystem(system);
system.actorOf(Props.create(ClusterListener.class));
try {
Class routerClazz =
(Class) Class.forName(c.get(WorkRouter.WORK_ROUTER, IterativeReduceWorkRouter.class.getName()));
Constructor constructor = routerClazz.getConstructor(StateTracker.class);
workRouter = (WorkRouter) constructor.newInstance(stateTracker);
} catch (Exception e) {
throw new RuntimeException(e);
}
workRouter.setup(c);
ActorRef batchActor = system.actorOf(Props.create(BatchActor.class,iter,stateTracker,c,workRouter),"batch");
log.info("Started batch actor");
Props masterProps = Props.create(MasterActor.class,c,batchActor,stateTracker,workRouter);
/*
* Starts a master: in the active state with the poison pill upon failure with the role of master
*/
final Address realJoinAddress = (joinAddress == null) ? Cluster.get(system).selfAddress() : joinAddress;
c.set(MASTER_URL,realJoinAddress.toString());
if(exec == null)
exec = Executors.newScheduledThreadPool(2);
Cluster cluster = Cluster.get(system);
cluster.join(realJoinAddress);
exec.schedule(new Runnable() {
@Override
public void run() {
Cluster cluster = Cluster.get(system);
cluster.publishCurrentClusterState();
}
}, 10, TimeUnit.SECONDS);
masterActor = system.actorOf(
ClusterSingletonManager.defaultProps(masterProps, "master", PoisonPill.getInstance(), "master"));
log.info("Started master with address " + realJoinAddress.toString());
c.set(MASTER_PATH,ActorRefUtils.absPath(masterActor, system));
log.info("Set master abs path " + c.get(MASTER_PATH));
return realJoinAddress;
}
@Override
public void setup(final Configuration conf) {
system = ActorSystem.create(systemName);
ActorRefUtils.addShutDownForSystem(system);
mediator = DistributedPubSubExtension.get(system).mediator();
if(type.equals("master")) {
if(iter == null)
throw new IllegalStateException("Unable to initialize no dataset to iterate");
log.info("Starting master");
try {
if(stateTracker == null) {
if(stateTrackerPort > 0)
stateTracker = new HazelCastStateTracker(stateTrackerPort);
else
stateTracker = new HazelCastStateTracker();
}
if(stateTracker.jobAggregator() == null) {
Class clazz =
(Class) Class.forName(conf.get(JobAggregator.AGGREGATOR, INDArrayAggregator.class.getName()));
JobAggregator agg = clazz.newInstance();
stateTracker.setJobAggregator(agg);
}
log.info("Started state tracker with connection string " + stateTracker.connectionString());
masterAddress = startBackend(null,conf,iter,stateTracker);
} catch (Exception e1) {
Thread.currentThread().interrupt();
throw new RuntimeException(e1);
}
log.info("Starting Save saver");
if(modelSaver == null)
system.actorOf(Props.create(ModelSavingActor.class,"model-saver",stateTracker));
else
system.actorOf(Props.create(ModelSavingActor.class,modelSaver,stateTracker));
//store it in zookeeper for service discovery
conf.set(MASTER_URL,getMasterAddress().toString());
conf.set(MASTER_PATH,ActorRefUtils.absPath(masterActor, system));
//sets up the connection string for reference on the external worker
conf.set(STATE_TRACKER_CONNECTION_STRING,stateTracker.connectionString());
ActorRefUtils.registerConfWithZooKeeper(conf, system);
system.scheduler().schedule(Duration.create(1, TimeUnit.MINUTES),
Duration.create(1, TimeUnit.MINUTES),
new Runnable() {
@Override
public void run() {
if (!system.isTerminated()) {
try {
log.info("Current cluster members " +
Cluster.get(system).readView().members());
} catch (Exception e) {
log.warn("Tried reading cluster members during shutdown");
}
}
}
}, system.dispatcher());
}
else {
log.info("Starting worker node");
Address a = AddressFromURIString.parse(conf.get(MASTER_URL));
Configuration c = new Configuration(conf);
Cluster cluster = Cluster.get(system);
cluster.join(a);
try {
String host = a.host().get();
if(host == null)
throw new IllegalArgumentException("No host applyTransformToDestination for worker");
String connectionString = conf.get(STATE_TRACKER_CONNECTION_STRING);
//issue with setting the master url, fallback
if(connectionString.contains("0.0.0.0")) {
if(masterHost == null)
throw new IllegalStateException("No master host specified and host discovery was lost due to" +
" improper setup on the master (related to hostname resolution) Please run the following" +
" command on your host: sudo hostname YOUR_HOST_NAME." +
" This will make your hostname resolution work correctly on master.");
connectionString = connectionString.replace("0.0.0.0",masterHost);
}
log.info("Creating state tracker with connection string "+ connectionString);
if(stateTracker == null)
stateTracker = new HazelCastStateTracker(connectionString);
} catch (Exception e1) {
Thread.currentThread().interrupt();
throw new RuntimeException(e1);
}
startWorker(c);
system.scheduler().schedule(Duration.create(1, TimeUnit.MINUTES), Duration.create(1, TimeUnit.MINUTES), new Runnable() {
@Override
public void run() {
log.info("Current cluster members " + Cluster.get(system).readView().members());
}
},system.dispatcher());
log.info("Setup worker nodes");
}
//only start dropwizard on the master
if(type.equals("master")) {
stateTracker.startRestApi();
}
else if(stateTracker instanceof HazelCastStateTracker)
log.info("Not starting drop wizard; worker state detected");
}
public void startWorker(Configuration conf) {
Address contactAddress = AddressFromURIString.parse(conf.get(MASTER_URL));
system.actorOf(Props.create(ClusterListener.class));
log.info("Attempting to join node " + contactAddress);
log.info("Starting workers");
Set initialContacts = new HashSet<>();
initialContacts.add(system.actorSelection(contactAddress + "/user/"));
RoundRobinPool pool = new RoundRobinPool(Runtime.getRuntime().availableProcessors());
ActorRef clusterClient = system.actorOf(ClusterClient.defaultProps(initialContacts),
"clusterClient");
try {
String host = contactAddress.host().get();
log.info("Connecting to host " + host);
int workers = stateTracker.numWorkers();
if(workers <= 1)
throw new IllegalStateException("Did not properly connect to cluster");
log.info("Joining cluster of size " + workers);
Class factoryClazz =
(Class) Class.forName(conf.get(WorkerPerformerFactory.WORKER_PERFORMER));
WorkerPerformerFactory factory = factoryClazz.newInstance();
WorkerPerformer performer = factory.create(conf);
Props p = pool.props(WorkerActor.propsFor(conf, stateTracker,performer));
system.actorOf(p, "worker");
Cluster cluster = Cluster.get(system);
cluster.join(contactAddress);
log.info("Worker joining cluster of " + stateTracker.workers().size());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Kicks off the distributed training.
* It will grab the optimal batch size off of
* the beginning of the dataset iterator which
* is based on the desired mini batch size (conf.getSplit())
*
* and the number of initial workers in the state tracker after setup.
*
* For example, if you have a mini batch size of 10 and 8 workers
*
* the initial @link{JobIterator#next(int batches)} would be
*
* 80, this would be 10 per worker.
*/
public void train() {
log.info("Publishing to results for training");
log.info("Started pipeline");
//start the pipeline
mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER,
MoreWorkMessage.getInstance()), mediator);
log.info("Published results");
while(!stateTracker.isDone()) {
log.info("State tracker not done...blocking");
try {
Thread.sleep(15000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
shutdown();
}
public Address getMasterAddress() {
return masterAddress;
}
public StateTracker getStateTracker() {
return stateTracker;
}
public void setStateTracker(
StateTracker stateTracker) {
this.stateTracker = stateTracker;
}
/**
*
* Shut down this network actor
*/
public void shutdown() {
//order matters here, the state tracker should
try {
system.shutdown();
}catch(Exception e ) {
// do nothing
}
try {
if(stateTracker != null)
stateTracker.shutdown();
}catch(Exception e ) {
// do nothing
}
}
public ModelSaver getModelSaver() {
return modelSaver;
}
/**
* Sets a custom model saver. This will allow custom directories
* among other things when saving snapshots.
* @param modelSaver the model saver to use
*/
public void setModelSaver(ModelSaver modelSaver) {
this.modelSaver = modelSaver;
}
/**
* Gets the state tracker port.
* A lot of state trackers will be servers
* that need to be bound on a port.
* This will allow overrides per implementation of the state tracker
* @return the state tracker port that the state tracker
* server will bind to
*/
public int getStateTrackerPort() {
return stateTrackerPort;
}
public void setStateTrackerPort(int stateTrackerPort) {
this.stateTrackerPort = stateTrackerPort;
}
public String getMasterHost() {
return masterHost;
}
public void setMasterHost(String masterHost) {
this.masterHost = masterHost;
}
}