All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.shifu.guagua.example.kmeans.KMeansMaster Maven / Gradle / Ivy

/*
 * Copyright [2013-2014] PayPal Software Foundation
 *  
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *  
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.guagua.example.kmeans;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;

import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * {@link KMeansMaster} computes new k center points for next iteration.
 * 
 * 

* By accumulating all the k center points sum values from all workers, then average to get new k center points. */ public class KMeansMaster implements MasterComputable { private static final Logger LOG = LoggerFactory.getLogger(KMeansMaster.class); /** * Master computation by accumulating all the k center points sum values from all workers, then average to get new k * center points. * * @throws NullPointerException * if worker result or worker results is null. */ @Override public KMeansMasterParams compute(MasterContext context) { if(context.getWorkerResults() == null) { throw new NullPointerException("No worker results received in Master."); } if(context.getCurrentIteration() == 1) { return doFirstIteration(context); } else { return doOtherIterations(context); } } private KMeansMasterParams doFirstIteration(MasterContext context) { List allInitialCentriods = new ArrayList(); boolean initilized = false; int k = 0, c = 0; for(KMeansWorkerParams workerResult: context.getWorkerResults()) { allInitialCentriods.addAll(workerResult.getPointList()); if(!initilized) { k = workerResult.getK(); c = workerResult.getC(); } } if(allInitialCentriods.size() < k) { throw new GuaguaRuntimeException( "Error: data size is smaller than k, please check your input and k settings."); } Collections.sort(allInitialCentriods, new Comparator() { @Override public int compare(double[] o1, double[] o2) { double dist = distance(o1) - distance(o2); return Double.valueOf(dist).compareTo(Double.valueOf(0d)); } }); List initialCentriods = new ArrayList(k); int step = allInitialCentriods.size() / k; for(int i = 0; i < k; i++) { initialCentriods.add(allInitialCentriods.get(i * step)); } KMeansMasterParams masterResult = new KMeansMasterParams(); masterResult.setK(k); masterResult.setC(c); masterResult.setPointList(initialCentriods); return masterResult; } private double distance(double[] record) { double sumSquare = 0d; for(int i = 0; i < record.length; i++) { sumSquare += (record[i] * record[i]); } return Math.sqrt(sumSquare); } private KMeansMasterParams doOtherIterations(MasterContext context) { // Accumulate all values for all categories List sumAllList = new LinkedList(); // here use long to avoid over flow List countAllList = new LinkedList(); boolean initilized = false; int k = 0, c = 0; for(KMeansWorkerParams workerResult: context.getWorkerResults()) { LOG.debug("Worker result: %s", workerResult); if(!initilized) { k = workerResult.getK(); c = workerResult.getC(); } for(int i = 0; i < k; i++) { if(!initilized) { sumAllList.add(new double[c]); countAllList.add(0L); } long currCount = countAllList.get(i); countAllList.set(i, currCount + workerResult.getCountList().get(i)); double[] sumAll = sumAllList.get(i); for(int j = 0; j < c; j++) { sumAll[j] += workerResult.getPointList().get(i)[j]; } } initilized = true; } LOG.debug("sumList: %s", (sumAllList)); LOG.debug("countList: %s", countAllList); // Get new center points List meanList = new LinkedList(); for(int i = 0; i < k; i++) { double[] means = new double[c]; for(int j = 0; j < c; j++) { means[j] = sumAllList.get(i)[j] / countAllList.get(i); } meanList.add(means); } LOG.debug("meanList: %s", (meanList)); // Construct new master result with new center points KMeansMasterParams masterResult = new KMeansMasterParams(); masterResult.setK(k); masterResult.setC(c); masterResult.setPointList(meanList); return masterResult; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy