All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.shifu.guagua.example.kmeans.KMeansWorker Maven / Gradle / Ivy

/*
 * Copyright [2013-2014] PayPal Software Foundation
 *  
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *  
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.guagua.example.kmeans;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.util.MemoryDiskList;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;

import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Splitter;

/**
 * {@link KMeansWorker} re-computes each record tagged with new category.
 * 
 * 

* To calculate new k centers in master, {@link KMeansWorker} also help to accumulate worker info for new k centers by * using sum list and count list. */ public class KMeansWorker extends AbstractWorkerComputable, GuaguaWritableAdapter> { private static final Logger LOG = LoggerFactory.getLogger(KMeansWorker.class); /** * Data list of current worker cached in memory. */ private MemoryDiskList dataList; /** * K categories pre-defined */ private int k; /** * Columns (dimensions) for each record */ private int c; /** * Separator to split data for each record */ private String separator; /** * Reading input line by line */ @Override public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException { this.setRecordReader(new GuaguaLineRecordReader()); this.getRecordReader().initialize(fileSplit); } @Override public void init(WorkerContext context) { this.k = Integer.parseInt(context.getProps().getProperty(KMeansContants.KMEANS_K_NUMBER)); this.c = Integer.parseInt(context.getProps().getProperty(KMeansContants.KMEANS_COLUMN_NUMBER)); this.separator = context.getProps().getProperty(KMeansContants.KMEANS_DATA_SEPERATOR); double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.5")); String tmpFolder = context.getProps().getProperty("guagua.data.tmpfolder", "tmp"); this.dataList = new MemoryDiskList((long) (Runtime.getRuntime().maxMemory() * memoryFraction), tmpFolder + File.separator + System.currentTimeMillis()); // cannot find a good place to close these two data set, using Shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { @Override public void run() { KMeansWorker.this.dataList.close(); KMeansWorker.this.dataList.clear(); } })); // just set into worker context for data output interceptor usage. context.setAttachment(this.dataList); } /** * Using the new k centers to tag each record with index denoting the record belongs to which category. */ @Override public KMeansWorkerParams doCompute(WorkerContext context) { if(context.getCurrentIteration() == 1) { return doFirstIteration(context); } else { this.dataList.reOpen(); return doOtherIterations(context); } } private KMeansWorkerParams doFirstIteration(WorkerContext workerContext) { KMeansWorkerParams workerResult = new KMeansWorkerParams(); workerResult.setK(this.k); workerResult.setC(this.c); workerResult.setFirstIteration(true); int dataSize = (int) this.dataList.size(); List pointList = new ArrayList(dataSize); if(this.k >= dataSize) { for(TaggedRecord record: this.dataList) { pointList.add(toDouble(record)); } } else { int m = dataSize / this.k; int i = 0; this.dataList.reOpen(); for(TaggedRecord record: this.dataList) { if(i++ % m == 0) { pointList.add(toDouble(record)); } } } workerResult.setPointList(pointList); return workerResult; } private double[] toDouble(TaggedRecord record) { Double[] data = record.getRecord(); double[] newData = new double[data.length]; int i = 0; for(Double d: data) { newData[i] = d == null ? 0d : d; } return newData; } private KMeansWorkerParams doOtherIterations(WorkerContext workerContext) { // new centers used in this iteration. List centers = workerContext.getLastMasterResult().getPointList(); LOG.debug("Initial centers:%s", (centers)); // sum list and count list as worker result sent to master for global accumulation. List sumList = new LinkedList(); List countList = new LinkedList(); // Initializing sum list and count list. for(int i = 0; i < this.k; i++) { sumList.add(new double[this.c]); countList.add(0); } for(TaggedRecord record: this.dataList) { int index = findClosedCenter(record.getRecord(), centers); record.setTag(index); countList.set(index, countList.get(index) + 1); double[] sum = sumList.get(index); for(int i = 0; i < this.c; i++) { sum[i] += record.getRecord()[i] == null ? 0d : record.getRecord()[i].doubleValue(); } } LOG.debug("sumList:%s", (sumList)); LOG.debug("countList:%s", countList); KMeansWorkerParams workerResult = new KMeansWorkerParams(); workerResult.setK(this.k); workerResult.setC(this.c); workerResult.setFirstIteration(false); workerResult.setPointList(sumList); workerResult.setCountList(countList); return workerResult; } @Override protected void postLoad(WorkerContext context) { this.dataList.switchState(); } /** * Finding closed center from all the k centers. Return the index of finding center. */ private int findClosedCenter(Double[] record, List centers) { int index = 0; double minDist = distance(record, centers.get(0)); for(int i = 1; i < centers.size(); i++) { double distance = distance(record, centers.get(i)); if(distance < minDist) { index = i; } } return index; } /** * Calculate cosine distance for two points. */ // TODO cache sqW2, no need re-computing private double distance(Double[] record, double[] center) { double denominator = 0; for(int i = 0; i < center.length; i++) { denominator += record[i] == null ? 0d : (record[i] * center[i]); } double sqW1 = 0, sqW2 = 0; for(int i = 0; i < center.length; i++) { sqW1 += record[i] == null ? 0d : (record[i] * record[i]); sqW2 += (center[i] * center[i]); } return denominator / (Math.sqrt(sqW1) * Math.sqrt(sqW2)); } /** * Loading data into memory. any invalid data will be set to null. */ @Override public void load(GuaguaWritableAdapter currentKey, GuaguaWritableAdapter currentValue, WorkerContext workerContext) { String line = currentValue.getWritable().toString(); Double[] record = new Double[this.c]; int i = 0; for(String input: Splitter.on(this.separator).split(line)) { try { record[i++] = Double.parseDouble(input); } catch (NumberFormatException e) { // use null to replace in-valid number record[i++] = null; } } this.dataList.append(new TaggedRecord(record)); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy