All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.ClusteringUtils Maven / Gradle / Ivy

There is a newer version: 0.13.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering;

import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.Searcher;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;

public final class ClusteringUtils {
  private ClusteringUtils() {
  }

  /**
   * Computes the summaries for the distances in each cluster.
   * @param datapoints iterable of datapoints.
   * @param centroids iterable of Centroids.
   * @return a list of OnlineSummarizers where the i-th element is the summarizer corresponding to the cluster whose
   * index is i.
   */
  public static List summarizeClusterDistances(Iterable datapoints,
                                                                 Iterable centroids,
                                                                 DistanceMeasure distanceMeasure) {
    UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
    searcher.addAll(centroids);
    List summarizers = new ArrayList<>();
    if (searcher.size() == 0) {
      return summarizers;
    }
    for (int i = 0; i < searcher.size(); ++i) {
      summarizers.add(new OnlineSummarizer());
    }
    for (Vector v : datapoints) {
      Centroid closest = (Centroid)searcher.search(v,  1).get(0).getValue();
      OnlineSummarizer summarizer = summarizers.get(closest.getIndex());
      summarizer.add(distanceMeasure.distance(v, closest));
    }
    return summarizers;
  }

  /**
   * Adds up the distances from each point to its closest cluster and returns the sum.
   * @param datapoints iterable of datapoints.
   * @param centroids iterable of Centroids.
   * @return the total cost described above.
   */
  public static double totalClusterCost(Iterable datapoints, Iterable centroids) {
    DistanceMeasure distanceMeasure = new EuclideanDistanceMeasure();
    UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
    searcher.addAll(centroids);
    return totalClusterCost(datapoints, searcher);
  }

  /**
   * Adds up the distances from each point to its closest cluster and returns the sum.
   * @param datapoints iterable of datapoints.
   * @param centroids searcher of Centroids.
   * @return the total cost described above.
   */
  public static double totalClusterCost(Iterable datapoints, Searcher centroids) {
    double totalCost = 0;
    for (Vector vector : datapoints) {
      totalCost += centroids.searchFirst(vector, false).getWeight();
    }
    return totalCost;
  }

  /**
   * Estimates the distance cutoff. In StreamingKMeans, the distance between two vectors divided
   * by this value is used as a probability threshold when deciding whether to form a new cluster
   * or not.
   * Small values (comparable to the minimum distance between two points) are preferred as they
   * guarantee with high likelihood that all but very close points are put in separate clusters
   * initially. The clusters themselves are actually collapsed periodically when their number goes
   * over the maximum number of clusters and the distanceCutoff is increased.
   * So, the returned value is only an initial estimate.
   * @param data the datapoints whose distance is to be estimated.
   * @param distanceMeasure the distance measure used to compute the distance between two points.
   * @return the minimum distance between the first sampleLimit points
   * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean)
   */
  public static double estimateDistanceCutoff(List data, DistanceMeasure distanceMeasure) {
    BruteSearch searcher = new BruteSearch(distanceMeasure);
    searcher.addAll(data);
    double minDistance = Double.POSITIVE_INFINITY;
    for (Vector vector : data) {
      double closest = searcher.searchFirst(vector, true).getWeight();
      if (minDistance > 0 && closest < minDistance) {
        minDistance = closest;
      }
      searcher.add(vector);
    }
    return minDistance;
  }

  public static  double estimateDistanceCutoff(
      Iterable data, DistanceMeasure distanceMeasure, int sampleLimit) {
    return estimateDistanceCutoff(Lists.newArrayList(Iterables.limit(data, sampleLimit)), distanceMeasure);
  }

  /**
   * Computes the Davies-Bouldin Index for a given clustering.
   * See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation
   * @param centroids list of centroids
   * @param distanceMeasure distance measure for inter-cluster distances
   * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
   * @return the Davies-Bouldin Index
   */
  public static double daviesBouldinIndex(List centroids, DistanceMeasure distanceMeasure,
                                          List clusterDistanceSummaries) {
    Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
        "Number of centroids and cluster summaries differ.");
    int n = centroids.size();
    double totalDBIndex = 0;
    // The inner loop shouldn't be reduced for j = i + 1 to n because the computation of the Davies-Bouldin
    // index is not really symmetric.
    // For a given cluster i, we look for a cluster j that maximizes the ratio of the sum of average distances
    // from points in cluster i to its center and and points in cluster j to its center to the distance between
    // cluster i and cluster j.
    // The maximization is the key issue, as the cluster that maximizes this ratio might be j for i but is NOT
    // NECESSARILY i for j.
    for (int i = 0; i < n; ++i) {
      double averageDistanceI = clusterDistanceSummaries.get(i).getMean();
      double maxDBIndex = 0;
      for (int j = 0; j < n; ++j) {
        if (i != j) {
          double dbIndex = (averageDistanceI + clusterDistanceSummaries.get(j).getMean())
              / distanceMeasure.distance(centroids.get(i), centroids.get(j));
          if (dbIndex > maxDBIndex) {
            maxDBIndex = dbIndex;
          }
        }
      }
      totalDBIndex += maxDBIndex;
    }
    return totalDBIndex / n;
  }

  /**
   * Computes the Dunn Index of a given clustering. See http://en.wikipedia.org/wiki/Dunn_index
   * @param centroids list of centroids
   * @param distanceMeasure distance measure to compute inter-centroid distance with
   * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
   * @return the Dunn Index
   */
  public static double dunnIndex(List centroids, DistanceMeasure distanceMeasure,
                                 List clusterDistanceSummaries) {
    Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
        "Number of centroids and cluster summaries differ.");
    int n = centroids.size();
    // Intra-cluster distances will come from the OnlineSummarizer, and will be the median distance (noting that
    // the median for just one value is that value).
    // A variety of metrics can be used for the intra-cluster distance including max distance between two points,
    // mean distance, etc. Median distance was chosen as this is more robust to outliers and characterizes the
    // distribution of distances (from a point to the center) better.
    double maxIntraClusterDistance = 0;
    for (OnlineSummarizer summarizer : clusterDistanceSummaries) {
      if (summarizer.getCount() > 0) {
        double intraClusterDistance;
        if (summarizer.getCount() == 1) {
          intraClusterDistance = summarizer.getMean();
        } else {
          intraClusterDistance = summarizer.getMedian();
        }
        if (maxIntraClusterDistance < intraClusterDistance) {
          maxIntraClusterDistance = intraClusterDistance;
        }
      }
    }
    double minDunnIndex = Double.POSITIVE_INFINITY;
    for (int i = 0; i < n; ++i) {
      // Distances are symmetric, so d(i, j) = d(j, i).
      for (int j = i + 1; j < n; ++j) {
        double dunnIndex = distanceMeasure.distance(centroids.get(i), centroids.get(j));
        if (minDunnIndex > dunnIndex) {
          minDunnIndex = dunnIndex;
        }
      }
    }
    return minDunnIndex / maxIntraClusterDistance;
  }

  public static double choose2(double n) {
    return n * (n - 1) / 2;
  }

  /**
   * Creates a confusion matrix by searching for the closest cluster of both the row clustering and column clustering
   * of a point and adding its weight to that cell of the matrix.
   * It doesn't matter which clustering is the row clustering and which is the column clustering. If they're
   * interchanged, the resulting matrix is the transpose of the original one.
   * @param rowCentroids clustering one
   * @param columnCentroids clustering two
   * @param datapoints datapoints whose closest cluster we need to find
   * @param distanceMeasure distance measure to use
   * @return the confusion matrix
   */
  public static Matrix getConfusionMatrix(List rowCentroids, List columnCentroids,
                                          Iterable datapoints, DistanceMeasure distanceMeasure) {
    Searcher rowSearcher = new BruteSearch(distanceMeasure);
    rowSearcher.addAll(rowCentroids);
    Searcher columnSearcher = new BruteSearch(distanceMeasure);
    columnSearcher.addAll(columnCentroids);

    int numRows = rowCentroids.size();
    int numCols = columnCentroids.size();
    Matrix confusionMatrix = new DenseMatrix(numRows, numCols);

    for (Vector vector : datapoints) {
      WeightedThing closestRowCentroid = rowSearcher.search(vector, 1).get(0);
      WeightedThing closestColumnCentroid = columnSearcher.search(vector, 1).get(0);
      int row = ((Centroid) closestRowCentroid.getValue()).getIndex();
      int column = ((Centroid) closestColumnCentroid.getValue()).getIndex();
      double vectorWeight;
      if (vector instanceof WeightedVector) {
        vectorWeight = ((WeightedVector) vector).getWeight();
      } else {
        vectorWeight = 1;
      }
      confusionMatrix.set(row, column, confusionMatrix.get(row, column) + vectorWeight);
    }

    return confusionMatrix;
  }

  /**
   * Computes the Adjusted Rand Index for a given confusion matrix.
   * @param confusionMatrix confusion matrix; not to be confused with the more restrictive ConfusionMatrix class
   * @return the Adjusted Rand Index
   */
  public static double getAdjustedRandIndex(Matrix confusionMatrix) {
    int numRows = confusionMatrix.numRows();
    int numCols = confusionMatrix.numCols();
    double rowChoiceSum = 0;
    double columnChoiceSum = 0;
    double totalChoiceSum = 0;
    double total = 0;
    for (int i = 0; i < numRows; ++i) {
      double rowSum = 0;
      for (int j = 0; j < numCols; ++j) {
        rowSum += confusionMatrix.get(i, j);
        totalChoiceSum += choose2(confusionMatrix.get(i, j));
      }
      total += rowSum;
      rowChoiceSum += choose2(rowSum);
    }
    for (int j = 0; j < numCols; ++j) {
      double columnSum = 0;
      for (int i = 0; i < numRows; ++i) {
        columnSum += confusionMatrix.get(i, j);
      }
      columnChoiceSum += choose2(columnSum);
    }
    double rowColumnChoiceSumDivTotal = rowChoiceSum * columnChoiceSum / choose2(total);
    return (totalChoiceSum - rowColumnChoiceSumDivTotal)
        / ((rowChoiceSum + columnChoiceSum) / 2 - rowColumnChoiceSumDivTotal);
  }

  /**
   * Computes the total weight of the points in the given Vector iterable.
   * @param data iterable of points
   * @return total weight
   */
  public static double totalWeight(Iterable data) {
    double sum = 0;
    for (Vector row : data) {
      Preconditions.checkNotNull(row);
      if (row instanceof WeightedVector) {
        sum += ((WeightedVector)row).getWeight();
      } else {
        sum++;
      }
    }
    return sum;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy