com.expleague.ml.clustering.impl.KMeansAlgorithm Maven / Gradle / Ivy
package com.expleague.ml.clustering.impl;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.clustering.ClusterizationAlgorithm;
import org.jetbrains.annotations.NotNull;
import java.util.*;
import static com.expleague.commons.math.vectors.VecTools.scale;
/**
* User: solar
* Date: 13.02.2010
* Time: 22:49:47
*/
public class KMeansAlgorithm implements ClusterizationAlgorithm {
int clustCount;
double maxDist;
public KMeansAlgorithm(final int clustCount, final double maxDist) {
this.clustCount = clustCount;
this.maxDist = maxDist;
}
@NotNull
@Override
public Collection extends Collection> cluster(final Collection dataSet, final Computable data2DVector) {
Vec[] centroids = new Vec[clustCount];
final List> clusters = new ArrayList>();
while (clusters.size() < centroids.length) {
clusters.add(new HashSet());
}
int fullIndex = 0;
for (final T point : dataSet) {
final Vec vec = data2DVector.compute(point);
final int index = fullIndex++ % centroids.length;
if (centroids[index] == null)
//noinspection unchecked
centroids[index] = VecTools.copy(vec);
else
VecTools.append(centroids[index], vec);
clusters.get(index).add(point);
}
for (int i = 0; i < centroids.length; i++) {
VecTools.scale(centroids[i], 1./clusters.size());
}
int iteration = 0;
do {
final Vec[] nextCentroids = new Vec[clustCount];
for (int i = 0; i < centroids.length; i++) {
clusters.get(i).clear();
}
for (final T point : dataSet) {
final Vec vec = data2DVector.compute(point);
double minResemblance = Double.MAX_VALUE;
int minIndex = -1;
for (int i = 0; i < centroids.length; i++) {
final Vec centroid = centroids[i];
final double resemblance = VecTools.distanceAV(centroid, vec);
if (resemblance < minResemblance) {
minResemblance = resemblance;
minIndex = i;
}
}
clusters.get(minIndex).add(point);
VecTools.append(nextCentroids[minIndex], vec);
}
for (int i = 0; i < centroids.length; i++) {
VecTools.scale(centroids[i], 1./clusters.size());
}
centroids = nextCentroids;
}
while (++iteration < 10);
final Iterator> iter = clusters.iterator();
int index = 0;
while (iter.hasNext()) {
final Set cluster = iter.next();
double meanDist = 0;
final Vec centroid = centroids[index++];
for (final T term : cluster) {
meanDist += VecTools.distanceAV(data2DVector.compute(term), centroid);
}
meanDist /= cluster.size();
if (meanDist > maxDist) {
// iter.remove();
}
}
return clusters;
}
}