All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.cluster.evaluate.PairF1Evaluator Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.cluster.evaluate;

import cc.mallet.cluster.Clustering;

/**
 * Evaluates two clustering using pairwise comparisons. For each pair
 * of Instances, compute false positives and false negatives as in
 * classification performance, determined by whether the pair should
 * be in the same cluster or not.
 *
 * @author "Aron Culotta" 
 * @version 1.0
 * @since 1.0
 * @see ClusteringEvaluator
 */
public class PairF1Evaluator extends ClusteringEvaluator {
	
	int tpTotal, fnTotal, fpTotal;
	
	public PairF1Evaluator () {
		tpTotal = fnTotal = fpTotal = 0;
	}

	public String evaluate (Clustering truth, Clustering predicted) {
		double[] vals = getEvaluationScores(truth, predicted);
		return "pr=" + vals[0] + " re=" + vals[1] + " f1=" + vals[2];
	}

	public String evaluateTotals () {
		double prTotal = (double)tpTotal / (tpTotal+fpTotal);
		double recTotal = (double)tpTotal / (tpTotal+fnTotal);
		double f1Total = 2*prTotal*recTotal/(prTotal+recTotal);
		return "pr=" + prTotal + " re=" + recTotal + " f1=" + f1Total;
	}

	@Override
	public double[] getEvaluationScores(Clustering truth, Clustering predicted) {
		int tp, fn, fp;
		tp = fn = fp = 0;
		
		for (int i = 0; i < predicted.getNumClusters(); i++) {
			int[] predIndices = predicted.getIndicesWithLabel(i);
			
			for (int j = 0; j < predIndices.length; j++) 
				for (int k = j + 1; k < predIndices.length; k++) 
					if (truth.getLabel(predIndices[j]) == truth.getLabel(predIndices[k]))
						tp++;
					else 
						fp++;
		}

		for (int i = 0; i < truth.getNumClusters(); i++) {
			int[] trueIndices = truth.getIndicesWithLabel(i);
			for (int j = 0; j < trueIndices.length; j++) 
				for (int k = j + 1; k < trueIndices.length; k++) 
					if (predicted.getLabel(trueIndices[j]) != predicted.getLabel(trueIndices[k]))
						fn++;
		}

		double pr = (double)tp / (tp+fp);
		double rec = (double)tp / (tp+fn);
		double f1 = 2*pr*rec/(pr+rec);
		this.tpTotal += tp;
		this.fpTotal += fp;
		this.fnTotal += fn;

		return new double[]{pr, rec, f1};
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy