All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.cluster.KMeans Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.cluster;

import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.Vector;

import java.text.NumberFormat;
import java.util.*;

/**
 * Implements the "batch" version of the K Means clustering algorithm-- that is, in one sweep, assign each point to its
 * nearest center, and then in a second sweep, reset each center to the mean of the cluster for that center, repeating
 * until convergence.
 * 

* Note that this algorithm is guaranteed to converge, since the total squared error is guaranteed to be reduced at each * step. * * @author josephramsey * @version $Id: $Id */ public class KMeans implements ClusteringAlgorithm { /** * The type of initialization in which random points are selected from the data to serve as initial centers. */ private static final int RANDOM_POINTS = 0; /** * The type of initialization in which points are assigned randomly to clusters. */ private static final int RANDOM_CLUSTERS = 1; /** * The type of initialiation in which explicit points are provided to serve as clusters. */ private static final int EXPLICIT_POINTS = 2; /** * The dissimilarity metric being used. For K means, the metric must be squared Euclidean. It's an assumption of the * algorithm. */ private final Dissimilarity metric = new SquaredErrorLoss(); /** * The data, columns as features, rows as cases. */ private Matrix data; /** * The centers. */ private Matrix centers; /** * The maximum number of interations. */ private int maxIterations = 50; /** * Current clusters. */ private List clusters; /** * Number of iterations of algorithm. */ private int iterations; /** * The number of centers (i.e. the number clusters) that the algorithm will find. */ private int numCenters; /** * The type of initialization, one of RANDOM_POINTS, */ private int initializationType = KMeans.RANDOM_POINTS; /** * True if verbose output should be printed. */ private boolean verbose; //============================CONSTRUCTOR==========================// /** * Private constructor. (Please keep it that way.) */ private KMeans() { } /** * Constructs a new KMeansBatch, initializing the algorithm by picking * numCeneters centers randomly from the data itself. * * @param numCenters The number of centers (clusters). * @return The parametrized algorithm. */ public static KMeans randomPoints(int numCenters) { KMeans algorithm = new KMeans(); algorithm.numCenters = numCenters; algorithm.initializationType = KMeans.RANDOM_POINTS; return algorithm; } /** * Constructs a new KMeansBatch, initializing the algorithm by randomly assigning each point in the data to one of * the numCenters clusters, then calculating the centroid of each cluster. * * @param numCenters The number of centers (clusters). * @return The constructed algorithm. */ public static KMeans randomClusters(int numCenters) { KMeans algorithm = new KMeans(); algorithm.numCenters = numCenters; algorithm.initializationType = KMeans.RANDOM_CLUSTERS; return algorithm; } //===========================PUBLIC METHODS=======================// private static List> convertClusterIndicesToLists(List clusterIndices) { int max = 0; for (Integer clusterIndice : clusterIndices) { if (clusterIndice > max) max = clusterIndice; } List> clusters = new ArrayList<>(); for (int i = 0; i <= max; i++) { clusters.add(new LinkedList<>()); } for (int i = 0; i < clusterIndices.size(); i++) { Integer index = clusterIndices.get(i); if (index == -1) continue; clusters.get(index).add(i); } return clusters; } /** * {@inheritDoc} *

* Runs the batch K-means clustering algorithm on the data, returning a result. */ public void cluster(Matrix data) { this.data = data; if (this.initializationType == KMeans.RANDOM_POINTS) { this.centers = pickCenters(this.numCenters, data); this.clusters = new ArrayList<>(); for (int i = 0; i < data.getNumRows(); i++) { this.clusters.add(-1); } } else if (this.initializationType == KMeans.RANDOM_CLUSTERS) { this.centers = new Matrix(this.numCenters, data.getNumColumns()); // Randomly assign points to clusters and get the initial centers of // mass from that assignment. this.clusters = new ArrayList<>(); for (int i = 0; i < data.getNumRows(); i++) { this.clusters.add(RandomUtil.getInstance() .nextInt(this.centers.getNumRows())); } moveCentersToMeans(); } else if (this.initializationType == KMeans.EXPLICIT_POINTS) { this.clusters = new ArrayList<>(); for (int i = 0; i < data.getNumRows(); i++) { this.clusters.add(-1); } } boolean changed = true; this.iterations = 0; // System.out.println("Original centers: " + centers); while (changed && (this.maxIterations == -1 || this.iterations < this.maxIterations)) { this.iterations++; // System.out.println("Iteration = " + iterations); // Step #1: Assign each point to its closest center, forming a cluster for // each center. int numChanged = reassignPoints(); changed = numChanged > 0; // Step #2: Replace each center by the center of mass of its cluster. moveCentersToMeans(); } } /** *

Getter for the field clusters.

* * @return a {@link java.util.List} object */ public List> getClusters() { return KMeans.convertClusterIndicesToLists(this.clusters); } /** * Return the maximum number of iterations, or -1 if the algorithm is allowed to run unconstrainted. * * @return This value. */ public int getMaxIterations() { return this.maxIterations; } /** * Sets the maximum number of iterations, or -1 if the algorithm is allowed to run unconstrainted. * * @param maxIterations This value. */ public void setMaxIterations(int maxIterations) { this.maxIterations = maxIterations; } /** *

getNumClusters.

* * @return a int */ public int getNumClusters() { return this.centers.getNumRows(); } /** *

getCluster.

* * @param k a int * @return a {@link java.util.List} object */ public List getCluster(int k) { List cluster = new ArrayList<>(); for (int i = 0; i < this.clusters.size(); i++) { if (this.clusters.get(i) == k) { cluster.add(i); } } return cluster; } private Dissimilarity getMetric() { return this.metric; } /** *

iterations.

* * @return the number of iterations. */ public int iterations() { return this.iterations; } /** * The squared error of the kth cluster. * * @param k The index of the cluster in question. * @return this squared error. */ private double squaredError(int k) { double squaredError = 0.0; for (int i = 0; i < this.data.getNumRows(); i++) { if (this.clusters.get(i) == k) { Vector datum = this.data.getRow(i); Vector center = this.centers.getRow(k); squaredError += this.metric.dissimilarity(datum, center); } } return squaredError; } /** * Total squared error for most recent run. * * @return the total squared error. */ private double totalSquaredError() { double totalSquaredError = 0.0; for (int k = 0; k < this.centers.getNumRows(); k++) { totalSquaredError += squaredError(k); } return totalSquaredError; } /** *

toString.

* * @return a string representation of the cluster result. */ public String toString() { NumberFormat n1 = NumberFormatUtil.getInstance().getNumberFormat(); Vector counts = countClusterSizes(); double totalSquaredError = totalSquaredError(); StringBuilder buf = new StringBuilder(); buf.append("Cluster Result (").append(this.clusters.size()) .append(" cases, ").append(this.centers.getNumColumns()) .append(" feature(s), ").append(this.centers.getNumRows()) .append(" clusters)"); for (int k = 0; k < this.centers.getNumRows(); k++) { buf.append("\n\tCluster #").append(k + 1).append(": n = ").append(counts.get(k)); buf.append(" Squared Error = ").append(n1.format(squaredError(k))); } buf.append("\n\tTotal Squared Error = ").append(n1.format(totalSquaredError)); return buf.toString(); } //==========================PRIVATE METHODS=========================// private int reassignPoints() { int numChanged = 0; for (int i = 0; i < this.data.getNumRows(); i++) { Vector datum = this.data.getRow(i); double minDissimilarity = Double.POSITIVE_INFINITY; int cluster = -1; for (int k = 0; k < this.centers.getNumRows(); k++) { Vector center = this.centers.getRow(k); double dissimilarity = getMetric().dissimilarity(datum, center); if (dissimilarity < minDissimilarity) { minDissimilarity = dissimilarity; cluster = k; } } if (cluster != this.clusters.get(i)) { this.clusters.set(i, cluster); numChanged++; } } //System.out.println("Moved " + numChanged + " points."); return numChanged; } private void moveCentersToMeans() { for (int k = 0; k < this.centers.getNumRows(); k++) { double[] sums = new double[this.centers.getNumColumns()]; int count = 0; for (int i = 0; i < this.data.getNumRows(); i++) { if (this.clusters.get(i) == k) { for (int j = 0; j < this.data.getNumColumns(); j++) { sums[j] += this.data.get(i, j); } count++; } } if (count != 0) { for (int j = 0; j < this.centers.getNumColumns(); j++) { this.centers.set(k, j, sums[j] / count); } } } } private Matrix pickCenters(int numCenters, Matrix data) { SortedSet indexSet = new TreeSet<>(); while (indexSet.size() < numCenters) { int candidate = RandomUtil.getInstance().nextInt(data.getNumRows()); indexSet.add(candidate); } int[] rows = new int[numCenters]; int i = -1; for (int row : indexSet) { rows[++i] = row; } int[] cols = new int[data.getNumColumns()]; for (int j = 0; j < data.getNumColumns(); j++) { cols[j] = j; } return data.getSelection(rows, cols).copy(); } private Vector countClusterSizes() { Vector counts = new Vector(this.centers.getNumRows()); for (int cluster : this.clusters) { if (cluster == -1) { continue; } counts.set(cluster, counts.get(cluster) + 1); } return counts; } /** *

isVerbose.

* * @return a boolean */ public boolean isVerbose() { return this.verbose; } /** * {@inheritDoc} */ public void setVerbose(boolean verbose) { this.verbose = verbose; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy