
weka.distributed.KMeansReduceTask Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of distributedWekaBase Show documentation
Show all versions of distributedWekaBase Show documentation
This package provides generic configuration class and distributed map/reduce style tasks for Weka
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* KMeansReduceTask
* Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
*
*/
package weka.distributed;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.NormalizableDistance;
import weka.core.Utils;
import weka.core.stats.ArffSummaryNumericMetric;
import weka.core.stats.NominalStats;
/**
* Reduce task for k-means clustering. Processes partial cluster summary
* metadata for a particular run in order to produce a set of Instances that
* contains new cluster centroids.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: $
*/
public class KMeansReduceTask implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 6222983145960081251L;
protected double m_totalWithinClustersError;
/**
* Will hold the updated centroids for this run after aggregating the partial
* clusterings
*/
protected Instances m_newCentroidsForRun;
/**
* This will hold the aggregated summary instances (with summary stats
* attributes), one for each cluster
*/
protected List m_aggregatedCentroidSummaries;
/**
* If run number = 0, then this will hold the some priming data for
* initializing the ranges of numeric attributes in the case where filters may
* have transformed/altered the original space.
*/
protected Instances m_globalDistanceFunctionPrimingData;
/** The run number */
protected int m_runNumber;
/** The iteration number */
protected int m_iterationNumber;
/**
* Reduce the cluster centroid summary metadata instances for a particular run
* in order to produce a new set of Instances that contains the new cluster
* centroids for the run. Adds the total within cluster error to the relation
* name of the instances. If the iteration number is 0 then also generates a
* two instance data set that can be used for initializing a distance
* function. The two instances contain global minimum and maximum values for
* numeric attributes respectively, which is used in the distance function for
* normalization. This particular dataset is useful when filters (beyond
* missing values replacement) have been specified for k-means and it is not
* possible to use the summary stats in the global ARFF header file for
* initializing the distance function
*
* @param runNumber the current run number
* @param iterationNumber the current iteration number of k-means
* @param headerNoSummary the global ARFF header (as computed by the
* ArffHeader job on the entire dataset, and having passed through
* any preprocessing filters). We need this so that the correct index
* for nominal attribute values can be set in the new centroids (map
* tasks accumulating summary stats when clustering partitions of the
* data may see nominal values in different orders, or not see some
* values at all, compared to the global header)
* @param clusterSummaries a list of cluster summary information. Each inner
* list of Instances will have been generated by a map task on a
* subset of the data. Each instances object in the list contains the
* summary stats for one cluster centroid. Inner lists are in order
* of centroid number. A particular Instances entry in a list may be
* null - this indicates that the cluster was empty within that
* particular map task (i.e. no training instances were assigned to
* it)
* @return an instance of KMeansReduceTask with new centroids and supporting
* data computed.
* @throws DistributedWekaException if a problem occurs
*/
public KMeansReduceTask reduceClusters(int runNumber, int iterationNumber,
Instances headerNoSummary, List> clusterSummaries)
throws DistributedWekaException {
m_runNumber = runNumber;
m_iterationNumber = iterationNumber;
int numClusters = clusterSummaries.get(0).size();
// headerNoSummary =
// CSVToARFFHeaderReduceTask.stripSummaryAtts(headerNoSummary);
List> partialsPerCentroid =
new ArrayList>();
for (int i = 0; i < numClusters; i++) {
partialsPerCentroid.add(new ArrayList());
}
for (int i = 0; i < clusterSummaries.size(); i++) {
List currentPartial = clusterSummaries.get(i);
if (currentPartial.size() != numClusters) {
throw new DistributedWekaException(
"Each list of centroid summary stats should be "
+ "equal to the number of clusters. Expected " + numClusters
+ " but this list" + " contains " + currentPartial.size());
}
for (int j = 0; j < currentPartial.size(); j++) {
Instances centroidPartial = currentPartial.get(j);
if (centroidPartial != null) {
partialsPerCentroid.get(j).add(centroidPartial);
}
}
}
CSVToARFFHeaderReduceTask reduceTask = new CSVToARFFHeaderReduceTask();
List aggregatedCentroidSummaries = new ArrayList();
m_totalWithinClustersError = 0;
for (int i = 0; i < partialsPerCentroid.size(); i++) {
if (partialsPerCentroid.get(i).size() > 0) {
double clusterError = getErrorsForCluster(partialsPerCentroid.get(i));
m_totalWithinClustersError += clusterError;
Instances aggregated = reduceTask.aggregate(partialsPerCentroid.get(i));
// update the relation name
aggregated.setRelationName("Stats for centroid " + i + " : "
+ clusterError);
aggregatedCentroidSummaries.add(aggregated);
} else {
// this means that this is now a global empty cluster (i.e.
// no mappers assigned any training instances to this centroid).
// So we just drop it.
}
}
m_aggregatedCentroidSummaries = aggregatedCentroidSummaries;
double[] globalMins = null;
double[] globalMaxes = null;
if (iterationNumber == 0) {
globalMins = new double[headerNoSummary.numAttributes()];
globalMaxes = new double[headerNoSummary.numAttributes()];
for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
if (headerNoSummary.attribute(i).isNumeric()) {
globalMins[i] = Double.MAX_VALUE;
globalMaxes[i] = Double.MIN_VALUE;
} else {
globalMins[i] = Utils.missingValue();
globalMaxes[i] = Utils.missingValue();
}
}
}
// now construct the new centroids
for (int i = 0; i < aggregatedCentroidSummaries.size(); i++) {
Instances centroidSummary = aggregatedCentroidSummaries.get(i);
double[] centerVals = new double[headerNoSummary.numAttributes()];
for (int j = 0; j < headerNoSummary.numAttributes(); j++) {
Attribute origAtt = headerNoSummary.attribute(j);
String name = origAtt.name();
Attribute summaryAtt =
centroidSummary
.attribute(CSVToARFFHeaderMapTask.ARFF_SUMMARY_ATTRIBUTE_PREFIX
+ name);
if (origAtt.isNumeric()) {
double nonMissingCountForAtt =
ArffSummaryNumericMetric.COUNT.valueFromAttribute(summaryAtt);
double missingCountForAtt =
ArffSummaryNumericMetric.MISSING.valueFromAttribute(summaryAtt);
double clusterMeanForAtt =
ArffSummaryNumericMetric.MEAN.valueFromAttribute(summaryAtt);
if (missingCountForAtt > nonMissingCountForAtt
|| Utils.isMissingValue(clusterMeanForAtt)) {
System.err
.println("********************************* att: "
+ origAtt.name() + " mean: " + clusterMeanForAtt
+ "non-missing: " + nonMissingCountForAtt + " missing: "
+ missingCountForAtt);
centerVals[j] = Utils.missingValue();
} else {
centerVals[j] = clusterMeanForAtt;
}
if (iterationNumber == 0) {
double min =
ArffSummaryNumericMetric.MIN.valueFromAttribute(summaryAtt);
double max =
ArffSummaryNumericMetric.MAX.valueFromAttribute(summaryAtt);
if (!Utils.isMissingValue(min) && !Double.isInfinite(min)) {
if (min < globalMins[j]) {
globalMins[j] = min;
}
}
if (!Utils.isMissingValue(max) && !Double.isInfinite(max)) {
if (max > globalMaxes[j]) {
globalMaxes[j] = max;
}
}
}
} else if (origAtt.isNominal()) {
NominalStats stats = NominalStats.attributeToStats(summaryAtt);
// int clusterModeForAttIndex = stats.getMode();
String clusterModeLabelForAtt = stats.getModeLabel();
double modeCountForAtt = stats.getCount(clusterModeLabelForAtt);
double missingCountForAtt = stats.getNumMissing();
if (missingCountForAtt > modeCountForAtt) {
centerVals[j] = Utils.missingValue();
} else {
// centerVals[j] = clusterModeForAttIndex;
int mappedIndex =
headerNoSummary.attribute(j).indexOfValue(clusterModeLabelForAtt);
if (mappedIndex < 0) {
throw new DistributedWekaException(
"Unable to find nominal value '" +
clusterModeLabelForAtt + "' in global header attribute '"
+ headerNoSummary.attribute(j));
}
centerVals[j] = mappedIndex;
}
} else {
// this could happen if the user has applied a streamable filter that
// creates string attributes or something
throw new DistributedWekaException(
"k-means can only handle numeric and nominal attributes!");
}
}
// add the new centroid
headerNoSummary.add(new DenseInstance(1.0, centerVals));
}
m_newCentroidsForRun = headerNoSummary;
// If iteration 0 then compute global priming data for distance functions
// in the (potentially) filtered space
if (iterationNumber == 0) {
m_globalDistanceFunctionPrimingData = new Instances(headerNoSummary, 0);
m_globalDistanceFunctionPrimingData
.add(new DenseInstance(1.0, globalMins));
m_globalDistanceFunctionPrimingData.add(new DenseInstance(1.0,
globalMaxes));
}
return this;
}
/**
* Return the centroids for the run
*
* @return the centroids as a set of instances
*/
public Instances getCentroidsForRun() {
return m_newCentroidsForRun;
}
/**
* Get the aggregated summary data for each individual centroid. This is
* represented as a Instances header with summary meta attributes
*
* @return a list of summary meta data
*/
public List getAggregatedCentroidSummaries() {
return m_aggregatedCentroidSummaries;
}
/**
* Get the global distance function priming data. This contains global mins
* and maxes for attributes in the transformed (by any filters) space
*
* @return the global distance function priming data
*/
public Instances getGlobalDistanceFunctionPrimingData() {
return m_globalDistanceFunctionPrimingData;
}
/**
* Get the run number
*
* @return the run number
*/
public int getRunNumber() {
return m_runNumber;
}
/**
* Get the current iteration number
*
* @return the current iteration number
*/
public int getIterationNumber() {
return m_iterationNumber;
}
/**
* Get the total within cluster error for this run
*
* @return the total within cluster error for this run
*/
public double getTotalWithinClustersError() {
return m_totalWithinClustersError;
}
/**
* Computes the errors for a particular cluster from a list of partial cluster
* summary data
*
* @param clusterPartials a list of Instances containing summary meta
* attributes
* @return the total error for the cluster
* @throws DistributedWekaException if a problem occurs
*/
protected static double getErrorsForCluster(List clusterPartials)
throws DistributedWekaException {
double error = 0;
for (Instances i : clusterPartials) {
String relationName = i.relationName();
String[] parts = relationName.split(":");
if (parts.length != 2) {
throw new DistributedWekaException(
"Can't find within cluster error in the "
+ "relation name of a cluster centroid partial stats instances:\n "
+ i.toString());
}
try {
error += Double.parseDouble(parts[1].trim());
} catch (NumberFormatException e) {
throw new DistributedWekaException(
"Unable to parse within cluster error"
+ " from a cluster centroid partial stats instances: \n"
+ i.toString());
}
}
return error;
}
/**
* Utility function to examine the attribute ranges in a bunch of distance
* functions and return a two instance dataset with the global mins/maxes of
* numeric attributes set. This can be used to "prime" a distance function.
*
* @param distanceFuncs a list of distance functions (where each potentially
* has only seen part of the overall dataset
* @param headerNoSummary the header of the data that the distance functions
* have been seeing
* @return a priming data set with global min and max values for numeric
* attributes
* @throws DistributedWekaException if a problem occurs
*/
public static Instances computeDistancePrimingDataFromDistanceFunctions(
List distanceFuncs, Instances headerNoSummary)
throws DistributedWekaException {
Instances prime = null;
double[] mins = new double[headerNoSummary.numAttributes()];
double[] maxes = new double[headerNoSummary.numAttributes()];
try {
for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
if (headerNoSummary.attribute(i).isNumeric()) {
mins[i] = Double.MAX_VALUE;
maxes[i] = Double.MIN_VALUE;
} else {
mins[i] = Utils.missingValue();
maxes[i] = Utils.missingValue();
}
}
for (NormalizableDistance d : distanceFuncs) {
double[][] ranges = d.getRanges();
for (int i = 0; i < headerNoSummary.numAttributes(); i++) {
if (ranges[i][NormalizableDistance.R_MIN] < mins[i]) {
mins[i] = ranges[i][NormalizableDistance.R_MIN];
}
if (ranges[i][NormalizableDistance.R_MAX] > maxes[i]) {
maxes[i] = ranges[i][NormalizableDistance.R_MAX];
}
}
}
} catch (Exception ex) {
throw new DistributedWekaException(ex);
}
prime = new Instances(headerNoSummary, 2);
prime.add(new DenseInstance(1.0, mins));
prime.add(new DenseInstance(1.0, maxes));
return prime;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy