All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.AdaBoostM2Trainer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.util.Arrays;
import java.util.logging.*;

import cc.mallet.types.*;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;

/**
 * This version of AdaBoost can handle multi-class problems.  For
 * binary classification, can also use AdaBoostTrainer.
 *
 * 

Yoav Freund and Robert E. Schapire * "Experiments with a New Boosting Algorithm" * In Journal of Machine Learning: Proceedings of the 13th International Conference, 1996 * http://www.cs.princeton.edu/~schapire/papers/FreundSc96b.ps.Z * * @author Gary Huang [email protected] */ public class AdaBoostM2Trainer extends ClassifierTrainer { private static Logger logger = MalletLogger.getLogger(AdaBoostM2Trainer.class.getName()); private static int MAX_NUM_RESAMPLING_ITERATIONS = 10; ClassifierTrainer weakLearner; int numRounds; AdaBoostM2 classifier; public AdaBoostM2 getClassifier () { return classifier; } public AdaBoostM2Trainer (ClassifierTrainer weakLearner, int numRounds) { if (! (weakLearner instanceof Boostable)) throw new IllegalArgumentException ("weak learner not boostable"); if (numRounds <= 0) throw new IllegalArgumentException ("number of rounds must be positive"); this.weakLearner = weakLearner; this.numRounds = numRounds; } public AdaBoostM2Trainer (ClassifierTrainer weakLearner) { this (weakLearner, 100); } /** * Boosting method that resamples instances using their weights */ public AdaBoostM2 train (InstanceList trainingList) { FeatureSelection selectedFeatures = trainingList.getFeatureSelection(); if (selectedFeatures != null) throw new UnsupportedOperationException("FeatureSelection not yet implemented."); int numClasses = trainingList.getTargetAlphabet().size(); int numInstances = trainingList.size(); // Construct the set "B", a list of instances of size // (numInstances * (numClasses - 1)). // Each instance in this list will have weights // (mislabel distribution) associated with classes // the intance doesn't belong to. InstanceList trainingInsts = new InstanceList(trainingList.getPipe()); // Set the initial weights to be uniform double[] weights = new double[numInstances * (numClasses - 1)]; double w = 1.0 / weights.length; Arrays.fill(weights, w); int[] classIndices = new int[weights.length]; int numAdded = 0; for (int i = 0; i < numInstances; i++) { Instance inst = trainingList.get(i); int trueClassIndex = inst.getLabeling().getBestIndex(); for (int j = 0; j < numClasses; j++) { if (j != trueClassIndex) { trainingInsts.add(inst, 1); classIndices[numAdded] = j; numAdded++; } } } java.util.Random random = new java.util.Random(); Classifier[] weakLearners = new Classifier[numRounds]; double[] classifierWeights = new double[numRounds]; double[] exponents = new double[weights.length]; int[] instIndices = new int[weights.length]; for (int i = 0; i < instIndices.length; i++) instIndices[i] = i; // Boosting iterations for (int round = 0; round < numRounds; round++) { logger.info("=========== AdaBoostM2Trainer round " + (round+1) + " begin"); // Sample instances from set B using the // weight vector to train the weak learner double epsilon; InstanceList roundTrainingInsts = new InstanceList(trainingInsts.getPipe()); int resamplingIterations = 0; do { epsilon = 0; int[] sampleIndices = sampleWithWeights(instIndices, weights, random); roundTrainingInsts = new InstanceList(trainingInsts.getPipe(), sampleIndices.length); for (int i = 0; i < sampleIndices.length; i++) { Instance inst = trainingInsts.get(sampleIndices[i]); roundTrainingInsts.add(inst, 1); } weakLearners[round] = weakLearner.train(roundTrainingInsts); // Calculate the pseudo-loss of weak learner for (int i = 0; i < trainingInsts.size(); i++) { Instance inst = trainingInsts.get(i); Classification c = weakLearners[round].classify(inst); double htCorrect = c.valueOfCorrectLabel(); double htWrong = c.getLabeling().value(classIndices[i]); epsilon += weights[i] * (1 - htCorrect + htWrong); exponents[i] = 1 + htCorrect - htWrong; } epsilon *= 0.5; resamplingIterations++; } while (Maths.almostEquals(epsilon, 0) && resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS); // Stop boosting when pseudo-loss is 0, ignoring // weak classifier trained this round if (Maths.almostEquals(epsilon, 0)) { logger.info("AdaBoostM2Trainer stopped at " + (round+1) + " / " + numRounds + " pseudo-loss=" + epsilon); // If we are in the first round, have to use the weak classifier in any case int numClassifiersToUse = (round == 0) ? 1 : round; if (round == 0) classifierWeights[0] = 1; double[] classifierWeights2 = new double[numClassifiersToUse]; Classifier[] weakLearners2 = new Classifier[numClassifiersToUse]; System.arraycopy(classifierWeights, 0, classifierWeights2, 0, numClassifiersToUse); System.arraycopy(weakLearners, 0, weakLearners2, 0, numClassifiersToUse); for (int i = 0; i < classifierWeights2.length; i++) { logger.info("AdaBoostM2Trainer weight[weakLearner[" + i + "]]=" + classifierWeights2[i]); } return new AdaBoostM2 (trainingInsts.getPipe(), weakLearners2, classifierWeights2); } double beta = epsilon / (1 - epsilon); classifierWeights[round] = Math.log(1.0 / beta); // Update and normalize weights double sum = 0; for (int i = 0; i < weights.length; i++) { weights[i] *= Math.pow(beta, 0.5 * exponents[i]); sum += weights[i]; } MatrixOps.timesEquals(weights, 1.0 / sum); logger.info("=========== AdaBoostM2Trainer round " + (round+1) + " finished, pseudo-loss = " + epsilon); } for (int i = 0; i < classifierWeights.length; i++) logger.info("AdaBoostM2Trainer weight[weakLearner[" + i + "]]=" + classifierWeights[i]); this.classifier = new AdaBoostM2 (trainingInsts.getPipe(), weakLearners, classifierWeights); return classifier; } // returns an array of ints of the same size as data, // where the the samples are randomly chosen from data // using the distribution of the weights vector private int[] sampleWithWeights(int[] data, double[] weights, java.util.Random random) { if (weights.length != data.length) throw new IllegalArgumentException("length of weight vector must equal number of data points"); double sumOfWeights = 0; for (int i = 0; i < data.length; i++) { if (weights[i] < 0) throw new IllegalArgumentException("weight vector must be non-negative"); sumOfWeights += weights[i]; } if (sumOfWeights <= 0) throw new IllegalArgumentException("weights must sum to positive value"); int[] sample = new int[data.length]; double[] probabilities = new double[data.length]; double sumProbs = 0; for (int i = 0; i < data.length; i++) { sumProbs += random.nextDouble(); probabilities[i] = sumProbs; } MatrixOps.timesEquals(probabilities, sumOfWeights / sumProbs); // make sure rounding didn't mess things up probabilities[data.length - 1] = sumOfWeights; // do sampling int a = 0; int b = 0; sumProbs = 0; while (a < data.length && b < data.length) { sumProbs += weights[b]; while (a < data.length && probabilities[a] <= sumProbs) { sample[a] = data[b]; a++; } b++; } return sample; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy