weka.classifiers.trees.lmt.ResidualSplit Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* ResidualSplit.java
* Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees.lmt;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.Distribution;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
/**
* Helper class for logistic model trees (weka.classifiers.trees.lmt.LMT) to implement the
* splitting criterion based on residuals of the LogitBoost algorithm.
*
* @author Niels Landwehr
* @version $Revision: 8034 $
*/
public class ResidualSplit
extends ClassifierSplitModel{
/** for serialization */
private static final long serialVersionUID = -5055883734183713525L;
/**The attribute selected for the split*/
protected Attribute m_attribute;
/**The index of the attribute selected for the split*/
protected int m_attIndex;
/**Number of instances in the set*/
protected int m_numInstances;
/**Number of classed*/
protected int m_numClasses;
/**The set of instances*/
protected Instances m_data;
/**The Z-values (LogitBoost response) for the set of instances*/
protected double[][] m_dataZs;
/**The LogitBoost-weights for the set of instances*/
protected double[][] m_dataWs;
/**The split point (for numeric attributes)*/
protected double m_splitPoint;
/**
*Creates a split object
*@param attIndex the index of the attribute to split on
*/
public ResidualSplit(int attIndex) {
m_attIndex = attIndex;
}
/**
* Builds the split.
* Needs the Z/W values of LogitBoost for the set of instances.
*/
public void buildClassifier(Instances data, double[][] dataZs, double[][] dataWs)
throws Exception {
m_numClasses = data.numClasses();
m_numInstances = data.numInstances();
if (m_numInstances == 0) throw new Exception("Can't build split on 0 instances");
//save data/Zs/Ws
m_data = data;
m_dataZs = dataZs;
m_dataWs = dataWs;
m_attribute = data.attribute(m_attIndex);
//determine number of subsets and split point for numeric attributes
if (m_attribute.isNominal()) {
m_splitPoint = 0.0;
m_numSubsets = m_attribute.numValues();
} else {
getSplitPoint();
m_numSubsets = 2;
}
//create distribution for data
m_distribution = new Distribution(data, this);
}
/**
* Selects split point for numeric attribute.
*/
protected boolean getSplitPoint() throws Exception{
//compute possible split points
double[] splitPoints = new double[m_numInstances];
int numSplitPoints = 0;
Instances sortedData = new Instances(m_data);
sortedData.sort(sortedData.attribute(m_attIndex));
double last, current;
last = sortedData.instance(0).value(m_attIndex);
for (int i = 0; i < m_numInstances - 1; i++) {
current = sortedData.instance(i+1).value(m_attIndex);
if (!Utils.eq(current, last)){
splitPoints[numSplitPoints++] = (last + current) / 2.0;
}
last = current;
}
//compute entropy for all split points
double[] entropyGain = new double[numSplitPoints];
for (int i = 0; i < numSplitPoints; i++) {
m_splitPoint = splitPoints[i];
entropyGain[i] = entropyGain();
}
//get best entropy gain
int bestSplit = -1;
double bestGain = -Double.MAX_VALUE;
for (int i = 0; i < numSplitPoints; i++) {
if (entropyGain[i] > bestGain) {
bestGain = entropyGain[i];
bestSplit = i;
}
}
if (bestSplit < 0) return false;
m_splitPoint = splitPoints[bestSplit];
return true;
}
/**
* Computes entropy gain for current split.
*/
public double entropyGain() throws Exception{
int numSubsets;
if (m_attribute.isNominal()) {
numSubsets = m_attribute.numValues();
} else {
numSubsets = 2;
}
double[][][] splitDataZs = new double[numSubsets][][];
double[][][] splitDataWs = new double[numSubsets][][];
//determine size of the subsets
int[] subsetSize = new int[numSubsets];
for (int i = 0; i < m_numInstances; i++) {
int subset = whichSubset(m_data.instance(i));
if (subset < 0) throw new Exception("ResidualSplit: no support for splits on missing values");
subsetSize[subset]++;
}
for (int i = 0; i < numSubsets; i++) {
splitDataZs[i] = new double[subsetSize[i]][];
splitDataWs[i] = new double[subsetSize[i]][];
}
int[] subsetCount = new int[numSubsets];
//sort Zs/Ws into subsets
for (int i = 0; i < m_numInstances; i++) {
int subset = whichSubset(m_data.instance(i));
splitDataZs[subset][subsetCount[subset]] = m_dataZs[i];
splitDataWs[subset][subsetCount[subset]] = m_dataWs[i];
subsetCount[subset]++;
}
//calculate entropy gain
double entropyOrig = entropy(m_dataZs, m_dataWs);
double entropySplit = 0.0;
for (int i = 0; i < numSubsets; i++) {
entropySplit += entropy(splitDataZs[i], splitDataWs[i]);
}
return entropyOrig - entropySplit;
}
/**
* Helper function to compute entropy from Z/W values.
*/
protected double entropy(double[][] dataZs, double[][] dataWs){
//method returns entropy * sumOfWeights
double entropy = 0.0;
int numInstances = dataZs.length;
for (int j = 0; j < m_numClasses; j++) {
//compute mean for class
double m = 0.0;
double sum = 0.0;
for (int i = 0; i < numInstances; i++) {
m += dataZs[i][j] * dataWs[i][j];
sum += dataWs[i][j];
}
m /= sum;
//sum up entropy for class
for (int i = 0; i < numInstances; i++) {
entropy += dataWs[i][j] * Math.pow(dataZs[i][j] - m,2);
}
}
return entropy;
}
/**
* Checks if there are at least 2 subsets that contain >= minNumInstances.
*/
public boolean checkModel(int minNumInstances){
//checks if there are at least 2 subsets that contain >= minNumInstances
int count = 0;
for (int i = 0; i < m_distribution.numBags(); i++) {
if (m_distribution.perBag(i) >= minNumInstances) count++;
}
return (count >= 2);
}
/**
* Returns name of splitting attribute (left side of condition).
*/
public final String leftSide(Instances data) {
return data.attribute(m_attIndex).name();
}
/**
* Prints the condition satisfied by instances in a subset.
*/
public final String rightSide(int index,Instances data) {
StringBuffer text;
text = new StringBuffer();
if (data.attribute(m_attIndex).isNominal())
text.append(" = "+
data.attribute(m_attIndex).value(index));
else
if (index == 0)
text.append(" <= "+
Utils.doubleToString(m_splitPoint,6));
else
text.append(" > "+
Utils.doubleToString(m_splitPoint,6));
return text.toString();
}
public final int whichSubset(Instance instance)
throws Exception {
if (instance.isMissing(m_attIndex))
return -1;
else{
if (instance.attribute(m_attIndex).isNominal())
return (int)instance.value(m_attIndex);
else
if (Utils.smOrEq(instance.value(m_attIndex),m_splitPoint))
return 0;
else
return 1;
}
}
/** Method not in use*/
public void buildClassifier(Instances data) {
//method not in use
}
/**Method not in use*/
public final double [] weights(Instance instance){
//method not in use
return null;
}
/**Method not in use*/
public final String sourceExpression(int index, Instances data) {
//method not in use
return "";
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 8034 $");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy