All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.shifu.guagua.example.lnr.LinearRegressionWorker Maven / Gradle / Ivy

There is a newer version: 0.7.10
Show newest version
/*
 * Copyright [2013-2014] PayPal Software Foundation
 *  
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *  
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.guagua.example.lnr;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;

import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.util.MemoryDiskList;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;

import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Splitter;

/**
 * {@link LinearRegressionWorker} defines logic to accumulate local linear regression gradients.
 * 
 * 

* At first iteration, wait for master to use the consistent initiating model. * *

* At other iterations, workers include: *

    *
  • 1. Update local model by using global model from last step..
  • *
  • 2. Accumulate gradients by using local worker input data.
  • *
  • 3. Send new local gradients to master by returning parameters.
  • *
* *

* WARNING: Input data should be normalized before, or you will get a very bad model. */ public class LinearRegressionWorker extends AbstractWorkerComputable, GuaguaWritableAdapter> { private static final Logger LOG = LoggerFactory.getLogger(LinearRegressionWorker.class); /** * Input column number */ private int inputNum; /** * Output column number */ private int outputNum; /** * In-memory data which located in memory at the first iteration. */ private MemoryDiskList dataList; /** * Local linear regression model. */ private double[] weights; /** * A splitter to split data with specified delimiter. */ private Splitter splitter = Splitter.on(","); @Override public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException { this.setRecordReader(new GuaguaLineRecordReader(fileSplit)); } @Override public void init(WorkerContext context) { this.inputNum = NumberFormatUtils.getInt(LinearRegressionContants.LR_INPUT_NUM, LinearRegressionContants.LR_INPUT_DEFAULT_NUM); this.outputNum = 1; double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.5")); String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", System.getProperty("user.dir")); this.dataList = new MemoryDiskList((long) (Runtime.getRuntime().maxMemory() * memoryFraction), tmpFolder + File.separator + System.currentTimeMillis()); // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { LinearRegressionWorker.this.dataList.close(); LinearRegressionWorker.this.dataList.clear(); } })); } @Override public LinearRegressionParams doCompute(WorkerContext context) { if(context.isFirstIteration()) { return new LinearRegressionParams(); } else { this.weights = context.getLastMasterResult().getParameters(); double[] gradients = new double[this.inputNum + 1]; double finalError = 0.0d; int size = 0; this.dataList.reOpen(); for(Data data: dataList) { double error = dot(data.inputs, this.weights) - data.outputs[0]; finalError += error * error / 2; for(int i = 0; i < gradients.length; i++) { gradients[i] += error * data.inputs[i]; } size++; } LOG.info("Iteration {} with error {}", context.getCurrentIteration(), finalError / size); return new LinearRegressionParams(gradients, finalError / size); } } @Override protected void postLoad(WorkerContext context) { this.dataList.switchState(); } /** * Compute dot value of two vectors. */ private double dot(double[] inputs, double[] weights) { double value = 0.0d; for(int i = 0; i < weights.length; i++) { value += weights[i] * inputs[i]; } return value; } @Override public void load(GuaguaWritableAdapter currentKey, GuaguaWritableAdapter currentValue, WorkerContext context) { String line = currentValue.getWritable().toString(); double[] inputData = new double[inputNum + 1]; double[] outputData = new double[outputNum]; int count = 0, inputIndex = 0, outputIndex = 0; inputData[inputIndex++] = 1.0d; for(String unit: splitter.split(line)) { if(count < inputNum) { inputData[inputIndex++] = Double.valueOf(unit); } else if(count >= inputNum && count < (inputNum + outputNum)) { outputData[outputIndex++] = Double.valueOf(unit); } else { break; } count++; } this.dataList.append(new Data(inputData, outputData)); } private static class Data implements Serializable { private static final long serialVersionUID = 3739632336801994754L; public Data(double[] inputs, double[] outputs) { this.inputs = inputs; this.outputs = outputs; } private final double[] inputs; private final double[] outputs; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy