All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.clustering.impl.ForelAlgorithm Maven / Gradle / Ivy

package com.expleague.ml.clustering.impl;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecIterator;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.util.logging.Logger;
import com.expleague.ml.clustering.ClusterizationAlgorithm;
import org.jetbrains.annotations.NotNull;

import java.util.*;
import java.util.function.Function;

import static com.expleague.commons.math.vectors.VecTools.scale;

/**
 * User: solar
 * Date: 13.02.2010
 * Time: 20:32:44
 */
public class ForelAlgorithm implements ClusterizationAlgorithm {
  Logger LOG = Logger.create(ForelAlgorithm.class);
  private final double maxDist0;

  public ForelAlgorithm(final double maxDist0) {
    this.maxDist0 = maxDist0;
  }//  @Required

  @NotNull
  @Override
  public Collection> cluster(final Collection dataSet, final Function data2DVector) {
    int count = 0;
    final List> clusters = new ArrayList>();
    final Set unclassified = new HashSet(dataSet);
    while (!unclassified.isEmpty()) {
      final T first = unclassified.iterator().next();
      final Vec vec = data2DVector.apply(first);
      @SuppressWarnings({"unchecked"})
      final Set cluster = new HashSet();
      Vec centroid = VecTools.copy(vec);
      int changesCount;
      double maxDist = maxDist0 + (1 - maxDist0) * Math.max(0, (1 - Math.log(1000) / Math.log(dataSet.size())));
      do {
        changesCount = 0;
        final Vec nextCentroid = VecTools.copy(centroid);
        cluster.add(first);
        unclassified.remove(first);
        for (final T currentTerm : unclassified) {
          final Vec currentVec = data2DVector.apply(currentTerm);
          final double distance = 1 - VecTools.cosine(centroid, currentVec);
          count ++;
          if (distance < maxDist && !cluster.contains(currentTerm)) {
            changesCount++;
            VecTools.scale(nextCentroid, cluster.size());
            VecTools.append(nextCentroid, currentVec);
            VecTools.scale(nextCentroid, 1. / (cluster.size() + 1));
            cluster.add(currentTerm);
          }
          else if (distance >= maxDist && cluster.contains(currentTerm)) {
            changesCount++;
            VecTools.scale(nextCentroid, -cluster.size());
            VecTools.append(nextCentroid, currentVec);
            VecTools.scale(nextCentroid, -1. / (cluster.size() - 1));
            cluster.remove(currentTerm);
          }
        }
        centroid = nextCentroid;
        final VecIterator iter = centroid.nonZeroes();
        while (iter.advance()) {
          if (iter.value() < 0.01)
            iter.setValue(0);
        }
        maxDist *= 0.99;
      }
      while (changesCount > 0);
      clusters.add(cluster);
      unclassified.removeAll(cluster);
    }
//    LOG.debug("Multiplications " + count + " for " + dataSet.size() + " objects");
    return clusters;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy