hex.kmeans.KMeans Maven / Gradle / Ivy
package hex.kmeans;
import hex.ModelBuilder;
import hex.schemas.KMeansV2;
import hex.schemas.ModelBuilderSchema;
import water.*;
import water.H2O.H2OCountedCompleter;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;
import java.util.ArrayList;
import java.util.Random;
/**
* Scalable K-Means++ (KMeans||)
* http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf
* http://www.youtube.com/watch?v=cigXAxV3XcY
*/
public class KMeans extends ModelBuilder {
public enum Initialization {
None, PlusPlus, Furthest
}
// Number of categorical columns
private int _ncats;
// Number of reinitialization attempts for preventing empty clusters
transient private int _reinit_attempts;
// Called from an http request
public KMeans( KMeansModel.KMeansParameters parms ) { super("K-means",parms); init(false); }
public ModelBuilderSchema schema() { return new KMeansV2(); }
/** Start the KMeans training Job on an F/J thread. */
@Override public Job trainModel() {
return start(new KMeansDriver(), _parms._max_iters);
}
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();".
*
* Validate K, max_iters and the number of rows. Precompute the number of
* categorical columns. */
@Override public void init(boolean expensive) {
super.init(expensive);
if( _parms._K < 2 || _parms._K > 9999999 ) error("_K", "K must be between 2 and 10 million");
if( _parms._max_iters < 1 || _parms._max_iters > 999999) error("_max_iters", "must be between 1 and a million");
if( _train == null ) return; // Nothing more to check
if( _train.numRows() < _parms._K ) error("_K","Cannot make " + _parms._K + " clusters out of " + _train.numRows() + " rows.");
for( Vec v : _train.vecs() )
if( v.isEnum() ) _ncats++;
// Sort columns, so the categoricals are all up front. They use a
// different distance metric than numeric columns.
Vec vecs[] = _train.vecs();
int ncats=0, len=vecs.length; // Feature count;
while( ncats != len ) {
while( ncats < len && vecs[ncats].isEnum() ) ncats++;
while( len > 0 && !vecs[len-1].isEnum() ) len--;
if( ncats < len-1 ) { _train.swap(ncats,len-1); _valid.swap(ncats,len-1); }
}
_ncats = ncats;
}
// ----------------------
private class KMeansDriver extends H2OCountedCompleter {
@Override protected void compute2() {
KMeansModel model = null;
try {
init(true);
_parms.lock_frames(KMeans.this); // Fetch & read-lock input frames
// The model to be built
model = new KMeansModel(dest(), _parms, new KMeansModel.KMeansOutput(KMeans.this));
model.delete_and_lock(_key);
// means are used to impute NAs
model._output._ncats = _ncats;
Vec vecs[] = _train.vecs();
final int N = vecs.length; // Feature count
double[] means = new double[N];
for( int i = 0; i < N; i++ )
means[i] = vecs[i].mean();
// mults & means for normalization
double[] mults = null;
if( _parms._normalize ) {
mults = new double[N];
for( int i = 0; i < N; i++ ) {
double sigma = vecs[i].sigma();
mults[i] = normalize(sigma) ? 1.0 / sigma : 1.0;
}
}
// Initialize clusters
Random rand = water.util.RandomUtils.getRNG(_parms._seed - 1);
double clusters[][]; // Normalized cluster centers
if( _parms._init == Initialization.None ) {
// Initialize all clusters to random rows
clusters = model._output._clusters = new double[_parms._K][_train.numCols()];
for( double[] cluster : clusters )
randomRow(vecs, rand, cluster, means, mults);
} else {
clusters = new double[1][vecs.length];
// Initialize first cluster to random row
randomRow(vecs, rand, clusters[0], means, mults);
while( model._output._iters < 5 ) {
// Sum squares distances to clusters
SumSqr sqr = new SumSqr(clusters,means,mults,_ncats).doAll(vecs);
// Sample with probability inverse to square distance
Sampler sampler = new Sampler(clusters, means, mults, _ncats, sqr._sqr, _parms._K * 3, _parms._seed).doAll(vecs);
clusters = ArrayUtils.append(clusters,sampler._sampled);
// Fill in sample clusters into the model
if( !isRunning() ) return; // Stopped/cancelled
model._output._clusters = denormalize(clusters, _ncats, means, mults);
model._output._mse = sqr._sqr/_train.numRows();
model._output._iters++; // One iteration done
// This doesn't count towards model building (we didn't account these iterations as work to be done during construction)
// update(1); // One unit of work
model.update(_key); // Early version of model is visible
}
// Recluster down to K normalized clusters
clusters = recluster(clusters, rand);
}
model._output._iters = 0; // Reset iteration count
// ---
// Run the main KMeans Clustering loop
// Stop after enough iterations
LOOP:
for( ; model._output._iters < _parms._max_iters; model._output._iters++ ) {
if( !isRunning() ) return; // Stopped/cancelled
Lloyds task = new Lloyds(clusters,means,mults,_ncats, _parms._K).doAll(vecs);
// Pick the max categorical level for clusters' center
max_cats(task._cMeans,task._cats);
// Handle the case where some clusters go dry. Rescue only 1 cluster
// per iteration ('cause we only tracked the 1 worst row)
boolean badrow=false;
for( int clu=0; clu<_parms._K; clu++ ) {
if (task._rows[clu] == 0) {
// If we see 2 or more bad rows, just re-run Lloyds to get the
// next-worst row. We don't count this as an iteration, because
// we're not really adjusting the centers, we're trying to get
// some centers *at-all*.
if (badrow) {
Log.warn("KMeans: Re-running Lloyds to re-init another cluster");
model._output._iters--; // Do not count against iterations
if (_reinit_attempts++ < _parms._K) {
continue LOOP; // Rerun Lloyds, and assign points to centroids
} else {
_reinit_attempts = 0;
break; //give up and accept empty cluster
}
}
long row = task._worst_row;
Log.warn("KMeans: Re-initializing cluster " + clu + " to row " + row);
data(clusters[clu] = task._cMeans[clu], vecs, row, means, mults);
task._rows[clu] = 1;
badrow = true;
}
}
// Fill in the model; denormalized centers
model._output._clusters = denormalize(task._cMeans, _ncats, means, mults);
model._output._rows = task._rows;
model._output._mses = task._cSqr;
double ssq = 0; // sum squared error
for( int i=0; i<_parms._K; i++ ) {
ssq += model._output._mses[i]; // sum squared error all clusters
model._output._mses[i] /= task._rows[i]; // mse per-cluster
}
model._output._mse = ssq/_train.numRows(); // mse total
model.update(_key); // Update model in K/V store
update(1); // One unit of work
// Compute change in clusters centers
double sum=0;
for( int clu=0; clu<_parms._K; clu++ )
sum += distance(clusters[clu],task._cMeans[clu],_ncats);
sum /= N; // Average change per feature
Log.info("KMeans: Change in cluster centers="+sum);
if( sum < 1e-6 ) break; // Model appears to be stable
clusters = task._cMeans; // Update cluster centers
StringBuilder sb = new StringBuilder();
sb.append("KMeans: iter: ").append(model._output._iters).append(", MSE=").append(model._output._mse);
for( int i=0; i<_parms._K; i++ )
sb.append(", ").append(task._cSqr[i]).append("/").append(task._rows[i]);
Log.info(sb);
}
} catch( Throwable t ) {
t.printStackTrace();
cancel2(t);
throw t;
} finally {
if( model != null ) model.unlock(_key);
_parms.unlock_frames(KMeans.this);
done(); // Job done!
}
tryComplete();
}
}
// -------------------------------------------------------------------------
// Initial sum-of-square-distance to nearest cluster
private static class SumSqr extends MRTask {
// IN
double[][] _clusters;
double[] _means, _mults; // Normalization
final int _ncats;
// OUT
double _sqr;
SumSqr( double[][] clusters, double[] means, double[] mults, int ncats ) {
_clusters = clusters;
_means = means;
_mults = mults;
_ncats = ncats;
}
@Override public void map(Chunk[] cs) {
double[] values = new double[cs.length];
ClusterDist cd = new ClusterDist();
for( int row = 0; row < cs[0]._len; row++ ) {
data(values, cs, row, _means, _mults);
_sqr += minSqr(_clusters, values, _ncats, cd);
}
_means = _mults = null;
_clusters = null;
}
@Override public void reduce(SumSqr other) { _sqr += other._sqr; }
}
// -------------------------------------------------------------------------
// Sample rows with increasing probability the farther they are from any
// cluster.
private static class Sampler extends MRTask {
// IN
double[][] _clusters;
double[] _means, _mults; // Normalization
final int _ncats;
final double _sqr; // Min-square-error
final double _probability; // Odds to select this point
final long _seed;
// OUT
double[][] _sampled; // New clusters
Sampler( double[][] clusters, double[] means, double[] mults, int ncats, double sqr, double prob, long seed ) {
_clusters = clusters;
_means = means;
_mults = mults;
_ncats = ncats;
_sqr = sqr;
_probability = prob;
_seed = seed;
}
@Override public void map(Chunk[] cs) {
double[] values = new double[cs.length];
ArrayList list = new ArrayList<>();
Random rand = RandomUtils.getRNG(_seed + cs[0].start());
ClusterDist cd = new ClusterDist();
for( int row = 0; row < cs[0]._len; row++ ) {
data(values, cs, row, _means, _mults);
double sqr = minSqr(_clusters, values, _ncats, cd);
if( _probability * sqr > rand.nextDouble() * _sqr )
list.add(values.clone());
}
_sampled = new double[list.size()][];
list.toArray(_sampled);
_clusters = null;
_means = _mults = null;
}
@Override public void reduce(Sampler other) {
_sampled = ArrayUtils.append(_sampled, other._sampled);
}
}
// ---------------------------------------
// A Lloyd's pass:
// Find nearest cluster for every point;
// Compute new mean/center & variance & rows for each cluster;
// Compute distance between clusters
// Compute total sqr distance
private static class Lloyds extends MRTask {
// IN
double[][] _clusters;
double[] _means, _mults; // Normalization
final int _ncats, _K;
// OUT
double[][] _cMeans; // Means for each cluster
long[/*K*/][/*ncats*/][] _cats; // Histogram of cat levels
double[] _cSqr; // Sum of squares for each cluster
long[] _rows; // Rows per cluster
long _worst_row; // Row with max err
double _worst_err; // Max-err-row's max-err
Lloyds( double[][] clusters, double[] means, double[] mults, int ncats, int K ) {
_clusters = clusters;
_means = means;
_mults = mults;
_ncats = ncats;
_K = K;
}
@Override public void map(Chunk[] cs) {
int N = cs.length;
assert _clusters[0].length==N;
_cMeans = new double[_K][N];
_cSqr = new double[_K];
_rows = new long[_K];
// Space for cat histograms
_cats = new long[_K][_ncats][];
for( int clu=0; clu<_K; clu++ )
for( int col=0; col<_ncats; col++ )
_cats[clu][col] = new long[cs[col].vec().cardinality()];
_worst_err = 0;
// Find closest cluster for each row
double[] values = new double[N];
ClusterDist cd = new ClusterDist();
for( int row = 0; row < cs[0]._len; row++ ) {
data(values, cs, row, _means, _mults);
closest(_clusters, values, _ncats, cd);
int clu = cd._cluster;
assert clu != -1; // No broken rows
_cSqr[clu] += cd._dist;
// Add values and increment counter for chosen cluster
for( int col = 0; col < _ncats; col++ )
_cats[clu][col][(int)values[col]]++; // Histogram the cats
for( int col = _ncats; col < N; col++ )
_cMeans[clu][col] += values[col];
_rows[clu]++;
// Track worst row
if( cd._dist > _worst_err) { _worst_err = cd._dist; _worst_row = cs[0].start()+row; }
}
// Scale back down to local mean
for( int clu = 0; clu < _K; clu++ )
if( _rows[clu] != 0 ) ArrayUtils.div(_cMeans[clu],_rows[clu]);
_clusters = null;
_means = _mults = null;
}
@Override public void reduce(Lloyds mr) {
for( int clu = 0; clu < _K; clu++ ) {
long ra = _rows[clu];
long rb = mr._rows[clu];
double[] ma = _cMeans[clu];
double[] mb = mr._cMeans[clu];
for( int c = 0; c < ma.length; c++ ) // Recursive mean
if( ra+rb > 0 ) ma[c] = (ma[c] * ra + mb[c] * rb) / (ra + rb);
}
ArrayUtils.add(_cats, mr._cats);
ArrayUtils.add(_cSqr, mr._cSqr);
ArrayUtils.add(_rows, mr._rows);
// track global worst-row
if( _worst_err < mr._worst_err) { _worst_err = mr._worst_err; _worst_row = mr._worst_row; }
}
}
// A pair result: nearest cluster, and the square distance
private static final class ClusterDist { int _cluster; double _dist; }
private static double minSqr(double[][] clusters, double[] point, int ncats, ClusterDist cd) {
return closest(clusters, point, ncats, cd, clusters.length)._dist;
}
private static double minSqr(double[][] clusters, double[] point, int ncats, ClusterDist cd, int count) {
return closest(clusters,point,ncats,cd,count)._dist;
}
private static ClusterDist closest(double[][] clusters, double[] point, int ncats, ClusterDist cd) {
return closest(clusters, point, ncats, cd, clusters.length);
}
private static double distance(double[] cluster, double[] point, int ncats) {
double sqr = 0; // Sum of dimensional distances
int pts = point.length; // Count of valid points
// Categorical columns first. Only equals/unequals matters (i.e., distance is either 0 or 1).
for(int column = 0; column < ncats; column++) {
double d = point[column];
if( Double.isNaN(d) ) pts--;
else if( d != cluster[column] )
sqr += 1.0; // Manhatten distance
}
// Numeric column distance
for( int column = ncats; column < cluster.length; column++ ) {
double d = point[column];
if( Double.isNaN(d) ) pts--; // Do not count
else {
double delta = d - cluster[column];
sqr += delta * delta;
}
}
// Scale distance by ratio of valid dimensions to all dimensions - since
// we did not add any error term for the missing point, the sum of errors
// is small - ratio up "as if" the missing error term is equal to the
// average of other error terms. Same math another way:
// double avg_dist = sqr / pts; // average distance per feature/column/dimension
// sqr = sqr * point.length; // Total dist is average*#dimensions
if( 0 < pts && pts < point.length )
sqr *= point.length / pts;
return sqr;
}
/** Return both nearest of N cluster/centroids, and the square-distance. */
private static ClusterDist closest(double[][] clusters, double[] point, int ncats, ClusterDist cd, int count) {
int min = -1;
double minSqr = Double.MAX_VALUE;
for( int cluster = 0; cluster < count; cluster++ ) {
double sqr = distance(clusters[cluster],point,ncats);
if( sqr < minSqr ) { // Record nearest cluster
min = cluster;
minSqr = sqr;
}
}
cd._cluster = min; // Record nearest cluster
cd._dist = minSqr; // Record square-distance
return cd; // Return for flow-coding
}
// For KMeansModel scoring; just the closest cluster
static int closest(double[][] clusters, double[] point, int ncats) {
int min = -1;
double minSqr = Double.MAX_VALUE;
for( int cluster = 0; cluster < clusters.length; cluster++ ) {
double sqr = distance(clusters[cluster],point,ncats);
if( sqr < minSqr ) { // Record nearest cluster
min = cluster;
minSqr = sqr;
}
}
return min;
}
// KMeans++ re-clustering
private double[][] recluster(double[][] points, Random rand) {
double[][] res = new double[_parms._K][];
res[0] = points[0];
int count = 1;
ClusterDist cd = new ClusterDist();
switch( _parms._init ) {
case None:
break;
case PlusPlus: { // k-means++
while( count < res.length ) {
double sum = 0;
for (double[] point1 : points) sum += minSqr(res, point1, _ncats, cd, count);
for (double[] point : points) {
if (minSqr(res, point, _ncats, cd, count) >= rand.nextDouble() * sum) {
res[count++] = point;
break;
}
}
}
break;
}
case Furthest: { // Takes cluster further from any already chosen ones
while( count < res.length ) {
double max = 0;
int index = 0;
for( int i = 0; i < points.length; i++ ) {
double sqr = minSqr(res, points[i], _ncats, cd, count);
if( sqr > max ) {
max = sqr;
index = i;
}
}
res[count++] = points[index];
}
break;
}
default: throw H2O.fail();
}
return res;
}
private void randomRow(Vec[] vecs, Random rand, double[] cluster, double[] means, double[] mults) {
long row = Math.max(0, (long) (rand.nextDouble() * vecs[0].length()) - 1);
data(cluster, vecs, row, means, mults);
}
private static boolean normalize(double sigma) {
// TODO unify handling of constant columns
return sigma > 1e-6;
}
// Pick most common cat level for each cluster_centers' cat columns
private static double[][] max_cats(double[][] clusters, long[][][] cats) {
int K = cats.length;
int ncats = cats[0].length;
for( int clu = 0; clu < K; clu++ )
for( int col = 0; col < ncats; col++ ) // Cats use max level for cluster center
clusters[clu][col] = ArrayUtils.maxIndex(cats[clu][col]);
return clusters;
}
private static double[][] denormalize(double[][] clusters, int ncats, double[] means, double[] mults) {
int K = clusters.length;
int N = clusters[0].length;
double[][] value = new double[K][N];
for( int clu = 0; clu < K; clu++ ) {
System.arraycopy(clusters[clu],0,value[clu],0,N);
if( mults!=null ) // Reverse normalization
for( int col = ncats; col < N; col++ )
value[clu][col] = value[clu][col] / mults[col] + means[col];
}
return value;
}
private static void data(double[] values, Vec[] vecs, long row, double[] means, double[] mults) {
for( int i = 0; i < values.length; i++ ) {
double d = vecs[i].at(row);
values[i] = data(d, i, means, mults, vecs[i].cardinality());
}
}
private static void data(double[] values, Chunk[] chks, int row, double[] means, double[] mults) {
for( int i = 0; i < values.length; i++ ) {
double d = chks[i].at0(row);
values[i] = data(d, i, means, mults, chks[i].vec().cardinality());
}
}
/**
* Takes mean if NaN, normalize if requested.
*/
private static double data(double d, int i, double[] means, double[] mults, int cardinality) {
if(cardinality == -1) {
if( Double.isNaN(d) )
d = means[i];
if( mults != null ) {
d -= means[i];
d *= mults[i];
}
} else {
// TODO: If NaN, then replace with majority class?
if(Double.isNaN(d))
d = Math.min(Math.round(means[i]), cardinality-1);
}
return d;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy