All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.trees.ht.InfoGainSplitMetric Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    InfoGainSplitMetric.java
 *    Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.trees.ht;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

import weka.core.ContingencyTables;
import weka.core.Utils;

/**
 * Implements the info gain splitting criterion
 * 
 * @author Richard Kirkby ([email protected])
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 9720 $
 */
public class InfoGainSplitMetric extends SplitMetric implements Serializable {

  /**
   * For serialization
   */
  private static final long serialVersionUID = 2173840581308675428L;

  protected double m_minFracWeightForTwoBranches;

  public InfoGainSplitMetric(double minFracWeightForTwoBranches) {
    m_minFracWeightForTwoBranches = minFracWeightForTwoBranches;
  }

  @Override
  public double evaluateSplit(Map preDist,
      List> postDist) {

    double[] pre = new double[preDist.size()];
    int count = 0;
    for (Map.Entry e : preDist.entrySet()) {
      pre[count++] = e.getValue().m_weight;
    }

    double preEntropy = ContingencyTables.entropy(pre);

    double[] distWeights = new double[postDist.size()];
    double totalWeight = 0.0;
    for (int i = 0; i < postDist.size(); i++) {
      distWeights[i] = SplitMetric.sum(postDist.get(i));
      totalWeight += distWeights[i];
    }

    int fracCount = 0;
    for (double d : distWeights) {
      if (d / totalWeight > m_minFracWeightForTwoBranches) {
        fracCount++;
      }
    }

    if (fracCount < 2) {
      return Double.NEGATIVE_INFINITY;
    }

    double postEntropy = 0;
    for (int i = 0; i < postDist.size(); i++) {
      Map d = postDist.get(i);
      double[] post = new double[d.size()];
      count = 0;
      for (Map.Entry e : d.entrySet()) {
        post[count++] = e.getValue().m_weight;
      }
      postEntropy += distWeights[i] * ContingencyTables.entropy(post);
    }

    if (totalWeight > 0) {
      postEntropy /= totalWeight;
    }

    return preEntropy - postEntropy;
  }

  @Override
  public double getMetricRange(Map preDist) {

    int numClasses = preDist.size();
    if (numClasses < 2) {
      numClasses = 2;
    }

    return Utils.log2(numClasses);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy