ml.shifu.guagua.example.nn.NNWorker Maven / Gradle / Ivy
/*
* Copyright [2013-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.guagua.example.nn;
import java.io.IOException;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.example.nn.meta.NNParams;
import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Splitter;
/**
* {@link NNWorker} is used to compute NN model according to splits assigned. The result will be sent to master for
* accumulation.
*
*
* Gradients in each worker will be sent to master to update weights of model in worker, which follows Encog's
* multi-core implementation.
*/
public class NNWorker extends
AbstractWorkerComputable, GuaguaWritableAdapter> {
private static final Logger LOG = LoggerFactory.getLogger(NNWorker.class);
/**
* Training data set
*/
private MLDataSet trainingData = null;
/**
* Testing data set
*/
private MLDataSet testingData = null;
/**
* NN algorithm runner instance.
*/
private Gradient gradient;
/**
* input record size, inc one by one.
*/
private long count;
private int inputs;
private int hiddens;
private int outputs;
@Override
public void init(WorkerContext context) {
inputs = NumberFormatUtils.getInt(context.getProps().getProperty(NNConstants.GUAGUA_NN_INPUT_NODES),
NNConstants.GUAGUA_NN_DEFAULT_INPUT_NODES);
hiddens = NumberFormatUtils.getInt(context.getProps().getProperty(NNConstants.GUAGUA_NN_HIDDEN_NODES),
NNConstants.GUAGUA_NN_DEFAULT_HIDDEN_NODES);
outputs = NumberFormatUtils.getInt(context.getProps().getProperty(NNConstants.GUAGUA_NN_OUTPUT_NODES),
NNConstants.GUAGUA_NN_DEFAULT_OUTPUT_NODES);
LOG.info("NNWorker is loading data into memory and disk.");
double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.5"));
long memoryStoreSize = (long) (Runtime.getRuntime().maxMemory() * memoryFraction);
this.trainingData = new MemoryDiskMLDataSet((long) (memoryStoreSize * 0.5), "train.egb", this.inputs,
this.outputs);
this.testingData = new MemoryDiskMLDataSet((long) (memoryStoreSize * 0.5), "test.egb", this.inputs,
this.outputs);
// cannot find a good place to close these two data set, using Shutdown hook
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
((MemoryDiskMLDataSet) (NNWorker.this.trainingData)).close();
((MemoryDiskMLDataSet) (NNWorker.this.testingData)).close();
}
}));
}
@Override
public NNParams doCompute(WorkerContext context) {
// For first iteration, we don't do anything, just wait for master to update weights in next iteration. This
// make sure all workers in the 1st iteration to get the same weights.
if(context.getCurrentIteration() == 1) {
return buildEmptyNNParams(context);
}
if(context.getLastMasterResult() == null) {
// This may not happen since master will set initialization weights firstly.
LOG.warn("Master result of last iteration is null.");
return null;
}
LOG.debug("Set current model with params {}", context.getLastMasterResult());
// initialize gradients if null
if(gradient == null) {
initGradient(this.trainingData, context.getLastMasterResult().getWeights());
}
// using the weights from master to train model in current iteration
this.gradient.setWeights(context.getLastMasterResult().getWeights());
this.gradient.run();
// get train errors and test errors
double trainError = this.gradient.getError();
double testError = this.testingData.getRecordCount() > 0 ? (this.gradient.getNetwork()
.calculateError(this.testingData)) : 0;
LOG.info("NNWorker compute iteration {} (train error {} validation error {})",
new Object[] { context.getCurrentIteration(), trainError, testError });
NNParams params = new NNParams();
params.setTestError(testError);
params.setTrainError(trainError);
params.setGradients(this.gradient.getGradients());
// prevent null point;
params.setWeights(new double[0]);
params.setTrainSize(this.trainingData.getRecordCount());
return params;
}
private void initGradient(MLDataSet training, double[] weights) {
BasicNetwork network = NNUtils.generateNetwork(this.inputs, this.hiddens, this.outputs);
// use the weights from master
network.getFlat().setWeights(weights);
FlatNetwork flat = network.getFlat();
// copy Propagation from encog
double[] flatSpot = new double[flat.getActivationFunctions().length];
for(int i = 0; i < flat.getActivationFunctions().length; i++) {
flatSpot[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1 : 0.0;
}
this.gradient = new Gradient(flat, training, flatSpot, new LinearErrorFunction());
}
private NNParams buildEmptyNNParams(WorkerContext workerContext) {
NNParams params = new NNParams();
params.setWeights(new double[0]);
params.setGradients(new double[0]);
params.setTestError(0.0d);
params.setTrainError(0.0d);
return params;
}
@Override
protected void postLoad(WorkerContext workerContext) {
((MemoryDiskMLDataSet) this.trainingData).endLoad();
((MemoryDiskMLDataSet) this.testingData).endLoad();
LOG.info("- # Records of the whole data set: {}.", this.count);
LOG.info("- # Records of the training data set: {}.", this.trainingData.getRecordCount());
LOG.info("- # Records of the testing data set: {}.", this.testingData.getRecordCount());
}
@Override
public void load(GuaguaWritableAdapter currentKey, GuaguaWritableAdapter currentValue,
WorkerContext workerContext) {
++this.count;
if((this.count) % 100000 == 0) {
LOG.info("Read {} records.", this.count);
}
// use guava to iterate only once
double[] ideal = new double[1];
int inputNodes = NumberFormatUtils.getInt(
workerContext.getProps().getProperty(NNConstants.GUAGUA_NN_INPUT_NODES),
NNConstants.GUAGUA_NN_DEFAULT_INPUT_NODES);
double[] inputs = new double[inputNodes];
int i = 0;
for(String input: Splitter.on(NNConstants.NN_DEFAULT_COLUMN_SEPARATOR).split(
currentValue.getWritable().toString())) {
if(i == 0) {
ideal[i++] = NumberFormatUtils.getDouble(input, 0.0d);
} else {
int inputsIndex = (i++) - 1;
if(inputsIndex >= inputNodes) {
break;
}
inputs[inputsIndex] = NumberFormatUtils.getDouble(input, 0.0d);
}
}
if(i < (inputNodes + 1)) {
throw new GuaguaRuntimeException(String.format(
"Not enough data columns, input nodes setting:%s, data column:%s", inputNodes, i));
}
int scale = NumberFormatUtils.getInt(workerContext.getProps().getProperty(NNConstants.NN_RECORD_SCALE), 1);
for(int j = 0; j < scale; j++) {
double[] tmpInputs = j == 0 ? inputs : new double[inputs.length];
double[] tmpIdeal = j == 0 ? ideal : new double[ideal.length];
System.arraycopy(inputs, 0, tmpInputs, 0, inputs.length);
MLDataPair pair = new BasicMLDataPair(new BasicMLData(tmpInputs), new BasicMLData(tmpIdeal));
double r = Math.random();
if(r >= 0.5d) {
this.trainingData.add(pair);
} else {
this.testingData.add(pair);
}
}
}
/*
* (non-Javadoc)
*
* @see ml.shifu.guagua.worker.AbstractWorkerComputable#initRecordReader(ml.shifu.guagua.io.GuaguaFileSplit)
*/
@Override
public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException {
this.setRecordReader(new GuaguaLineRecordReader());
this.getRecordReader().initialize(fileSplit);
}
public MLDataSet getTrainingData() {
return trainingData;
}
public void setTrainingData(MLDataSet trainingData) {
this.trainingData = trainingData;
}
public MLDataSet getTestingData() {
return testingData;
}
public void setTestingData(MLDataSet testingData) {
this.testingData = testingData;
}
}