All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.cluster.tui.Clusterings2Clusterings Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
package cc.mallet.cluster.tui;

import gnu.trove.TIntHashSet;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.Clusterings;
import cc.mallet.cluster.util.ClusterUtils;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;

// In progress
public class Clusterings2Clusterings {

	private static Logger logger =
			MalletLogger.getLogger(Clusterings2Clusterings.class.getName());

	public static void main (String[] args) {
		CommandOption
									.setSummary(Clusterings2Clusterings.class,
															"A tool to manipulate Clusterings.");
		CommandOption.process(Clusterings2Clusterings.class, args);

		Clusterings clusterings = null;
		try {
			ObjectInputStream iis =
					new ObjectInputStream(new FileInputStream(inputFile.value));
			clusterings = (Clusterings) iis.readObject();
		} catch (Exception e) {
			System.err.println("Exception reading clusterings from "
													+ inputFile.value + " " + e);
			e.printStackTrace();
		}

		logger.info("number clusterings=" + clusterings.size());

		// Prune clusters based on size.
		if (minClusterSize.value > 1) {
			for (int i = 0; i < clusterings.size(); i++) {
				Clustering clustering = clusterings.get(i);
				InstanceList oldInstances = clustering.getInstances();
				Alphabet alph = oldInstances.getDataAlphabet();
				LabelAlphabet lalph = (LabelAlphabet) oldInstances.getTargetAlphabet();
				if (alph == null) alph = new Alphabet();
				if (lalph == null) lalph = new LabelAlphabet();
				Pipe noop = new Noop(alph, lalph);
				InstanceList newInstances = new InstanceList(noop);
				for (int j = 0; j < oldInstances.size(); j++) {
					int label = clustering.getLabel(j);
					Instance instance = oldInstances.get(j);
					if (clustering.size(label) >= minClusterSize.value) 
						newInstances.add(noop.pipe(new Instance(instance.getData(), lalph.lookupLabel(new Integer(label)), instance.getName(), instance.getSource())));
				}
				clusterings.set(i, createSmallerClustering(newInstances));
			}
			if (outputPrefixFile.value != null) {
				try {
					ObjectOutputStream oos =
						new ObjectOutputStream(new FileOutputStream(outputPrefixFile.value));
					oos.writeObject(clusterings);
					oos.close();
				} catch (Exception e) {
					logger.warning("Exception writing clustering to file " + outputPrefixFile.value												+ " " + e);
					e.printStackTrace();
				}
			}
		}
		
		
		// Split into training/testing
		if (trainingProportion.value > 0) {
			if (clusterings.size() > 1) 
				throw new IllegalArgumentException("Expect one clustering to do train/test split, not " + clusterings.size());
			Clustering clustering = clusterings.get(0);
			int targetTrainSize = (int)(trainingProportion.value * clustering.getNumInstances());
			TIntHashSet clustersSampled = new TIntHashSet();
			Randoms random = new Randoms(123);
			LabelAlphabet lalph = new LabelAlphabet();
			InstanceList trainingInstances = new InstanceList(new Noop(null, lalph));
			while (trainingInstances.size() < targetTrainSize) {
				int cluster = random.nextInt(clustering.getNumClusters());
				if (!clustersSampled.contains(cluster)) {
					clustersSampled.add(cluster);
					InstanceList instances = clustering.getCluster(cluster);
					for (int i = 0; i < instances.size(); i++) {
						Instance inst = instances.get(i);
						trainingInstances.add(new Instance(inst.getData(), lalph.lookupLabel(new Integer(cluster)), inst.getName(), inst.getSource()));
					}
				}
			}
			trainingInstances.shuffle(random);
			Clustering trainingClustering = createSmallerClustering(trainingInstances);
			
			InstanceList testingInstances = new InstanceList(null, lalph);
			for (int i = 0; i < clustering.getNumClusters(); i++) {
				if (!clustersSampled.contains(i)) {
					InstanceList instances = clustering.getCluster(i);
					for (int j = 0; j < instances.size(); j++) {
						Instance inst = instances.get(j);
						testingInstances.add(new Instance(inst.getData(), lalph.lookupLabel(new Integer(i)), inst.getName(), inst.getSource()));
					}					
				}
			}
			testingInstances.shuffle(random);
			Clustering testingClustering = createSmallerClustering(testingInstances);
			logger.info(outputPrefixFile.value + ".train : " + trainingClustering.getNumClusters() + " objects");
			logger.info(outputPrefixFile.value + ".test : " + testingClustering.getNumClusters() + " objects");
			if (outputPrefixFile.value != null) {
				try {
					ObjectOutputStream oos =
						new ObjectOutputStream(new FileOutputStream(new File(outputPrefixFile.value + ".train")));
					oos.writeObject(new Clusterings(new Clustering[]{trainingClustering}));
					oos.close();
					oos =
						new ObjectOutputStream(new FileOutputStream(new File(outputPrefixFile.value + ".test")));
					oos.writeObject(new Clusterings(new Clustering[]{testingClustering}));
					oos.close();					
				} catch (Exception e) {
					logger.warning("Exception writing clustering to file " + outputPrefixFile.value												+ " " + e);
					e.printStackTrace();
				}
			}
			
		}
	}

	private static Clustering createSmallerClustering (InstanceList instances) {
		Clustering c = ClusterUtils.createSingletonClustering(instances);
		return ClusterUtils.mergeInstancesWithSameLabel(c);
	}
	
	static CommandOption.String inputFile =
			new CommandOption.String(
																Clusterings2Clusterings.class,
																"input",
																"FILENAME",
																true,
																"text.clusterings",
																"The filename from which to read the list of instances.",
																null);

	static CommandOption.String outputPrefixFile =
		new CommandOption.String(
															Clusterings2Clusterings.class,
															"output-prefix",
															"FILENAME",
															false,
															"text.clusterings",
															"The filename prefix to write output. Suffices 'train' and 'test' appended.",
															null);

	static CommandOption.Integer minClusterSize = 
			new CommandOption.Integer(Clusterings2Clusterings.class,
			                          "min-cluster-size",
			                          "INTEGER",			                          	
			                          false,
			                          1,
			                          "Remove clusters with fewer than this many Instances.",
			                          null);


	static CommandOption.Double trainingProportion = 
		new CommandOption.Double(Clusterings2Clusterings.class,
		                          "training-proportion",
		                          "DOUBLE",			                          	
		                          false,
		                          0.0,
		                          "Split into training and testing, with this percentage of instances reserved for training.",
		                          null);
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy