All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.cluster.neighbor_evaluator.PairwiseEvaluator Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.cluster.neighbor_evaluator;


import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;

import cc.mallet.classify.Classifier;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.types.MatrixOps;

/**
 * Uses a {@link Classifier} over pairs of {@link Instances} to score
 * {@link Neighbor}. Currently only supports {@link
 * AgglomerativeNeighbor}s.
 *
 * @author "Aron Culotta" 
 * @version 1.0
 * @since 1.0
 * @see ClassifyingNeighborEvaluator
 */
public class PairwiseEvaluator extends ClassifyingNeighborEvaluator {

	private static final long serialVersionUID = 1L;

	/**
	 * How to combine a set of pairwise scores (e.g. mean, max, ...).
	 */
	CombiningStrategy combiningStrategy;

	/**
	 * If true, score all edges involved in a merge. If false, only
	 * score the edges that croess the boundaries of the clusters being
	 * merged.
	 */
	boolean mergeFirst;

	/**
	 * Cache for calls to getScore. In some experiments, reduced running
	 * time by nearly half.
	 */
	PairwiseMatrix scoreCache;
	
	/**
	 *
	 * @param classifier Classifier to assign scores to {@link
	 * Neighbor}s for which a pair of Instances has been merged.
	 * @param scoringLabel The predicted label that corresponds to a
	 * positive example (e.g. "YES").
	 * @param combiningStrategy How to combine the pairwise scores
	 * (e.g. max, mean, ...).
	 * @param mergeFirst If true, score all edges involved in a
	 * merge. If false, only score the edges that cross the boundaries
	 * of the clusters being merged.
	 * @return
	 */
	public PairwiseEvaluator (Classifier classifier,
														String scoringLabel,
														CombiningStrategy combiningStrategy,
														boolean mergeFirst) {
		super(classifier, scoringLabel);
		this.combiningStrategy = combiningStrategy;
		this.mergeFirst = mergeFirst;
	}

	public double[] evaluate (Neighbor[] neighbors) {
		double[] scores = new double[neighbors.length];
		for (int i = 0; i < neighbors.length; i++)
			scores[i] = evaluate(neighbors[i]);
		return scores;
	}
	
	public double evaluate (Neighbor neighbor) {
 		if (!(neighbor instanceof AgglomerativeNeighbor))
 			throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
 		AgglomerativeNeighbor aneighbor = (AgglomerativeNeighbor) neighbor;

		Clustering original = neighbor.getOriginal();
//		int[] mergedIndices = ((AgglomerativeNeighbor)neighbor).getNewCluster();
		int[] cluster1 = aneighbor.getOldClusters()[0];
		int[] cluster2 = aneighbor.getOldClusters()[1];
		ArrayList scores = new ArrayList();

		for (int i = 0; i < cluster1.length; i++) // Between cluster scores.
			for (int j = 0; j < cluster2.length; j++) {
				AgglomerativeNeighbor pwneighbor =
					new AgglomerativeNeighbor(original,	original, cluster1[i], cluster2[j]);
				scores.add(new Double(getScore(pwneighbor)));
			}
		if (mergeFirst) { // Also add w/in cluster scores.
			for (int i = 0; i < cluster1.length; i++)
				for (int j = i + 1; j < cluster1.length; j++) {
					AgglomerativeNeighbor pwneighbor =
						new AgglomerativeNeighbor(original,	original, cluster1[i], cluster1[j]);
				scores.add(new Double(getScore(pwneighbor)));				
			}
			for (int i = 0; i < cluster2.length; i++)
				for (int j = i + 1; j < cluster2.length; j++) {
					AgglomerativeNeighbor pwneighbor =
						new AgglomerativeNeighbor(original,	original, cluster2[i], cluster2[j]);
				scores.add(new Double(getScore(pwneighbor)));				
			}				
		}
				
// XXX This breaks during training if original cluster does not agree with mergedIndices.		
// 		for (int i = 0; i < mergedIndices.length; i++) {
//			for (int j = i + 1; j < mergedIndices.length; j++) {
//				if ((original.getLabel(mergedIndices[i]) != original.getLabel(mergedIndices[j])) || mergeFirst) {
//					AgglomerativeNeighbor pwneighbor =
//						new AgglomerativeNeighbor(original,	original,
//																			mergedIndices[i], mergedIndices[j]);
//					scores.add(new Double(getScore(pwneighbor)));
//				}
//			}
//		}

		if (scores.size() < 1)
			throw new IllegalStateException("No pairs of Instances were scored.");
		
 		double[] vals = new double[scores.size()];
		for (int i = 0; i < vals.length; i++)
			vals[i] = ((Double)scores.get(i)).doubleValue();
 		return combiningStrategy.combine(vals);
	}

	public void reset () {
		scoreCache = null;
	}
	
	public String toString () {
		return "class=" + this.getClass().getName() +
			" classifier=" + classifier.getClass().getName();
	}

	private double getScore (AgglomerativeNeighbor pwneighbor) {
		if (scoreCache == null)
			scoreCache = new PairwiseMatrix(pwneighbor.getOriginal().getNumInstances());
		int[] indices = pwneighbor.getNewCluster();
		if (scoreCache.get(indices[0], indices[1]) == 0.0) {
			scoreCache.set(indices[0], indices[1],
								 classifier.classify(pwneighbor).getLabelVector().value(scoringLabel));
		}
		return scoreCache.get(indices[0], indices[1]);
	}

	/**
	 * Specifies how to combine a set of pairwise scores into a
	 * cluster-wise score.
	 *
	 * @author "Aron Culotta" 
	 * @version 1.0
	 * @since 1.0
	 */
	public static interface CombiningStrategy {
		public double combine (double[] scores);
	}

	public static class Average implements CombiningStrategy, Serializable {
		public double combine (double[] scores) {
			return MatrixOps.mean(scores);
		}		
		// SERIALIZATION

		private static final long serialVersionUID = 1;

		private static final int CURRENT_SERIAL_VERSION = 1;

		private void writeObject(ObjectOutputStream out) throws IOException {
			out.defaultWriteObject();
			out.writeInt(CURRENT_SERIAL_VERSION);
		}

		private void readObject(ObjectInputStream in) throws IOException,
				ClassNotFoundException {
			in.defaultReadObject();
			int version = in.readInt();
		}	
	}

	public static class Minimum implements CombiningStrategy, Serializable {
		public double combine (double[] scores) {
			return MatrixOps.min(scores);
		}		
		// SERIALIZATION

		private static final long serialVersionUID = 1;

		private static final int CURRENT_SERIAL_VERSION = 1;

		private void writeObject(ObjectOutputStream out) throws IOException {
			out.defaultWriteObject();
			out.writeInt(CURRENT_SERIAL_VERSION);
		}

		private void readObject(ObjectInputStream in) throws IOException,
				ClassNotFoundException {
			in.defaultReadObject();
			int version = in.readInt();
		}	
	}

	public static class Maximum implements CombiningStrategy, Serializable {
		public double combine (double[] scores) {
			return MatrixOps.max(scores);
		}		
		// SERIALIZATION

		private static final long serialVersionUID = 1;

		private static final int CURRENT_SERIAL_VERSION = 1;

		private void writeObject(ObjectOutputStream out) throws IOException {
			out.defaultWriteObject();
			out.writeInt(CURRENT_SERIAL_VERSION);
		}

		private void readObject(ObjectInputStream in) throws IOException,
				ClassNotFoundException {
			in.defaultReadObject();
			int version = in.readInt();
		}			
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy