
hex.kmeans.KMeans Maven / Gradle / Ivy
package hex.kmeans;
import hex.ClusteringModelBuilder;
import hex.ModelCategory;
import hex.ModelMetricsClustering;
import hex.schemas.KMeansV3;
import hex.schemas.ModelBuilderSchema;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.*;
import water.H2O.H2OCountedCompleter;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
* Scalable K-Means++ (KMeans||)
* http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf
* http://www.youtube.com/watch?v=cigXAxV3XcY
*/
public class KMeans extends ClusteringModelBuilder {
@Override public ModelCategory[] can_build() {
return new ModelCategory[]{ ModelCategory.Clustering };
}
@Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };
public enum Initialization {
Random, PlusPlus, Furthest, User
}
// Convergence tolerance
final private double TOLERANCE = 1e-6;
// Called from an http request
public KMeans(Key dest, String desc, KMeansModel.KMeansParameters parms) { super(dest, desc, parms); init(false); }
public KMeans( KMeansModel.KMeansParameters parms ) { super("K-means",parms); init(false); }
public ModelBuilderSchema schema() { return new KMeansV3(); }
protected void checkMemoryFootPrint() {
long mem_usage = 8 /*doubles*/ * _parms._k * _train.numCols() * (_parms._standardize ? 2 : 1);
long max_mem = H2O.SELF.get_max_mem();
if (mem_usage > max_mem) {
String msg = "Centroids won't fit in the driver node's memory ("
+ PrettyPrint.bytes(mem_usage) + " > " + PrettyPrint.bytes(max_mem)
+ ") - try reducing the number of columns and/or the number of categorical factors.";
error("_train", msg);
cancel(msg);
}
}
/** Start the KMeans training Job on an F/J thread. */
@Override public Job trainModel() {
return start(new KMeansDriver(), _parms._max_iterations);
}
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();".
*
* Validate K, max_iterations and the number of rows. */
@Override public void init(boolean expensive) {
super.init(expensive);
if( _parms._max_iterations < 0 || _parms._max_iterations > 1e6) error("_max_iterations", " max_iterations must be between 0 and 1e6");
if( _train == null ) return;
if( _parms._init == Initialization.User && _parms._user_points == null )
error("_user_points","Must specify initial cluster centers");
if( null != _parms._user_points ){ // Check dimensions of user-specified centers
if( _parms._user_points.get().numCols() != _train.numCols() ) {
error("_user_points","The user-specified points must have the same number of columns (" + _train.numCols() + ") as the training observations");
}
}
if (expensive && error_count() == 0) checkMemoryFootPrint();
}
// ----------------------
private class KMeansDriver extends H2OCountedCompleter {
private String[][] _isCats; // Categorical columns
// Initialize cluster centers
double[][] initial_centers( KMeansModel model, final Vec[] vecs, final double[] means, final double[] mults ) {
// Categoricals use a different distance metric than numeric columns.
model._output._categorical_column_count=0;
_isCats = new String[vecs.length][];
for( int v=0; v _parms._max_iterations) return true;
// Compute average change in standardized cluster centers
if( oldCenters==null ) return false; // No prior iteration, not stopping
double average_change = 0;
for( int clu=0; clu<_parms._k; clu++ )
average_change += hex.genmodel.GenModel.KMeans_distance(oldCenters[clu],newCenters[clu],_isCats,null,null);
average_change /= _parms._k; // Average change per cluster
model._output._avg_centroids_chg = ArrayUtils.copyAndFillOf(
model._output._avg_centroids_chg,
model._output._avg_centroids_chg.length+1, average_change);
model._output._training_time_ms = ArrayUtils.copyAndFillOf(
model._output._training_time_ms,
model._output._training_time_ms.length+1, System.currentTimeMillis());
return average_change < TOLERANCE;
}
// Main worker thread
@Override protected void compute2() {
KMeansModel model = null;
try {
init(true);
// Do lock even before checking the errors, since this block is finalized by unlock
// (not the best solution, but the code is more readable)
_parms.read_lock_frames(KMeans.this); // Fetch & read-lock input frames
// Something goes wrong
if( error_count() > 0 ) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(KMeans.this);
// The model to be built
model = new KMeansModel(dest(), _parms, new KMeansModel.KMeansOutput(KMeans.this));
model.delete_and_lock(_key);
//
final Vec vecs[] = _train.vecs();
// mults & means for standardization
final double[] means = _train.means(); // means are used to impute NAs
final double[] mults = _parms._standardize ? _train.mults() : null;
model._output._normSub = means;
model._output._normMul = mults;
// Initialize cluster centers and standardize if requested
double[][] centers = initial_centers(model,vecs,means,mults);
if( centers==null ) return; // Stopped/cancelled during center-finding
double[][] oldCenters = null;
// ---
// Run the main KMeans Clustering loop
// Stop after enough iterations or average_change < TOLERANCE
while( !isDone(model,centers,oldCenters) ) {
Lloyds task = new Lloyds(centers,means,mults,_isCats, _parms._k).doAll(vecs);
// Pick the max categorical level for cluster center
max_cats(task._cMeans,task._cats,_isCats);
// Handle the case where some centers go dry. Rescue only 1 cluster
// per iteration ('cause we only tracked the 1 worst row)
if( cleanupBadClusters(task,vecs,centers,means,mults) ) continue;
// Compute model stats; update standardized cluster centers
oldCenters = centers;
centers = computeStatsFillModel(task, model, vecs, centers, means, mults);
model.update(_key); // Update model in K/V store
update(1); // One unit of work
if (model._parms._score_each_iteration)
Log.info(model._output._model_summary);
}
Log.info(model._output._model_summary);
// Log.info(model._output._scoring_history);
// Log.info(((ModelMetricsClustering)model._output._training_metrics).createCentroidStatsTable().toString());
// FIXME: Remove (most of) this code - once it passes...
// PUBDEV-871: Double-check the training metrics (gathered by computeStatsFillModel) and the scoring logic by scoring on the training set
if (false) {
assert((ArrayUtils.sum(model._output._size) - _parms.train().numRows()) <= 1);
// Log.info(model._output._model_summary);
// Log.info(model._output._scoring_history);
// Log.info(((ModelMetricsClustering)model._output._training_metrics).createCentroidStatsTable().toString());
model.score(_parms.train()).delete(); //this scores on the training data and appends a ModelMetrics
ModelMetricsClustering mm = DKV.getGet(model._output._model_metrics[model._output._model_metrics.length - 1]);
assert(Arrays.equals(mm._size, ((ModelMetricsClustering) model._output._training_metrics)._size));
for (int i=0; i<_parms._k; ++i) {
assert(MathUtils.compare(mm._withinss[i], ((ModelMetricsClustering) model._output._training_metrics)._withinss[i], 1e-6, 1e-6));
}
assert(MathUtils.compare(mm._totss, ((ModelMetricsClustering) model._output._training_metrics)._totss, 1e-6, 1e-6));
assert(MathUtils.compare(mm._betweenss, ((ModelMetricsClustering) model._output._training_metrics)._betweenss, 1e-6, 1e-6));
assert(MathUtils.compare(mm._tot_withinss, ((ModelMetricsClustering) model._output._training_metrics)._tot_withinss, 1e-6, 1e-6));
}
// At the end: validation scoring (no need to gather scoring history)
if (_valid != null) {
Frame pred = model.score(_parms.valid()); //this appends a ModelMetrics on the validation set
model._output._validation_metrics = DKV.getGet(model._output._model_metrics[model._output._model_metrics.length-1]);
pred.delete();
model.update(_key); // Update model in K/V store
}
done(); // Job done!
} catch( Throwable t ) {
Job thisJob = DKV.getGet(_key);
if (thisJob._state == JobState.CANCELLED) {
Log.info("Job cancelled by user.");
} else {
t.printStackTrace();
failed(t);
throw t;
}
} finally {
if( model != null ) model.unlock(_key);
_parms.read_unlock_frames(KMeans.this);
}
tryComplete();
}
private TwoDimTable createModelSummaryTable(KMeansModel.KMeansOutput output) {
List colHeaders = new ArrayList<>();
List colTypes = new ArrayList<>();
List colFormat = new ArrayList<>();
colHeaders.add("Number of Clusters"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Number of Categorical Columns"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Number of Iterations"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Within Cluster Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Total Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Between Cluster Sum of Squares"); colTypes.add("double"); colFormat.add("%.5f");
final int rows = 1;
TwoDimTable table = new TwoDimTable(
"Model Summary", null,
new String[rows],
colHeaders.toArray(new String[0]),
colTypes.toArray(new String[0]),
colFormat.toArray(new String[0]),
"");
int row = 0;
int col = 0;
table.set(row, col++, output._centers_raw.length);
table.set(row, col++, output._categorical_column_count);
table.set(row, col++, output._iterations);
table.set(row, col++, output._tot_withinss);
table.set(row, col++, output._totss);
table.set(row, col++, output._betweenss);
return table;
}
private TwoDimTable createScoringHistoryTable(KMeansModel.KMeansOutput output) {
List colHeaders = new ArrayList<>();
List colTypes = new ArrayList<>();
List colFormat = new ArrayList<>();
colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Iteration"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Avg. Change of Std. Centroids"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Within Cluster Sum Of Squares"); colTypes.add("double"); colFormat.add("%.5f");
final int rows = output._avg_centroids_chg.length;
TwoDimTable table = new TwoDimTable(
"Scoring History", null,
new String[rows],
colHeaders.toArray(new String[0]),
colTypes.toArray(new String[0]),
colFormat.toArray(new String[0]),
"");
int row = 0;
for( int i = 0; i {
// IN
final double[] _means, _mults;
final String[][] _isCats;
// OUT
double _tss;
TotSS(double[] means, double[] mults, String[][] isCats) {
_means = means;
_mults = mults;
_tss = 0;
_isCats = isCats;
}
@Override public void map(Chunk[] cs) {
// de-standardize the cluster means
double[] means = Arrays.copyOf(_means, _means.length);
if (_mults!=null)
for (int i=0; i {
// IN
double[][] _centers;
double[] _means, _mults; // Standardization
final String[][] _isCats;
// OUT
double _sqr;
SumSqr( double[][] centers, double[] means, double[] mults, String[][] isCats ) {
_centers = centers;
_means = means;
_mults = mults;
_isCats = isCats;
}
@Override public void map(Chunk[] cs) {
double[] values = new double[cs.length];
ClusterDist cd = new ClusterDist();
for( int row = 0; row < cs[0]._len; row++ ) {
data(values, cs, row, _means, _mults);
_sqr += minSqr(_centers, values, _isCats, cd);
}
_means = _mults = null;
_centers = null;
}
@Override public void reduce(SumSqr other) { _sqr += other._sqr; }
}
// -------------------------------------------------------------------------
// Sample rows with increasing probability the farther they are from any
// cluster center.
private static class Sampler extends MRTask {
// IN
double[][] _centers;
double[] _means, _mults; // Standardization
final String[][] _isCats;
final double _sqr; // Min-square-error
final double _probability; // Odds to select this point
final long _seed;
// OUT
double[][] _sampled; // New cluster centers
Sampler( double[][] centers, double[] means, double[] mults, String[][] isCats, double sqr, double prob, long seed ) {
_centers = centers;
_means = means;
_mults = mults;
_isCats = isCats;
_sqr = sqr;
_probability = prob;
_seed = seed;
}
@Override public void map(Chunk[] cs) {
double[] values = new double[cs.length];
ArrayList list = new ArrayList<>();
Random rand = RandomUtils.getRNG(_seed + cs[0].start());
ClusterDist cd = new ClusterDist();
for( int row = 0; row < cs[0]._len; row++ ) {
data(values, cs, row, _means, _mults);
double sqr = minSqr(_centers, values, _isCats, cd);
if( _probability * sqr > rand.nextDouble() * _sqr )
list.add(values.clone());
}
_sampled = new double[list.size()][];
list.toArray(_sampled);
_centers = null;
_means = _mults = null;
}
@Override public void reduce(Sampler other) {
_sampled = ArrayUtils.append(_sampled, other._sampled);
}
}
// ---------------------------------------
// A Lloyd's pass:
// Find nearest cluster center for every point
// Compute new mean/center & variance & rows for each cluster
// Compute distance between clusters
// Compute total sqr distance
private static class Lloyds extends MRTask {
// IN
double[][] _centers;
double[] _means, _mults; // Standardization
final int _k;
final String[][] _isCats;
// OUT
double[][] _cMeans; // Means for each cluster
long[/*k*/][/*features*/][/*nfactors*/] _cats; // Histogram of cat levels
double[] _cSqr; // Sum of squares for each cluster
long[] _size; // Number of rows in each cluster
long _worst_row; // Row with max err
double _worst_err; // Max-err-row's max-err
Lloyds( double[][] centers, double[] means, double[] mults, String[][] isCats, int k ) {
_centers = centers;
_means = means;
_mults = mults;
_isCats = isCats;
_k = k;
}
@Override public void map(Chunk[] cs) {
int N = cs.length;
assert _centers[0].length==N;
_cMeans = new double[_k][N];
_cSqr = new double[_k];
_size = new long[_k];
// Space for cat histograms
_cats = new long[_k][N][];
for( int clu=0; clu< _k; clu++ )
for( int col=0; col _worst_err) { _worst_err = cd._dist; _worst_row = cs[0].start()+row; }
}
// Scale back down to local mean
for( int clu = 0; clu < _k; clu++ )
if( _size[clu] != 0 ) ArrayUtils.div(_cMeans[clu], _size[clu]);
_centers = null;
_means = _mults = null;
}
@Override public void reduce(Lloyds mr) {
for( int clu = 0; clu < _k; clu++ ) {
long ra = _size[clu];
long rb = mr._size[clu];
double[] ma = _cMeans[clu];
double[] mb = mr._cMeans[clu];
for( int c = 0; c < ma.length; c++ ) // Recursive mean
if( ra+rb > 0 ) ma[c] = (ma[c] * ra + mb[c] * rb) / (ra + rb);
}
ArrayUtils.add(_cats, mr._cats);
ArrayUtils.add(_cSqr, mr._cSqr);
ArrayUtils.add(_size, mr._size);
// track global worst-row
if( _worst_err < mr._worst_err) { _worst_err = mr._worst_err; _worst_row = mr._worst_row; }
}
}
// A pair result: nearest cluster center and the square distance
private static final class ClusterDist { int _cluster; double _dist; }
private static double minSqr(double[][] centers, double[] point, String[][] isCats, ClusterDist cd) {
return closest(centers, point, isCats, cd, centers.length)._dist;
}
private static double minSqr(double[][] centers, double[] point, String[][] isCats, ClusterDist cd, int count) {
return closest(centers,point,isCats,cd,count)._dist;
}
private static ClusterDist closest(double[][] centers, double[] point, String[][] isCats, ClusterDist cd) {
return closest(centers, point, isCats, cd, centers.length);
}
/** Return both nearest of N cluster center/centroids, and the square-distance. */
private static ClusterDist closest(double[][] centers, double[] point, String[][] isCats, ClusterDist cd, int count) {
int min = -1;
double minSqr = Double.MAX_VALUE;
for( int cluster = 0; cluster < count; cluster++ ) {
double sqr = hex.genmodel.GenModel.KMeans_distance(centers[cluster],point,isCats,null,null);
if( sqr < minSqr ) { // Record nearest cluster
min = cluster;
minSqr = sqr;
}
}
cd._cluster = min; // Record nearest cluster
cd._dist = minSqr; // Record square-distance
return cd; // Return for flow-coding
}
// KMeans++ re-clustering
private static double[][] recluster(double[][] points, Random rand, int N, Initialization init, String[][] isCats) {
double[][] res = new double[N][];
res[0] = points[0];
int count = 1;
ClusterDist cd = new ClusterDist();
switch( init ) {
case Random:
break;
case PlusPlus: { // k-means++
while( count < res.length ) {
double sum = 0;
for (double[] point1 : points) sum += minSqr(res, point1, isCats, cd, count);
for (double[] point : points) {
if (minSqr(res, point, isCats, cd, count) >= rand.nextDouble() * sum) {
res[count++] = point;
break;
}
}
}
break;
}
case Furthest: { // Takes cluster center further from any already chosen ones
while( count < res.length ) {
double max = 0;
int index = 0;
for( int i = 0; i < points.length; i++ ) {
double sqr = minSqr(res, points[i], isCats, cd, count);
if( sqr > max ) {
max = sqr;
index = i;
}
}
res[count++] = points[index];
}
break;
}
default: throw H2O.fail();
}
return res;
}
private void randomRow(Vec[] vecs, Random rand, double[] center, double[] means, double[] mults) {
long row = Math.max(0, (long) (rand.nextDouble() * vecs[0].length()) - 1);
data(center, vecs, row, means, mults);
}
// Pick most common cat level for each cluster_centers' cat columns
private static double[][] max_cats(double[][] centers, long[][][] cats, String[][] isCats) {
for( int clu = 0; clu < centers.length; clu++ )
for( int col = 0; col < centers[0].length; col++ )
if( isCats[col] != null )
centers[clu][col] = ArrayUtils.maxIndex(cats[clu][col]);
return centers;
}
private static double[][] destandardize(double[][] centers, String[][] isCats, double[] means, double[] mults) {
int K = centers.length;
int N = centers[0].length;
double[][] value = new double[K][N];
for( int clu = 0; clu < K; clu++ ) {
System.arraycopy(centers[clu],0,value[clu],0,N);
if( mults!=null ) { // Reverse standardization
for( int col = 0; col < N; col++)
if( isCats[col] == null )
value[clu][col] = value[clu][col] / mults[col] + means[col];
}
}
return value;
}
private static void data(double[] values, Vec[] vecs, long row, double[] means, double[] mults) {
for( int i = 0; i < values.length; i++ ) {
double d = vecs[i].at(row);
values[i] = data(d, i, means, mults, vecs[i].cardinality());
}
}
private static void data(double[] values, Chunk[] chks, int row, double[] means, double[] mults) {
for( int i = 0; i < values.length; i++ ) {
double d = chks[i].atd(row);
values[i] = data(d, i, means, mults, chks[i].vec().cardinality());
}
}
/**
* Takes mean if NaN, standardize if requested.
*/
private static double data(double d, int i, double[] means, double[] mults, int cardinality) {
if(cardinality == -1) {
if( Double.isNaN(d) )
d = means[i];
if( mults != null ) {
d -= means[i];
d *= mults[i];
}
} else {
// TODO: If NaN, then replace with majority class?
if(Double.isNaN(d)) {
d = Math.min(Math.round(means[i]), cardinality-1);
if( mults != null ) {
d = 0;
}
}
}
return d;
}
/**
* This helper creates a ModelMetricsClustering from a trained model
* @param model, must contain valid statistics from training, such as _betweenss etc.
*/
private ModelMetricsClustering makeTrainingMetrics(KMeansModel model) {
ModelMetricsClustering mm = new ModelMetricsClustering(model, model._parms.train());
mm._size = model._output._size;
mm._withinss = model._output._withinss;
mm._betweenss = model._output._betweenss;
mm._totss = model._output._totss;
mm._tot_withinss = model._output._tot_withinss;
model.addMetrics(mm);
return mm;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy