All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.clustering.impl.KMeansAlgorithm Maven / Gradle / Ivy

package com.expleague.ml.clustering.impl;

import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.clustering.ClusterizationAlgorithm;
import org.jetbrains.annotations.NotNull;

import java.util.*;

import static com.expleague.commons.math.vectors.VecTools.scale;

/**
 * User: solar
 * Date: 13.02.2010
 * Time: 22:49:47
 */
public class KMeansAlgorithm implements ClusterizationAlgorithm {
  int clustCount;
  double maxDist;

  public KMeansAlgorithm(final int clustCount, final double maxDist) {
    this.clustCount = clustCount;
    this.maxDist = maxDist;
  }

  @NotNull
  @Override
  public Collection> cluster(final Collection dataSet, final Computable data2DVector) {
    Vec[] centroids = new Vec[clustCount];
    final List> clusters = new ArrayList>();
    while (clusters.size() < centroids.length) {
      clusters.add(new HashSet());
    }
    int fullIndex = 0;
    for (final T point : dataSet) {
      final Vec vec = data2DVector.compute(point);
      final int index = fullIndex++ % centroids.length;
      if (centroids[index] == null)
        //noinspection unchecked
        centroids[index] = VecTools.copy(vec);
      else
        VecTools.append(centroids[index], vec);
      clusters.get(index).add(point);
    }
    for (int i = 0; i < centroids.length; i++) {
      VecTools.scale(centroids[i], 1./clusters.size());
    }

    int iteration = 0;
    do {
      final Vec[] nextCentroids = new Vec[clustCount];
      for (int i = 0; i < centroids.length; i++) {
        clusters.get(i).clear();
      }

      for (final T point : dataSet) {
        final Vec vec = data2DVector.compute(point);
        double minResemblance = Double.MAX_VALUE;
        int minIndex = -1;
        for (int i = 0; i < centroids.length; i++) {
          final Vec centroid = centroids[i];
          final double resemblance = VecTools.distanceAV(centroid, vec);
          if (resemblance < minResemblance) {
            minResemblance = resemblance;
            minIndex = i;
          }
        }
        clusters.get(minIndex).add(point);
        VecTools.append(nextCentroids[minIndex], vec);
      }

      for (int i = 0; i < centroids.length; i++) {
        VecTools.scale(centroids[i], 1./clusters.size());
      }
      centroids = nextCentroids;
    }
    while (++iteration < 10);

    final Iterator> iter = clusters.iterator();
    int index = 0;
    while (iter.hasNext()) {
      final Set cluster = iter.next();
      double meanDist = 0;
      final Vec centroid = centroids[index++];
      for (final T term : cluster) {
        meanDist += VecTools.distanceAV(data2DVector.compute(term), centroid);
      }
      meanDist /= cluster.size();
      if (meanDist > maxDist) {
//        iter.remove();
      }
    }

    return clusters;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy