All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.C45 Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
 This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
 http://www.cs.umass.edu/~mccallum/mallet
 This software is provided under the terms of the Common Public License,
 version 1.0, as published by http://www.opensource.org.  For further
 information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Logger;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.GainRatio;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;


/**
 * A C4.5 Decision Tree classifier.
 *
 * @see C45Trainer
 * @author Gary Huang [email protected]
 */
public class C45 extends Classifier implements Boostable, Serializable
{
	private static Logger logger = MalletLogger.getLogger(C45.class.getName());
	Node m_root;
	
	public C45 (Pipe instancePipe, C45.Node root)
	{
		super (instancePipe);
		m_root = root;
	}
	
	public Node getRoot ()
	{
		return m_root;
	}
	
	private Node getLeaf (Node node, FeatureVector fv)
	{
		if (node.getLeftChild() == null && node.getRightChild() == null)
			return node;
		else if (fv.value(node.getGainRatio().getMaxValuedIndex()) <= node.getGainRatio().getMaxValuedThreshold())
			return getLeaf(node.getLeftChild(), fv);
		else
			return getLeaf(node.getRightChild(), fv);
	}
	
	public Classification classify (Instance instance)
	{
		FeatureVector fv = (FeatureVector) instance.getData ();
		assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
		
		Node leaf = getLeaf(m_root, fv);
		return new Classification (instance, this, leaf.getGainRatio().getBaseLabelDistribution());
	}
	
	/**
	 * Prune the tree using minimum description length
	 */
	public void prune()
	{
		getRoot().computeCostAndPrune();
	}
	
	/**
	 * @return the total number of nodes in this tree
	 */
	public int getSize()
	{
		Node root = getRoot();        
		if (root == null)
			return 0;
		return 1+root.getNumDescendants();
	}
	
	/**
	 * Prints the tree
	 */
	public void print()
	{
		if (getRoot() != null)
			getRoot().print();
	}
	
	public static class Node implements Serializable
	{
		private static final long serialVersionUID = 1L;
		
		GainRatio m_gainRatio;
		// the entire set of instances given to the root node
		InstanceList m_ilist;
		// indices of instances at this node
		int[] m_instIndices;
		// data vocabulary
		Alphabet m_dataDict;
		// mininum number of instances allowed in this node
		int m_minNumInsts;
		Node m_parent, m_leftChild, m_rightChild;
		
		public Node(InstanceList ilist, Node parent, int minNumInsts)
		{
			this(ilist, parent, minNumInsts, null);
		}
		
		public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices)
		{
			if (instIndices == null) {
				instIndices = new int[ilist.size()];
				for (int ii = 0; ii < instIndices.length; ii++)
					instIndices[ii] = ii;
			}
			m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts);
			m_ilist = ilist;
			m_instIndices = instIndices;
			m_dataDict = m_ilist.getDataAlphabet();
			m_minNumInsts = minNumInsts;
			m_parent = parent;
			m_leftChild = m_rightChild = null;
		}
		
		/** The root has depth zero. */
		public int depth ()
		{
			int depth = 0;
			Node p = m_parent;
			while (p != null) {
				p = p.m_parent;
				depth++;
			}
			return depth;
		}
		
		public int getSize() { return m_instIndices.length; }
		public boolean isLeaf() { return (m_leftChild == null && m_rightChild == null); }
		public boolean isRoot() { return m_parent == null; }
		public Node getParent() { return m_parent; }
		public Node getLeftChild() { return m_leftChild; }
		public Node getRightChild() { return m_rightChild; }
		public GainRatio getGainRatio() { return m_gainRatio; }
		public Object getSplitFeature() { return m_dataDict.lookupObject(m_gainRatio.getMaxValuedIndex()); }
		
		public InstanceList getInstances() 
		{ 
			InstanceList ret = new InstanceList(m_ilist.getPipe());
			for (int ii = 0; ii < m_instIndices.length; ii++)
				ret.add(m_ilist.get(m_instIndices[ii]));
			return ret; 
		}
		
		/** 
		 * Count the number of non-leaf descendant nodes
		 */
		public int getNumDescendants()
		{
			if (isLeaf())
				return 0;
			int count = 0;
			if (! getLeftChild().isLeaf())
				count += 1 + getLeftChild().getNumDescendants();
			if (! getRightChild().isLeaf())
				count += 1 + getRightChild().getNumDescendants();
			return count;
		}
		
		public void split()
		{
			if (m_ilist == null)
				throw new IllegalStateException ("Frozen.  Cannot split.");
			int numLeftChildren = 0;
			boolean[] toLeftChild = new boolean[m_instIndices.length];
			for (int i = 0; i < m_instIndices.length; i++) {
				Instance instance = m_ilist.get(m_instIndices[i]);
				FeatureVector fv = (FeatureVector) instance.getData();
				if (fv.value (m_gainRatio.getMaxValuedIndex()) <= m_gainRatio.getMaxValuedThreshold()) {
					toLeftChild[i] = true;
					numLeftChildren++;
				}
				else
					toLeftChild[i] = false;
			}
			logger.info("leftChild.size=" + numLeftChildren 
					+ " rightChild.size=" + (m_instIndices.length-numLeftChildren));
			int[] leftIndices = new int[numLeftChildren];
			int[] rightIndices = new int[m_instIndices.length - numLeftChildren];
			int li = 0, ri = 0;
			for (int i = 0; i < m_instIndices.length; i++) {
				if (toLeftChild[i])
					leftIndices[li++] = m_instIndices[i];
				else
					rightIndices[ri++] = m_instIndices[i];
			}
			m_leftChild = new Node(m_ilist, this, m_minNumInsts, leftIndices);
			m_rightChild = new Node(m_ilist, this, m_minNumInsts, rightIndices);
		}
		
		public double computeCostAndPrune()
		{
			double costS = getMDL();

			if (isLeaf())
				return costS + 1;

			double minCost1 = getLeftChild().computeCostAndPrune();
			double minCost2 = getRightChild().computeCostAndPrune();
			double costSplit = Math.log(m_gainRatio.getNumSplitPointsForBestFeature()) / GainRatio.log2;
			double minCostN = Math.min(costS+1, costSplit+1+minCost1+minCost2);

			if (Maths.almostEquals(minCostN, costS+1))
				m_leftChild = m_rightChild = null;

			return minCostN;
		}
		
		/**
		 * Calculates the minimum description length of this node, i.e., 
		 * the length of the binary encoding that describes the feature 
		 * and the split value used at this node
		 */
		public double getMDL()
		{
			int numClasses = m_ilist.getTargetAlphabet().size();
			double mdl = getSize() * getGainRatio().getBaseEntropy();
			mdl += ((numClasses-1) * Math.log(getSize() / 2.0)) / (2 * GainRatio.log2);
			double piPow = Math.pow(Math.PI, numClasses/2.0);
			double gammaVal = Maths.gamma(numClasses/2.0);
			mdl += Math.log(piPow/gammaVal) / GainRatio.log2;
			return mdl;
		}
		
		/**
		 * Saves memory by allowing ilist to be garbage collected
		 * (deletes this node's associated instance list)
		 */
		public void stopGrowth ()
		{
			if (m_leftChild != null)
				m_leftChild.stopGrowth();
			if (m_rightChild != null)
				m_rightChild.stopGrowth();	  
			m_ilist = null;
		}
		
		public String getName()
		{
			return getStringBufferName().toString();
		}
		
		public StringBuffer getStringBufferName()
		{
			StringBuffer sb = new StringBuffer();
			if (m_parent == null)
				return sb.append("root");
			else if (m_parent.getParent() == null) {
				sb.append("(\"");
				sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString());
				sb.append("\"");
				if (m_parent.getLeftChild() == this)
					sb.append(" <= ");
				else
					sb.append(" > ");
				sb.append(m_parent.getGainRatio().getMaxValuedThreshold());
				return sb.append(")");
			} 
			else {
				sb.append(m_parent.getStringBufferName());
				sb.append(" && (\"");
				sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString());
				sb.append("\"");
				if (m_parent.getLeftChild() == this)
					sb.append(" <= ");
				else 
					sb.append(" > ");
				sb.append(m_parent.getGainRatio().getMaxValuedThreshold());
				return sb.append(")");
			}
		}
		
		/**
		 * Prints the tree rooted at this node
		 */
		public void print()
		{
			print("");
		}
		
		public void print(String prefix)
		{	  
			if (isLeaf()) {
				int bestLabelIndex = getGainRatio().getBaseLabelDistribution().getBestIndex();
				int numMajorityLabel = (int) (getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * getSize());
				System.out.println("root:" + getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + getSize());
			}
			else {
				String featName = m_dataDict.lookupObject(getGainRatio().getMaxValuedIndex()).toString();
				double threshold = getGainRatio().getMaxValuedThreshold();
				System.out.print(prefix + "\"" + featName + "\" <= " + threshold + ":");
				if (m_leftChild.isLeaf()) {
					int bestLabelIndex = m_leftChild.getGainRatio().getBaseLabelDistribution().getBestIndex();
					int numMajorityLabel = (int) (m_leftChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_leftChild.getSize());
					System.out.println(m_leftChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_leftChild.getSize());
				}
				else {
					System.out.println();
					m_leftChild.print(prefix + "|    ");
				}	      
				System.out.print(prefix + "\"" + featName + "\" > " + threshold + ":");
				if (m_rightChild.isLeaf()) {
					int bestLabelIndex = m_rightChild.getGainRatio().getBaseLabelDistribution().getBestIndex();
					int numMajorityLabel = (int) (m_rightChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_rightChild.getSize());
					System.out.println(m_rightChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_rightChild.getSize());
				}
				else {
					System.out.println();
					m_rightChild.print(prefix + "|    ");
				}
			}
		}
		
	}
	
	// Serialization
	// serialVersionUID is overriden to prevent innocuous changes in this
	// class from making the serialization mechanism think the external
	// format has changed.
	
	private static final long serialVersionUID = 1;
	private static final int CURRENT_SERIAL_VERSION = 1;
	
	private void writeObject(ObjectOutputStream out) throws IOException
	{
		out.writeInt(CURRENT_SERIAL_VERSION);
		out.writeObject(getInstancePipe());
		out.writeObject(m_root);
	}
	
	private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
		int version = in.readInt();
		if (version != CURRENT_SERIAL_VERSION)
			throw new ClassNotFoundException("Mismatched C45 versions: wanted " +
					CURRENT_SERIAL_VERSION + ", got " +
					version);
		instancePipe = (Pipe) in.readObject();
		m_root = (Node) in.readObject();
		
	}
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy