All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.trees.j48.C45Split Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    C45Split.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.trees.j48;

import java.util.Enumeration;

import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 * Class implementing a C4.5-type split on an attribute.
 * 
 * @author Eibe Frank ([email protected])
 * @version $Revision: 14911 $
 */
public class C45Split extends ClassifierSplitModel {

  /** for serialization */
  private static final long serialVersionUID = 3064079330067903161L;

  /** Desired number of branches. */
  protected int m_complexityIndex;

  /** Attribute to split on. */
  protected final int m_attIndex;

  /** Minimum number of objects in a split. */
  protected final int m_minNoObj;

  /** Use MDL correction? */
  protected final boolean m_useMDLcorrection;

  /** Value of split point. */
  protected double m_splitPoint;

  /** InfoGain of split. */
  protected double m_infoGain;

  /** GainRatio of split. */
  protected double m_gainRatio;

  /** The sum of the weights of the instances. */
  protected final double m_sumOfWeights;

  /** Number of split points. */
  protected int m_index;

  /** Static reference to splitting criterion. */
  protected static InfoGainSplitCrit infoGainCrit = new InfoGainSplitCrit();

  /** Static reference to splitting criterion. */
  protected static GainRatioSplitCrit gainRatioCrit = new GainRatioSplitCrit();

  /**
   * Initializes the split model.
   */
  public C45Split(int attIndex, int minNoObj, double sumOfWeights,
    boolean useMDLcorrection) {

    // Get index of attribute to split on.
    m_attIndex = attIndex;

    // Set minimum number of objects.
    m_minNoObj = minNoObj;

    // Set the sum of the weights
    m_sumOfWeights = sumOfWeights;

    // Whether to use the MDL correction for numeric attributes
    m_useMDLcorrection = useMDLcorrection;
  }

  /**
   * Creates a C4.5-type split on the given data. Assumes that none of the class
   * values is missing.
   * 
   * @exception Exception if something goes wrong
   */
  @Override
  public void buildClassifier(Instances trainInstances) throws Exception {

    // Initialize the remaining instance variables.
    m_numSubsets = 0;
    m_splitPoint = Double.MAX_VALUE;
    m_infoGain = 0;
    m_gainRatio = 0;

    // Different treatment for enumerated and numeric
    // attributes.
    if (trainInstances.attribute(m_attIndex).isNominal()) {
      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
      m_index = m_complexityIndex;
      handleEnumeratedAttribute(trainInstances);
    } else {
      m_complexityIndex = 2;
      m_index = 0;
      trainInstances.sort(trainInstances.attribute(m_attIndex));
      handleNumericAttribute(trainInstances);
    }
  }

  /**
   * Returns index of attribute for which split was generated.
   */
  public final int attIndex() {

    return m_attIndex;
  }

  /**
   * Returns the split point (numeric attribute only).
   * 
   * @return the split point used for a test on a numeric attribute
   */
  public double splitPoint() {
    return m_splitPoint;
  }

  /**
   * Gets class probability for instance.
   * 
   * @exception Exception if something goes wrong
   */
  @Override
  public final double classProb(int classIndex, Instance instance, int theSubset)
    throws Exception {

    if (theSubset <= -1) {
      double[] weights = weights(instance);
      if (weights == null) {
        return m_distribution.prob(classIndex);
      } else {
        double prob = 0;
        for (int i = 0; i < weights.length; i++) {
          prob += weights[i] * m_distribution.prob(classIndex, i);
        }
        return prob;
      }
    } else {
      if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
        return m_distribution.prob(classIndex, theSubset);
      } else {
        return m_distribution.prob(classIndex);
      }
    }
  }

  /**
   * Returns coding cost for split (used in rule learner).
   */
  @Override
  public final double codingCost() {

    return Utils.log2(m_index);
  }

  /**
   * Returns (C4.5-type) gain ratio for the generated split.
   */
  public final double gainRatio() {
    return m_gainRatio;
  }

  /**
   * Creates split on enumerated attribute.
   * 
   * @exception Exception if something goes wrong
   */
  private void handleEnumeratedAttribute(Instances trainInstances)
    throws Exception {

    Instance instance;

    m_distribution = new Distribution(m_complexityIndex,
      trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = enu.nextElement();
      if (!instance.isMissing(m_attIndex)) {
        m_distribution.add((int) instance.value(m_attIndex), instance);
      }
    }

    // Check if minimum number of Instances in at least two
    // subsets.
    if (m_distribution.check(m_minNoObj)) {
      m_numSubsets = m_complexityIndex;
      m_infoGain = infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights);
      m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,
        m_sumOfWeights, m_infoGain);
    }
  }

  /**
   * Creates split on numeric attribute.
   * 
   * @exception Exception if something goes wrong
   */
  private void handleNumericAttribute(Instances trainInstances)
    throws Exception {

    int firstMiss;
    int next = 1;
    int last = 0;
    int splitIndex = -1;
    double currentInfoGain;
    double defaultEnt;
    double minSplit;
    Instance instance;
    int i;

    // Current attribute is a numeric attribute.
    m_distribution = new Distribution(2, trainInstances.numClasses());

    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    i = 0;
    while (enu.hasMoreElements()) {
      instance = enu.nextElement();
      if (instance.isMissing(m_attIndex)) {
        break;
      }
      m_distribution.add(1, instance);
      i++;
    }
    firstMiss = i;

    // Compute minimum number of Instances required in each
    // subset.
    minSplit = 0.1 * (m_distribution.total()) / (trainInstances.numClasses());
    if (Utils.smOrEq(minSplit, m_minNoObj)) {
      minSplit = m_minNoObj;
    } else if (Utils.gr(minSplit, 25)) {
      minSplit = 25;
    }

    // Enough Instances with known values?
    if (Utils.sm(firstMiss, 2 * minSplit)) {
      return;
    }

    // Compute values of criteria for all possible split
    // indices.
    defaultEnt = infoGainCrit.oldEnt(m_distribution);
    while (next < firstMiss) {

      if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5 < trainInstances
        .instance(next).value(m_attIndex)) {

        // Move class values for all Instances up to next
        // possible split point.
        m_distribution.shiftRange(1, 0, trainInstances, last, next);

        // Check if enough Instances in each subset and compute
        // values for criteria.
        if (Utils.grOrEq(m_distribution.perBag(0), minSplit)
          && Utils.grOrEq(m_distribution.perBag(1), minSplit)) {
          currentInfoGain = infoGainCrit.splitCritValue(m_distribution,
            m_sumOfWeights, defaultEnt);
          if (Utils.gr(currentInfoGain, m_infoGain)) {
            m_infoGain = currentInfoGain;
            splitIndex = next - 1;
          }
          m_index++;
        }
        last = next;
      }
      next++;
    }

    // Was there any useful split?
    if (m_index == 0) {
      return;
    }

    // Compute modified information gain for best split.
    if (m_useMDLcorrection) {
      m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights);
    }
    if (Utils.smOrEq(m_infoGain, 0)) {
      return;
    }

    // Set instance variables' values to values for
    // best split.
    m_numSubsets = 2;
    m_splitPoint = (trainInstances.instance(splitIndex + 1).value(m_attIndex) + trainInstances
      .instance(splitIndex).value(m_attIndex)) / 2;

    // In case we have a numerical precision problem we need to choose the
    // smaller value
    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(
      m_attIndex)) {
      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
    }

    // Restore distributioN for best split.
    m_distribution = new Distribution(2, trainInstances.numClasses());
    m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);
    m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss);

    // Compute modified gain ratio for best split.
    m_gainRatio = gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights,
      m_infoGain);
  }

  /**
   * Returns (C4.5-type) information gain for the generated split.
   */
  public final double infoGain() {

    return m_infoGain;
  }

  /**
   * Prints left side of condition..
   * 
   * @param data training set.
   */
  @Override
  public final String leftSide(Instances data) {

    return data.attribute(m_attIndex).name();
  }

  /**
   * Prints the condition satisfied by instances in a subset.
   * 
   * @param index of subset
   * @param data training set.
   */
  @Override
  public final String rightSide(int index, Instances data) {

    StringBuffer text;

    text = new StringBuffer();
    if (data.attribute(m_attIndex).isNominal()) {
      text.append(" = " + data.attribute(m_attIndex).value(index));
    } else if (index == 0) {
      text.append(" <= " + Utils.doubleToString(m_splitPoint, 6));
    } else {
      text.append(" > " + Utils.doubleToString(m_splitPoint, 6));
    }
    return text.toString();
  }

  /**
   * Returns a string containing java source code equivalent to the test made at
   * this node. The instance being tested is called "i".
   * 
   * @param index index of the nominal value tested
   * @param data the data containing instance structure info
   * @return a value of type 'String'
   */
  @Override
  public final String sourceExpression(int index, Instances data) {

    StringBuffer expr = null;
    if (index < 0) {
      return "i[" + m_attIndex + "] == null";
    }
    if (data.attribute(m_attIndex).isNominal()) {
      expr = new StringBuffer("i[");
      expr.append(m_attIndex).append("]");
      expr.append(".equals(\"").append(data.attribute(m_attIndex).value(index))
        .append("\")");
    } else {
      expr = new StringBuffer("((Double) i[");
      expr.append(m_attIndex).append("])");
      if (index == 0) {
        expr.append(".doubleValue() <= ").append(m_splitPoint);
      } else {
        expr.append(".doubleValue() > ").append(m_splitPoint);
      }
    }
    return expr.toString();
  }

  /**
   * Sets split point to greatest value in given data smaller or equal to old
   * split point. (C4.5 does this for some strange reason).
   */
  public final void setSplitPoint(Instances allInstances) {

    double newSplitPoint = -Double.MAX_VALUE;

    if ((allInstances.attribute(m_attIndex).isNumeric()) && (m_numSubsets > 1)) {
      for (int i = 0; i < allInstances.numInstances(); i++) {
        Instance instance = allInstances.instance(i);
        double tempValue = instance.value(m_attIndex);
        if (!Utils.isMissingValue(tempValue)) {
          if ((tempValue > newSplitPoint) && (tempValue <= m_splitPoint)) {
            newSplitPoint = tempValue;
          }
        }
      }
      m_splitPoint = newSplitPoint;
    }
  }

  /**
   * Returns the minsAndMaxs of the index.th subset.
   */
  public final double[][] minsAndMaxs(Instances data, double[][] minsAndMaxs,
    int index) {

    double[][] newMinsAndMaxs = new double[data.numAttributes()][2];

    for (int i = 0; i < data.numAttributes(); i++) {
      newMinsAndMaxs[i][0] = minsAndMaxs[i][0];
      newMinsAndMaxs[i][1] = minsAndMaxs[i][1];
      if (i == m_attIndex) {
        if (data.attribute(m_attIndex).isNominal()) {
          newMinsAndMaxs[m_attIndex][1] = 1;
        } else {
          newMinsAndMaxs[m_attIndex][1 - index] = m_splitPoint;
        }
      }
    }

    return newMinsAndMaxs;
  }

  /**
   * Sets distribution associated with model.
   */
  @Override
  public void resetDistribution(Instances data) throws Exception {

    Instances insts = new Instances(data, data.numInstances());
    for (int i = 0; i < data.numInstances(); i++) {
      if (whichSubset(data.instance(i)) > -1) {
        insts.add(data.instance(i));
      }
    }
    Distribution newD = new Distribution(insts, this);
    newD.addInstWithUnknown(data, m_attIndex);
    m_distribution = newD;
  }

  /**
   * Returns weights if instance is assigned to more than one subset. Returns
   * null if instance is only assigned to one subset.
   */
  @Override
  public final double[] weights(Instance instance) {

    double[] weights;
    int i;

    if (instance.isMissing(m_attIndex)) {
      weights = new double[m_numSubsets];
      for (i = 0; i < m_numSubsets; i++) {
        weights[i] = m_distribution.perBag(i) / m_distribution.total();
      }
      return weights;
    } else {
      return null;
    }
  }

  /**
   * Returns index of subset instance is assigned to. Returns -1 if instance is
   * assigned to more than one subset.
   * 
   * @exception Exception if something goes wrong
   */
  @Override
  public final int whichSubset(Instance instance) throws Exception {

    if (instance.isMissing(m_attIndex)) {
      return -1;
    } else {
      if (instance.attribute(m_attIndex).isNominal()) {
        return (int) instance.value(m_attIndex);
      } else if (instance.value(m_attIndex) <= m_splitPoint) {
        return 0;
      } else {
        return 1;
      }
    }
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 14911 $");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy