All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.cluster.tui.Clusterings2Clusterer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.cluster.tui;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.logging.Logger;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.cluster.Clusterer;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.Clusterings;
import cc.mallet.cluster.GreedyAgglomerativeByDensity;
import cc.mallet.cluster.Record;
import cc.mallet.cluster.evaluate.AccuracyEvaluator;
import cc.mallet.cluster.evaluate.BCubedEvaluator;
import cc.mallet.cluster.evaluate.ClusteringEvaluator;
import cc.mallet.cluster.evaluate.ClusteringEvaluators;
import cc.mallet.cluster.evaluate.MUCEvaluator;
import cc.mallet.cluster.evaluate.PairF1Evaluator;
import cc.mallet.cluster.iterator.PairSampleIterator;
import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor;
import cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator;
import cc.mallet.cluster.neighbor_evaluator.PairwiseEvaluator;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.PropertyList;
import cc.mallet.util.Randoms;
import cc.mallet.util.Strings;

//In progress
public class Clusterings2Clusterer {

	private static Logger logger =
		MalletLogger.getLogger(Clusterings2Clusterer.class.getName());

	public static void main(String[] args) throws Exception {

		CommandOption.setSummary(Clusterings2Clusterer.class,
				"A tool to train and test a Clusterer.");
		CommandOption.process(Clusterings2Clusterer.class, args);

		// TRAIN

		Randoms random = new Randoms(123);
		Clusterer clusterer = null;
		if (!loadClusterer.value.exists()) {
			Clusterings training = readClusterings(trainingFile.value);

			Alphabet fieldAlphabet = ((Record) training.get(0).getInstances()
					.get(0).getData()).fieldAlphabet();

			Pipe pipe = new ClusteringPipe(string2ints(exactMatchFields.value, fieldAlphabet), 
					                       string2ints(approxMatchFields.value, fieldAlphabet), 
					                       string2ints(substringMatchFields.value, fieldAlphabet));

			InstanceList trainingInstances = new InstanceList(pipe);
			for (int i = 0; i < training.size(); i++) {
				PairSampleIterator iterator = new PairSampleIterator(training
						.get(i), random, 0.5, training.get(i).getNumInstances());
				while(iterator.hasNext()) {
					Instance inst = iterator.next();
					trainingInstances.add(pipe.pipe(inst));
				}
			}
			logger.info("generated " + trainingInstances.size()
					+ " training instances");
			Classifier classifier = new MaxEntTrainer().train(trainingInstances);
			logger.info("InfoGain:\n");
			new InfoGain(trainingInstances).printByRank(System.out);
			logger.info("pairwise training accuracy="
					+ new Trial(classifier, trainingInstances).getAccuracy());
			NeighborEvaluator neval = new PairwiseEvaluator(classifier, "YES",
					new PairwiseEvaluator.Average(), true);				
			clusterer = new GreedyAgglomerativeByDensity(
					training.get(0).getInstances().getPipe(), neval, 0.5, false,
					random);
			training = null;
			trainingInstances = null;
		} else {
			ObjectInputStream ois = new ObjectInputStream(new FileInputStream(loadClusterer.value));
			clusterer = (Clusterer) ois.readObject();
		}

		// TEST

		Clusterings testing = readClusterings(testingFile.value);
		ClusteringEvaluator evaluator = (ClusteringEvaluator) clusteringEvaluatorOption.value;
		if (evaluator == null)
			evaluator = new ClusteringEvaluators(
					new ClusteringEvaluator[] { new BCubedEvaluator(),
							new PairF1Evaluator(), new MUCEvaluator(), new AccuracyEvaluator() });
		ArrayList predictions = new ArrayList();
		for (int i = 0; i < testing.size(); i++) {
			Clustering clustering = testing.get(i);
			Clustering predicted = clusterer.cluster(clustering.getInstances());
			predictions.add(predicted);
			logger.info(evaluator.evaluate(clustering, predicted));
		}
		logger.info(evaluator.evaluateTotals());
		
		// WRITE OUTPUT

		ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(saveClusterer.value));
		oos.writeObject(clusterer);
		oos.close();
		
		if (outputClusterings.value != null) {
			BufferedWriter writer = new BufferedWriter(new FileWriter(new File(outputClusterings.value)));
			writer.write(predictions.toString());
			writer.flush();
			writer.close();
		}
	}

	public static int[] string2ints(String[] ss, Alphabet alph) {
		int[] ret = new int[ss.length];
		for (int i = 0; i < ss.length; i++)
			ret[i] = alph.lookupIndex(ss[i]);
		return ret;
	}

	public static Clusterings readClusterings(String f) throws Exception {
		ObjectInputStream ois = new ObjectInputStream(new FileInputStream(
				new File(f)));
		return (Clusterings) ois.readObject();
	}

	static CommandOption.File loadClusterer = new CommandOption.File(
			Clusterings2Clusterer.class,
			"load-clusterer",
			"FILE",
			false,
			null,
			"The file from which to read the clusterer.",
			null);

	static CommandOption.File saveClusterer = new CommandOption.File(
			Clusterings2Clusterer.class,
			"save-clusterer",
			"FILE",
			false,
			new File("clusterer.mallet"),			
			"The filename in which to write the clusterer after it has been trained.",
			null);

	static CommandOption.String outputClusterings = new CommandOption.String(
			Clusterings2Clusterer.class,
			"output-clusterings",
			"FILENAME",
			false,
			"predictions",
			"The filename in which to write the predicted clusterings.",
			null);

	static CommandOption.String trainingFile = new CommandOption.String(
			Clusterings2Clusterer.class,
			"train",
			"FILENAME",
			false,
			"text.clusterings.train",
			"Read the training set Clusterings from this file. "
					+ "If this is specified, the input file parameter is ignored",
			null);

	static CommandOption.String testingFile = new CommandOption.String(
			Clusterings2Clusterer.class,
			"test",
			"FILENAME",
			false,
			"text.clusterings.test",
			"Read the test set Clusterings from this file. "
					+ "If this option is specified, the training-file parameter must be specified and "
					+ " the input-file parameter is ignored", null);

	 static CommandOption.Object clusteringEvaluatorOption = new CommandOption.Object(
			Clusterings2Clusterer.class, "clustering-evaluator", "CONSTRUCTOR",
			true, null,
			"Java code for constructing a ClusteringEvaluator object", null);

	static CommandOption.SpacedStrings exactMatchFields = new CommandOption.SpacedStrings(
			Clusterings2Clusterer.class, "exact-match-fields", "STRING...",
			false, null,
			"The field names to be checked for exactly matching values", null);

	static CommandOption.SpacedStrings approxMatchFields = new CommandOption.SpacedStrings(
			Clusterings2Clusterer.class, "approx-match-fields", "STRING...",
			false, null,
			"The field names to be checked for approx matching values", null);

	static CommandOption.SpacedStrings substringMatchFields = new CommandOption.SpacedStrings(
			Clusterings2Clusterer.class, "substring-match-fields", "STRING...",
			false, null,
			"The field names to be checked for substring matching values. Note that values fewer than 3 characters are ignored.", null);

	
	
	public static class ClusteringPipe extends Pipe {
		private static final long serialVersionUID = 1L;

		int[] exactMatchFields;

		int[] approxMatchFields;

		int[] substringMatchFields;


		double approxMatchThreshold;

		public ClusteringPipe(int[] exactMatchFields, int[] approxMatchFields,
				int[] substringMatchFields) {
			super(new Alphabet(), new LabelAlphabet());
			this.exactMatchFields = exactMatchFields;
			this.approxMatchFields = approxMatchFields;
			this.substringMatchFields = substringMatchFields;
		}

		private Record[] array2Records(int[] a, InstanceList list) {
			ArrayList records = new ArrayList();
			for (int i = 0; i < a.length; i++)
				records.add((Record) list.get(a[i]).getData());
			return (Record[]) records.toArray(new Record[] {});
		}

		public Instance pipe(Instance carrier) {
			AgglomerativeNeighbor neighbor = (AgglomerativeNeighbor) carrier
					.getData();
			Clustering original = neighbor.getOriginal();
			int[] cluster1 = neighbor.getOldClusters()[0];
			int[] cluster2 = neighbor.getOldClusters()[1];
			InstanceList list = original.getInstances();
			int[] mergedIndices = neighbor.getNewCluster();
			Record[] records = array2Records(mergedIndices, list);
			Alphabet fieldAlph = records[0].fieldAlphabet();
			Alphabet valueAlph = records[0].valueAlphabet();

			PropertyList features = null;
			features = addExactMatch(records, fieldAlph, valueAlph, features);
			features = addApproxMatch(records, fieldAlph, valueAlph, features);
			features = addSubstringMatch(records, fieldAlph, valueAlph, features);
			carrier
					.setData(new FeatureVector(getDataAlphabet(), features,
							true));

			LabelAlphabet ldict = (LabelAlphabet) getTargetAlphabet();
			String label = (original.getLabel(cluster1[0]) == original
					.getLabel(cluster2[0])) ? "YES" : "NO";
			carrier.setTarget(ldict.lookupLabel(label));			
			return carrier;
		}

		private PropertyList addExactMatch(Record[] records,
				Alphabet fieldAlph, Alphabet valueAlph, PropertyList features) {

			for (int fi = 0; fi < exactMatchFields.length; fi++) {
				int matches = 0;
				int comparisons = 0;
				for (int i = 0; i < records.length
						&& exactMatchFields.length > 0; i++) {
					FeatureVector valsi = records[i]
							.values(exactMatchFields[fi]);
					for (int j = i + 1; j < records.length && valsi != null; j++) {
						FeatureVector valsj = records[j]
								.values(exactMatchFields[fi]);
						if (valsj != null) {
							comparisons++;
							for (int ii = 0; ii < valsi.numLocations(); ii++) {
								if (valsj.contains(valueAlph.lookupObject(valsi
										.indexAtLocation(ii)))) {
									matches++;
									break;
								}
							}
						}
					}
					if (matches == comparisons && comparisons > 1)
						features = PropertyList.add(fieldAlph
								.lookupObject(exactMatchFields[fi])
								+ "_all_match", 1.0, features);
					if (matches > 0)
						features = PropertyList.add(fieldAlph
								.lookupObject(exactMatchFields[fi])
								+ "_exists_match", 1.0, features);
				}
			}
			return features;
		}

		private PropertyList addApproxMatch(Record[] records,
				Alphabet fieldAlph, Alphabet valueAlph, PropertyList features) {

			for (int fi = 0; fi < approxMatchFields.length; fi++) {
				int matches = 0;
				int comparisons = 0;
				for (int i = 0; i < records.length
						&& approxMatchFields.length > 0; i++) {
					FeatureVector valsi = records[i]
							.values(approxMatchFields[fi]);
					for (int j = i + 1; j < records.length && valsi != null; j++) {
						FeatureVector valsj = records[j]
								.values(approxMatchFields[fi]);
						if (valsj != null) {
							comparisons++;
							for (int ii = 0; ii < valsi.numLocations(); ii++) {
								String si = (String) valueAlph
										.lookupObject(valsi.indexAtLocation(ii));
								for (int jj = 0; jj < valsj.numLocations(); jj++) {
									String sj = (String) valueAlph
											.lookupObject(valsj
													.indexAtLocation(jj));
									if (Strings.levenshteinDistance(si, sj) < approxMatchThreshold) {
										matches++;
										break;
									}
								}
							}
						}
					}
					if (matches == comparisons && comparisons > 1)
						features = PropertyList.add(fieldAlph
								.lookupObject(approxMatchFields[fi])
								+ "_all_approx_match", 1.0, features);
					if (matches > 0)
						features = PropertyList.add(fieldAlph
								.lookupObject(approxMatchFields[fi])
								+ "_exists_approx_match", 1.0, features);
				}
			}
			return features;
		}

		private PropertyList addSubstringMatch(Record[] records,
				Alphabet fieldAlph, Alphabet valueAlph, PropertyList features) {

			for (int fi = 0; fi < substringMatchFields.length; fi++) {
				int matches = 0;
				int comparisons = 0;
				for (int i = 0; i < records.length
						&& substringMatchFields.length > 0; i++) {
					FeatureVector valsi = records[i]
							.values(substringMatchFields[fi]);
					for (int j = i + 1; j < records.length && valsi != null; j++) {
						FeatureVector valsj = records[j]
								.values(substringMatchFields[fi]);
						if (valsj != null) {
							comparisons++;
							for (int ii = 0; ii < valsi.numLocations(); ii++) {
								String si = (String) valueAlph
								.lookupObject(valsi.indexAtLocation(ii));
								if (si.length() < 2) break;
								for (int jj = 0; jj < valsj.numLocations(); jj++) {
									String sj = (String) valueAlph
											.lookupObject(valsj
													.indexAtLocation(jj));
									if (sj.length() > 2 && (si.contains(si) || sj.contains(si))) {
										matches++;
										break;
									}
								}
							}
						}
					}
					if (matches == comparisons && comparisons > 1)
						features = PropertyList.add(fieldAlph
								.lookupObject(exactMatchFields[fi])
								+ "_all_substring_match", 1.0, features);
					if (matches > 0)
						features = PropertyList.add(fieldAlph
								.lookupObject(exactMatchFields[fi])
								+ "_exists_substring_match", 1.0, features);
				}
			}
			return features;
		}

	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy