All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.cluster.GreedyAgglomerative Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.cluster;

import java.util.logging.Logger;

import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor;
import cc.mallet.cluster.neighbor_evaluator.Neighbor;
import cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator;
import cc.mallet.cluster.util.ClusterUtils;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletProgressMessageLogger;


/**
 * Greedily merges Instances until convergence. New merges are scored
 * using {@link NeighborEvaluator}.
 *
 * @author "Aron Culotta" 
 * @version 1.0
 * @since 1.0
 * @see HillClimbingClusterer
 */
public class GreedyAgglomerative extends HillClimbingClusterer {

	
	private static final long serialVersionUID = 1L;

	private static Logger progressLogger =
		MalletProgressMessageLogger.getLogger(GreedyAgglomerative.class.getName()+"-pl");

	/**
	 * Converged when merge score is below this value.
	 */
	protected double stoppingThreshold;

	/**
	 * True if should stop clustering.
	 */
   protected boolean converged;

	/**
	 * Cache for calls to {@link NeighborhoodEvaluator}. In some
	 * experiments, reduced running time by nearly half.
	 */
   protected PairwiseMatrix scoreCache;
	
	/**
	 *
	 * @param instancePipe Pipe for each underying {@link Instance}.
	 * @param evaluator To score potential merges.
	 * @param stoppingThreshold Clustering converges when the evaluator score is below this value.
	 * @return
	 */
	public GreedyAgglomerative (Pipe instancePipe,
															NeighborEvaluator evaluator,
															double stoppingThreshold) {
		super(instancePipe, evaluator);		
		this.stoppingThreshold = stoppingThreshold;
		this.converged = false;
	}

	/**
	 *
	 * @param instances
	 * @return A singleton clustering (each Instance in its own cluster).
	 */
	public Clustering initializeClustering (InstanceList instances) {
		reset();
		return ClusterUtils.createSingletonClustering(instances);
	}

	public boolean converged (Clustering clustering) {
		return converged;
	}

	/**
	 * Reset convergence to false so a new round of clustering can begin.
	 */
	public void reset () {
		converged = false;
		scoreCache = null;
		evaluator.reset();
	}
	
	/**
	 * For each pair of clusters, calculate the score of the {@link Neighbor}
	 * that would result from merging the two clusters. Choose the merge that
	 * obtains the highest score. If no merge improves score, return original
	 * Clustering
	 * 
	 * @param clustering
	 * @return
	 */
	public Clustering improveClustering (Clustering clustering) {
		double bestScore = Double.NEGATIVE_INFINITY;
		int[] toMerge = new int[]{-1,-1};
		for (int i = 0; i < clustering.getNumClusters(); i++) {
			for (int j = i + 1; j < clustering.getNumClusters(); j++) {
				double score = getScore(clustering, i, j);
				if (score > bestScore) {
					bestScore = score;
					toMerge[0] = i;
					toMerge[1] = j;
				}				
			}
		}
		
		converged = (bestScore < stoppingThreshold);

		if (!(converged)) {
			progressLogger.info("Merging " + toMerge[0] + "(" + clustering.size(toMerge[0]) +
													" nodes) and " + toMerge[1] + "(" + clustering.size(toMerge[1]) +
													" nodes) [" + bestScore + "] numClusters=" +
													clustering.getNumClusters());
			updateScoreMatrix(clustering, toMerge[0], toMerge[1]);
			clustering = ClusterUtils.mergeClusters(clustering, toMerge[0], toMerge[1]);
		} else {
			progressLogger.info("Converged with score " + bestScore);
		}
		return clustering;
	}
	
	/**
	 *
	 * @param clustering
	 * @param i
	 * @param j
	 * @return The score for merging these two clusters.
	 */
	protected double getScore (Clustering clustering, int i, int j) {
		if (scoreCache == null)
			scoreCache = new PairwiseMatrix(clustering.getNumInstances());

		int[] ci = clustering.getIndicesWithLabel(i);
		int[] cj = clustering.getIndicesWithLabel(j);
		if (scoreCache.get(ci[0], cj[0]) == 0.0) {
			double val = evaluator.evaluate(
				new AgglomerativeNeighbor(clustering,
																	ClusterUtils.copyAndMergeClusters(clustering, i, j),
																	ci, cj));
			for (int ni = 0; ni < ci.length; ni++) 
				for (int nj = 0; nj < cj.length; nj++)
					scoreCache.set(ci[ni], cj[nj], val);
		}

		return scoreCache.get(ci[0], cj[0]);														
	}

	/**
	 * Resets the values of clusters that have been merged.
	 * @param clustering
	 * @param i
	 * @param j
	 */
	protected void updateScoreMatrix (Clustering clustering, int i, int j) {
		int size = clustering.getNumInstances();
		int[] ci = clustering.getIndicesWithLabel(i);
		for (int ni = 0; ni < ci.length; ni++) {
			for (int nj = 0; nj < size; nj++)
				if (ci[ni] != nj)
					scoreCache.set(ci[ni], nj, 0.0);
		}
		int[] cj = clustering.getIndicesWithLabel(j);
		for (int ni = 0; ni < cj.length; ni++) {
			for (int nj = 0; nj < size; nj++)
				if (cj[ni] != nj)
					scoreCache.set(cj[ni], nj, 0.0);
		}
	}
		
	public String toString () {
		return "class=" + this.getClass().getName() +
			"\nstoppingThreshold=" + stoppingThreshold + 
			"\nneighborhoodEvaluator=[" + evaluator + "]";		
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy