All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.openimaj.ml.classification.StumpClassifier Maven / Gradle / Ivy

/**
 * Copyright (c) 2011, The University of Southampton and the individual contributors.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 *   * 	Redistributions of source code must retain the above copyright notice,
 * 	this list of conditions and the following disclaimer.
 *
 *   *	Redistributions in binary form must reproduce the above copyright notice,
 * 	this list of conditions and the following disclaimer in the documentation
 * 	and/or other materials provided with the distribution.
 *
 *   *	Neither the name of the University of Southampton nor the names of its
 * 	contributors may be used to endorse or promote products derived from this
 * 	software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package org.openimaj.ml.classification;

import org.openimaj.util.function.Operation;
import org.openimaj.util.pair.ObjectFloatPair;
import org.openimaj.util.parallel.Parallel;
import org.openimaj.util.parallel.Parallel.IntRange;

public class StumpClassifier {
	public static class WeightedLearner {
		// Trains using Error = \sum_{i=1}^{N} D_i * [y_i != h(x_i)]
		// and h(x) = classifier.sign * (2 * [xclassifier.dimension >
		// classifier.threshold] - 1)
		public ObjectFloatPair learn(final LabelledDataProvider trainingSet, final float[] _weights) {
			final StumpClassifier classifier = new StumpClassifier();

			// Search for minimum training set error
			final float[] minimumError = { Float.POSITIVE_INFINITY };

			final boolean[] classes = trainingSet.getClasses();
			final int nInstances = trainingSet.numInstances();

			// Determine total potential error
			float totalErrorC = 0.0f;
			for (int i = 0; i < nInstances; i++)
				totalErrorC += _weights[i];
			final float totalError = totalErrorC;

			// Initialise search error
			float initialErrorC = 0.0f;
			for (int i = 0; i < nInstances; i++)
				initialErrorC += !classes[i] ? _weights[i] : 0.0;
				final float initialError = initialErrorC;

				// Loop over possible dimensions
				// for (int d = 0; d < trainingSet.numFeatures(); d++) {
				Parallel.forRange(0, trainingSet.numDimensions(), 1, new Operation() {
					@Override
					public void perform(IntRange rng) {
						final StumpClassifier currClassifier = new StumpClassifier();
						currClassifier.dimension = -1;
						currClassifier.threshold = Float.NaN;
						currClassifier.sign = 0;

						float currMinimumError = Float.POSITIVE_INFINITY;

						for (int d = rng.start; d < rng.stop; d += rng.incr) {
							// Pre-sort data-items in dimension for efficient
							// threshold
							// search
							final float[] data = trainingSet.getFeatureResponse(d);
							final int[] indices = trainingSet.getSortedResponseIndices(d);

							// Initialise search error
							float currentError = initialError;

							// Search through the sorted list to determine best
						// threshold
							for (int i = 0; i < nInstances - 1; i++) {
								// Update current error
								final int index = indices[i];
								if (classes[index])
									currentError += _weights[index];
								else
									currentError -= _weights[index];

								// Check for repeated values
								if (data[indices[i]] == data[indices[i + 1]])
									continue;

								// Compute the test threshold - maximises the margin
								// between potential thresholds
								final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f;

								// Compare to current best
								if (currentError < currMinimumError)
								{
									// Good classifier with classifier.sign = +1
									currMinimumError = currentError;
									currClassifier.dimension = d;
									currClassifier.threshold = testThreshold;
									currClassifier.sign = +1;
								}
								if ((totalError - currentError) < currMinimumError)
								{
									// Good classifier with classifier.sign = -1
									currMinimumError = (totalError - currentError);
									currClassifier.dimension = d;
									currClassifier.threshold = testThreshold;
									currClassifier.sign = -1;
								}
							}
						}

						synchronized (classifier) {
							if (currMinimumError < minimumError[0]) {
								minimumError[0] = currMinimumError;
								classifier.dimension = currClassifier.dimension;
								classifier.sign = currClassifier.sign;
								classifier.threshold = currClassifier.threshold;
							}
						}
					}
				});

				return new ObjectFloatPair(classifier, minimumError[0]);
		}
	}

	/**
	 * The dimension of the feature on which this stump operates
	 */
	public int dimension;

	/**
	 * The threshold which is the feature is tested against
	 */
	public float threshold;

	/**
	 * The sign of the stump (determines which side of the threshold corresponds
	 * to positive)
	 */
	public int sign;

	public boolean classify(float[] instanceFeature) {
		return (instanceFeature[dimension] > threshold ? sign : -sign) == 1;
	}

	public boolean classify(float f) {
		return (f > threshold ? sign : -sign) == 1;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy