All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.AdaBoostTrainer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.util.Random;
import java.util.logging.*;

import cc.mallet.types.*;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;

/**
 * This version of AdaBoost should be used only for binary classification.
 * Use AdaBoost.M2 for multi-class problems.
 *
 * 

Robert E. Schapire. * "A decision-theoretic generalization of on-line learning and * an application to boosting" * In Journal of Computer and System Sciences * http://www.cs.princeton.edu/~schapire/uncompress-papers.cgi/FreundSc95.ps * * @author Andrew McCallum [email protected] */ public class AdaBoostTrainer extends ClassifierTrainer { private static Logger logger = MalletLogger.getLogger(AdaBoostTrainer.class.getName()); private static int MAX_NUM_RESAMPLING_ITERATIONS = 10; ClassifierTrainer weakLearner; int numRounds; AdaBoost classifier; public AdaBoost getClassifier () { return classifier; } public AdaBoostTrainer (ClassifierTrainer weakLearner, int numRounds) { if (! (weakLearner instanceof Boostable)) throw new IllegalArgumentException ("weak learner not boostable"); if (numRounds <= 0) throw new IllegalArgumentException ("number of rounds must be positive"); this.weakLearner = weakLearner; this.numRounds = numRounds; } public AdaBoostTrainer (ClassifierTrainer weakLearner) { this (weakLearner, 100); } /** * Boosting method that resamples instances using their weights */ public AdaBoost train (InstanceList trainingList) { FeatureSelection selectedFeatures = trainingList.getFeatureSelection(); if (selectedFeatures != null) throw new UnsupportedOperationException("FeatureSelection not yet implemented."); java.util.Random random = new java.util.Random(); // Set the initial weights to be uniform double w = 1.0 / trainingList.size(); InstanceList trainingInsts = new InstanceList(trainingList.getPipe(), trainingList.size()); for (int i = 0; i < trainingList.size(); i++) trainingInsts.add(trainingList.get(i), w); boolean[] correct = new boolean[trainingInsts.size()]; int numClasses = trainingInsts.getTargetAlphabet().size(); if (numClasses != 2) logger.info("AdaBoostTrainer.train: WARNING: more than two classes"); Classifier[] weakLearners = new Classifier[numRounds]; double[] alphas = new double[numRounds]; InstanceList roundTrainingInsts = new InstanceList(trainingInsts.getPipe()); // Boosting iterations for (int round = 0; round < numRounds; round++) { logger.info("=========== AdaBoostTrainer round " + (round+1) + " begin"); // Keep resampling the training instances (using the distribution // of instance weights) on which to train the weak learner until // either we exceed the preset number of maximum iterations, or // the weak learner makes a non-zero error on trainingInsts // (this makes sure we sample at least some 'hard' instances). int resamplingIterations = 0; double err; do { err = 0; roundTrainingInsts = trainingInsts.sampleWithInstanceWeights(random); weakLearners[round] = weakLearner.train (roundTrainingInsts); // Calculate error for (int i = 0; i < trainingInsts.size(); i++) { Instance inst = trainingInsts.get(i); if (weakLearners[round].classify(inst).bestLabelIsCorrect()) correct[i] = true; else { correct[i] = false; err += trainingInsts.getInstanceWeight(i); } } resamplingIterations++; } while (Maths.almostEquals(err, 0) && resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS); // Stop boosting when error is too big or 0, // ignoring weak classifier trained this round if (Maths.almostEquals(err, 0) || err > 0.5) { logger.info("AdaBoostTrainer stopped at " + (round+1) + " / " + numRounds + " rounds: numClasses=" + numClasses + " error=" + err); // If we are in the first round, have to use the weak classifier in any case int numClassifiersToUse = (round == 0) ? 1 : round; if (round == 0) alphas[0] = 1; double[] betas = new double[numClassifiersToUse]; Classifier[] weakClassifiers = new Classifier[numClassifiersToUse]; System.arraycopy(alphas, 0, betas, 0, numClassifiersToUse); System.arraycopy(weakLearners, 0, weakClassifiers, 0, numClassifiersToUse); for (int i = 0; i < betas.length; i++) logger.info("AdaBoostTrainer weight[weakLearner[" + i + "]]=" + betas[i]); return new AdaBoost (roundTrainingInsts.getPipe(), weakClassifiers, betas); } // Calculate the weight to assign to this weak classifier // This formula is really designed for binary classifiers that don't // give a confidence score. Use AdaBoostMH for multi-class or // multi-labeled data. alphas[round] = Math.log((1 - err) / err); double reweightFactor = err / (1 - err); double sum = 0; // Decrease weights of correctly classified instances for (int i = 0; i < trainingInsts.size(); i++) { w = trainingInsts.getInstanceWeight(i); if (correct[i]) w *= reweightFactor; trainingInsts.setInstanceWeight (i, w); sum += w; } // Normalize the instance weights for (int i = 0; i < trainingInsts.size(); i++) { trainingInsts.setInstanceWeight (i, trainingInsts.getInstanceWeight(i) / sum); } logger.info("=========== AdaBoostTrainer round " + (round+1) + " finished, weak classifier training error = " + err); } for (int i = 0; i < alphas.length; i++) logger.info("AdaBoostTrainer weight[weakLearner[" + i + "]]=" + alphas[i]); this.classifier = new AdaBoost (roundTrainingInsts.getPipe(), weakLearners, alphas); return classifier; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy