weka.classifiers.meta.RandomCommittee Maven / Gradle / Ivy
Show all versions of weka-stable Show documentation
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* RandomCommittee.java
* Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.meta;
import java.util.Random;
import java.util.ArrayList;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.PartitionGenerator;
/**
* Class for building an ensemble of randomizable base classifiers. Each base classifiers is built using a different random number seed (but based one the same data). The final prediction is a straight average of the predictions generated by the individual base classifiers.
*
*
* Valid options are:
*
* -S <num>
* Random number seed.
* (default 1)
*
* -I <num>
* Number of iterations.
* (default 10)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.trees.RandomTree)
*
*
* Options specific to classifier weka.classifiers.trees.RandomTree:
*
*
* -K <number of attributes>
* Number of attributes to randomly investigate
* (<1 = int(log(#attributes)+1)).
*
* -M <minimum number of instances>
* Set minimum number of instances per leaf.
*
* -S <num>
* Seed for random number generator.
* (default 1)
*
* -depth <num>
* The maximum depth of the tree, 0 for unlimited.
* (default 0)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
* Options after -- are passed to the designated classifier.
*
* @author Eibe Frank ([email protected])
* @version $Revision: 11461 $
*/
public class RandomCommittee
extends RandomizableParallelIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler, PartitionGenerator {
/** for serialization */
static final long serialVersionUID = -9204394360557300093L;
/** training data */
protected Instances m_data;
/**
* Constructor.
*/
public RandomCommittee() {
m_Classifier = new weka.classifiers.trees.RandomTree();
}
/**
* String describing default classifier.
*
* @return the default classifier classname
*/
protected String defaultClassifierString() {
return "weka.classifiers.trees.RandomTree";
}
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for building an ensemble of randomizable base classifiers. Each "
+ "base classifiers is built using a different random number seed (but based "
+ "one the same data). The final prediction is a straight average of the "
+ "predictions generated by the individual base classifiers.";
}
/**
* Builds the committee of randomizable classifiers.
*
* @param data the training data to be used for generating the
* bagged classifier.
* @exception Exception if the classifier could not be built successfully
*/
public void buildClassifier(Instances data) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(data);
// get fresh instances
m_data = new Instances(data);
super.buildClassifier(m_data);
if (!(m_Classifier instanceof Randomizable)) {
throw new IllegalArgumentException("Base learner must implement Randomizable!");
}
m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, m_NumIterations);
Random random = m_data.getRandomNumberGenerator(m_Seed);
// Resample data based on weights if base learner can't handle weights
if (!(m_Classifier instanceof WeightedInstancesHandler)) {
m_data = m_data.resampleWithWeights(random);
}
for (int j = 0; j < m_Classifiers.length; j++) {
// Set the random number seed for the current classifier.
((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());
// Build the classifier.
// m_Classifiers[j].buildClassifier(m_data);
}
buildClassifiers();
// save memory
m_data = null;
}
/**
* Returns a training set for a particular iteration.
*
* @param iteration the number of the iteration for the requested training set.
* @return the training set for the supplied iteration number
* @throws Exception if something goes wrong when generating a training set.
*/
protected synchronized Instances getTrainingSet(int iteration) throws Exception {
// we don't manipulate the training data in any way.
return m_data;
}
/**
* Calculates the class membership probabilities for the given test
* instance.
*
* @param instance the instance to be classified
* @return preedicted class probability distribution
* @exception Exception if distribution can't be computed successfully
*/
public double[] distributionForInstance(Instance instance) throws Exception {
double [] sums = new double [instance.numClasses()], newProbs;
double numPreds = 0;
for (int i = 0; i < m_NumIterations; i++) {
if (instance.classAttribute().isNumeric() == true) {
double pred = m_Classifiers[i].classifyInstance(instance);
if (!Utils.isMissingValue(pred)) {
sums[0] += pred;
numPreds++;
}
} else {
newProbs = m_Classifiers[i].distributionForInstance(instance);
for (int j = 0; j < newProbs.length; j++)
sums[j] += newProbs[j];
}
}
if (instance.classAttribute().isNumeric() == true) {
if (numPreds == 0) {
sums[0] = Utils.missingValue();
} else {
sums[0] /= numPreds;
}
return sums;
} else if (Utils.eq(Utils.sum(sums), 0)) {
return sums;
} else {
Utils.normalize(sums);
return sums;
}
}
/**
* Returns description of the committee.
*
* @return description of the committee as a string
*/
public String toString() {
if (m_Classifiers == null) {
return "RandomCommittee: No model built yet.";
}
StringBuffer text = new StringBuffer();
text.append("All the base classifiers: \n\n");
for (int i = 0; i < m_Classifiers.length; i++)
text.append(m_Classifiers[i].toString() + "\n\n");
return text.toString();
}
/**
* Builds the classifier to generate a partition.
*/
public void generatePartition(Instances data) throws Exception {
if (m_Classifier instanceof PartitionGenerator)
buildClassifier(data);
else throw new Exception("Classifier: " + getClassifierSpec()
+ " cannot generate a partition");
}
/**
* Computes an array that indicates leaf membership
*/
public double[] getMembershipValues(Instance inst) throws Exception {
if (m_Classifier instanceof PartitionGenerator) {
ArrayList al = new ArrayList();
int size = 0;
for (int i = 0; i < m_Classifiers.length; i++) {
double[] r = ((PartitionGenerator)m_Classifiers[i]).
getMembershipValues(inst);
size += r.length;
al.add(r);
}
double[] values = new double[size];
int pos = 0;
for (double[] v: al) {
System.arraycopy(v, 0, values, pos, v.length);
pos += v.length;
}
return values;
} else throw new Exception("Classifier: " + getClassifierSpec()
+ " cannot generate a partition");
}
/**
* Returns the number of elements in the partition.
*/
public int numElements() throws Exception {
if (m_Classifier instanceof PartitionGenerator) {
int size = 0;
for (int i = 0; i < m_Classifiers.length; i++) {
size += ((PartitionGenerator)m_Classifiers[i]).numElements();
}
return size;
} else throw new Exception("Classifier: " + getClassifierSpec()
+ " cannot generate a partition");
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 11461 $");
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
runClassifier(new RandomCommittee(), argv);
}
}