All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.DecisionTree Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify;

import java.io.Serializable;
import java.util.logging.Logger;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.util.MalletLogger;

/**
   Decision Tree classifier.
   @author Andrew McCallum [email protected]
 */

public class DecisionTree extends Classifier implements Serializable //implements InduceFeatures
{
	private static final long serialVersionUID = 1L;
	
	private static Logger logger = MalletLogger.getLogger(DecisionTree.class.getName());
	Node root;
	
	public DecisionTree (Pipe instancePipe, DecisionTree.Node root)
	{
		super (instancePipe);
		this.root = root;
	}

	public Node getRoot ()
	{
		return root;
	}

	private Node getLeaf (Node node, FeatureVector fv)
	{
		if (node.child0 == null)
			return node;
		else if (fv.value (node.featureIndex) != 0)
			return getLeaf (node.child1, fv);
		else
			return getLeaf (node.child0, fv);
	}

	public Classification classify (Instance instance)
	{
		FeatureVector fv = (FeatureVector) instance.getData ();
		assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
		
		Node leaf = getLeaf (root, fv);
		return new Classification (instance, this, leaf.labeling);
	}

	// Entropy of 1.0 would say that it take "one bit" to indicate the correct class,
	// e.g. that there is a 50/50 split between two classes given a particular feature
	public double addFeaturesClassEntropyThreshold = 0.7;
	
	public void induceFeatures (InstanceList ilist, boolean withFeatureShrinkage, boolean inducePerClassFeatures)
	{
		if (inducePerClassFeatures) {
			int numClasses = ilist.getTargetAlphabet().size();
//			int numFeatures = ilist.getDataAlphabet().size();
			FeatureSelection[] pcfs = new FeatureSelection[numClasses];
			for (int j = 0; j < numClasses; j++)
				pcfs[j] = (FeatureSelection) ilist.getPerLabelFeatureSelection()[j].clone();
			for (int i = 0; i < ilist.size(); i++) {
				Object data = ilist.get(i).getData();
				AugmentableFeatureVector afv = (AugmentableFeatureVector) data;
				root.induceFeatures (afv, null, pcfs, ilist.getFeatureSelection(), ilist.getPerLabelFeatureSelection(),
														 withFeatureShrinkage, inducePerClassFeatures, addFeaturesClassEntropyThreshold);
			}
		} else {
			throw new UnsupportedOperationException ("Not yet implemented");
		}
	}

	
	public static class Node implements Serializable
	{
		private static final long serialVersionUID = 1L;
		
		int featureIndex;										// the feature on which the children (would) distinguish
		double infoGain;										// the information gain of splitting on this feature
		InstanceList ilist;
		Alphabet dictionary;
		double labelEntropy;								// the class label entropy of data in this (unsplit) node
		Labeling labeling;									// the class label distribution in the node (unsplit)
		Node parent, child0, child1;
		String name;

		// xxx Also calculate some sort of inverted entropy for feature induction,
		// in order to find the one class needs a new feature with a negative weight.

		public Node (InstanceList ilist, Node parent, FeatureSelection fs)
		{
			InfoGain ig = new InfoGain (ilist);
			this.featureIndex = ig.getMaxValuedIndexIn (fs);
			this.infoGain = ig.value(featureIndex);
			this.ilist = ilist;
			this.dictionary = ilist.getDataAlphabet();
			this.parent = parent;
			this.labeling = ig.getBaseLabelDistribution();
			this.labelEntropy = ig.getBaseEntropy();
			this.child0 = this.child1 = null;
		}

		/** The root has depth zero. */
		public int depth ()
		{
			int depth = 0;
			Node p = parent;
			while (p != null) {
				p = p.parent;
				depth++;
			}
			return depth;
		}

		public boolean isLeaf ()
		{
			return (child0 == null && child1 == null);
		}

		public boolean isRoot ()
		{
			return parent == null;
		}
		
		public Node getFeatureAbsentChild () { return child0; }
		public Node getFeaturePresentChild () { return child1; }
		public double getSplitInfoGain () { return infoGain; }
		public Object getSplitFeature () { return ilist.getDataAlphabet().lookupObject(featureIndex); }

		public void split (FeatureSelection fs)
		{
			if (ilist == null)
				throw new IllegalStateException ("Frozen.  Cannot split.");
			InstanceList ilist0 = new InstanceList (ilist.getPipe());
			InstanceList ilist1 = new InstanceList (ilist.getPipe());
			for (int i = 0; i < ilist.size(); i++) {
				Instance instance = ilist.get(i);
				FeatureVector fv = (FeatureVector) instance.getData ();
				// xxx What test should this be?  What to do with negative values?
					// Whatever is decided here should also go in InfoGain.calcInfoGains()
				if (fv.value (featureIndex) != 0) {
					//System.out.println ("list1 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i));
					ilist1.add (instance, ilist.getInstanceWeight(i));
				} else {
					//System.out.println ("list0 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i));
					ilist0.add (instance, ilist.getInstanceWeight(i));
				}
			}
			logger.info("child0="+ilist0.size()+" child1="+ilist1.size());
			child0 = new Node (ilist0, this, fs);
			child1 = new Node (ilist1, this, fs);
		}

		// Saves memory by allowing ilist to be garbage collected
		public void stopGrowth ()
		{
			if (child0 != null) {
				child0.stopGrowth();
				child1.stopGrowth();
			}
			ilist = null;
		}

		public void induceFeatures (AugmentableFeatureVector afv, 
																FeatureSelection featuresAlreadyThere,
																FeatureSelection[] perClassFeaturesAlreadyThere,
																FeatureSelection newFeatureSelection,
																FeatureSelection[] perClassNewFeatureSelection,
																boolean withInteriorNodes,
																boolean addPerClassFeatures,
																double classEntropyThreshold)
		{
			if (!isRoot() && (isLeaf() || withInteriorNodes) && labelEntropy < classEntropyThreshold) {
				String name = getName();
				logger.info("Trying to add feature "+name);
				//int conjunctionIndex = afv.getAlphabet().lookupIndex (name, false);
				if (addPerClassFeatures) {
					int classIndex = labeling.getBestIndex();
					if (!perClassFeaturesAlreadyThere[classIndex].contains (name)) {
						afv.add (name, 1.0);
						perClassNewFeatureSelection[classIndex].add (name);
					}
				} else {
					throw new UnsupportedOperationException ("Not yet implemented.");
				}
			}
			boolean featurePresent = afv.value (featureIndex) != 0;
			if (child0 != null && !featurePresent)
				child0.induceFeatures (afv, featuresAlreadyThere, perClassFeaturesAlreadyThere,
															 newFeatureSelection, perClassNewFeatureSelection,
															 withInteriorNodes, addPerClassFeatures, classEntropyThreshold);
			if (child1 != null && featurePresent)
				child1.induceFeatures (afv, featuresAlreadyThere, perClassFeaturesAlreadyThere,
															 newFeatureSelection, perClassNewFeatureSelection,
															 withInteriorNodes, addPerClassFeatures, classEntropyThreshold);
		}

		public String getName ()
		{
//			String prefix;
			if (parent == null)
				return "root";
			else if (parent.parent == null) {
				if (parent.getFeaturePresentChild() == this)
					return dictionary.lookupObject(parent.featureIndex).toString();
				else {
					assert (dictionary != null);
					assert (dictionary.lookupObject(parent.featureIndex) != null);
					return "!" + dictionary.lookupObject(parent.featureIndex).toString();
				}
			} else {
				if (parent.getFeaturePresentChild() == this)
					return parent.getName() + "&" + dictionary.lookupObject(parent.featureIndex).toString();
				else
					return parent.getName() + "&!" + dictionary.lookupObject(parent.featureIndex).toString();
			}
		}

		public void print ()
		{
			if (child0 == null)
				System.out.println (getName() + ": " + labeling.getBestLabel());
			else {
				child0.print();
				child1.print();
			}
		}
		
	}



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy