All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.shifu.guagua.example.lr.LogisticRegressionMaster Maven / Gradle / Ivy

/*
 * Copyright [2013-2014] PayPal Software Foundation
 *  
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *  
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.guagua.example.lr;

import java.util.Arrays;
import java.util.Random;

import ml.shifu.guagua.master.AbstractMasterComputable;
import ml.shifu.guagua.master.MasterContext;
import ml.shifu.guagua.util.NumberFormatUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * {@link LogisticRegressionMaster} defines logic to update global logistic regression model.
 * 
 * 

* At first iteration, master builds a random model then send to all workers to start computing. This is to make all * workers use the same model at the starting time. * *

* At other iterations, master works: *

    *
  • 1. Accumulate all gradients from workers.
  • *
  • 2. Update global models by using accumulated gradients.
  • *
  • 3. Send new global model to workers by returning model parameters.
  • *
*/ // FIXME miss one parameter: size, the formula should be weights[i] -= learnRate * (1/size) * gradients[i]; pass from // workers public class LogisticRegressionMaster extends AbstractMasterComputable { private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionMaster.class); private static final Random RANDOM = new Random(); private int inputNum; private double[] weights; private double learnRate; @Override public void init(MasterContext context) { this.inputNum = NumberFormatUtils.getInt(LogisticRegressionContants.LR_INPUT_NUM, LogisticRegressionContants.LR_INPUT_DEFAULT_NUM); this.learnRate = NumberFormatUtils.getDouble(LogisticRegressionContants.LR_LEARNING_RATE, LogisticRegressionContants.LR_LEARNING_DEFAULT_RATE); // if not first iteration, means this is fail-over and should be recovered for state in master. if(!context.isFirstIteration()) { LogisticRegressionParams masterResult = context.getMasterResult(); if(masterResult != null && masterResult.getParameters() != null) { this.weights = masterResult.getParameters(); } else { initWeights(); } } } @Override public LogisticRegressionParams doCompute(MasterContext context) { if(context.isFirstIteration()) { initWeights(); } else { double[] gradients = new double[this.inputNum + 1]; double sumError = 0.0d; int size = 0; for(LogisticRegressionParams param: context.getWorkerResults()) { if(param != null) { for(int i = 0; i < gradients.length; i++) { gradients[i] += param.getParameters()[i]; } sumError += param.getError(); } size++; } for(int i = 0; i < weights.length; i++) { weights[i] -= learnRate * gradients[i]; } LOG.debug("DEBUG: Weights: {}", Arrays.toString(this.weights)); LOG.info("Iteration {} with error {}", context.getCurrentIteration(), sumError / size); } return new LogisticRegressionParams(weights); } private void initWeights() { weights = new double[this.inputNum + 1]; for(int i = 0; i < weights.length; i++) { weights[i] = RANDOM.nextDouble(); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy