All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.clustering.impl.NearestCentroidAlgorithm Maven / Gradle / Ivy

package com.expleague.ml.clustering.impl;

import com.expleague.commons.math.metrics.Metric;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecIterator;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.clustering.ClusterizationAlgorithm;
import com.expleague.commons.util.CollectionTools;
import com.expleague.commons.util.Pair;
import org.jetbrains.annotations.NotNull;

import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.function.Function;

/**
 * User: terry
 * Date: 16.01.2010
 */
public class NearestCentroidAlgorithm implements ClusterizationAlgorithm {
  private final Metric metric;
  private final double acceptanceDistance;

  public NearestCentroidAlgorithm(final Metric metric, final double acceptanceDistance) {
    this.metric = metric;
    this.acceptanceDistance = acceptanceDistance;
  }

  @NotNull
  @Override
  public Collection> cluster(final Collection dataSet, final Function data2DVector) {
    final Collection,Vec>> clusters = new HashSet<>();
    for (final X data : dataSet) {
      final Vec dataVector = data2DVector.apply(data);
      Pair, Vec> nearestCluster = null;
      double minDistance = Double.MAX_VALUE;
      for (final Pair, Vec> pair : clusters) {
        final double candidateDistance = metric.distance(pair.getSecond(), dataVector);
        if (candidateDistance < minDistance) {
          minDistance = candidateDistance ;
          nearestCluster = pair;
        }
      }
//      if (nearestCluster != null) {
//        System.out.println(dataVector.toString().substring(0, 100));
//        System.out.println("");
//        System.out.println(nearestCluster.getSecond().toString().substring(0, 100));
//        System.out.println(minDistance);
//        System.out.println("");
//      }
      if (minDistance > acceptanceDistance) {
        clusters.add(Pair.,Vec>create(new HashSet<>(Collections.singleton(data)), dataVector));
      } else {
        final Collection collection = nearestCluster.getFirst();
        final Vec centroid = nearestCluster.getSecond();
        VecTools.scale(centroid, collection.size());
        VecTools.append(centroid, dataVector);
        collection.add(data);
        VecTools.scale(centroid, 1./collection.size());
        final VecIterator it = centroid.nonZeroes();
        while (it.advance()) {
          if (it.value() < 0.01) {
            it.setValue(0);
          }
        }
      }
    }
    return CollectionTools.mapFirst(clusters);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy