All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.openimaj.image.objectdetection.haar.training.Testing Maven / Gradle / Ivy

Go to download

A project for various tests that don't quite constitute demos but might be useful to look at.

There is a newer version: 1.3.10
Show newest version
/**
 * Copyright (c) 2011, The University of Southampton and the individual contributors.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 *   * 	Redistributions of source code must retain the above copyright notice,
 * 	this list of conditions and the following disclaimer.
 *
 *   *	Redistributions in binary form must reproduce the above copyright notice,
 * 	this list of conditions and the following disclaimer in the documentation
 * 	and/or other materials provided with the distribution.
 *
 *   *	Neither the name of the University of Southampton nor the names of its
 * 	contributors may be used to endorse or promote products derived from this
 * 	software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package org.openimaj.image.objectdetection.haar.training;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;

import org.openimaj.image.FImage;
import org.openimaj.image.ImageUtilities;
import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
import org.openimaj.image.objectdetection.haar.HaarFeature;
import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
import org.openimaj.image.objectdetection.haar.Stage;
import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
import org.openimaj.image.objectdetection.haar.ValueClassifier;
import org.openimaj.io.IOUtils;
import org.openimaj.ml.classification.StumpClassifier;
import org.openimaj.ml.classification.boosting.AdaBoost;
import org.openimaj.util.pair.ObjectFloatPair;

public class Testing {
	List features;
	List positive = new ArrayList();
	List negative = new ArrayList();

	void createFeatures(int width, int height) {
		features = HaarFeatureType.generateFeatures(width, height, HaarFeatureType.CORE);

		final float invArea = 1f / ((width - 2) * (height - 2));
		for (final HaarFeature f : features) {
			f.setScale(1, invArea);
		}
	}

	// void loadPositive(boolean tilted) throws IOException {
	// final String base = "/Users/jsh2/Data/att_faces/s%d/%d.pgm";
	//
	// for (int j = 1; j <= 40; j++) {
	// for (int i = 1; i <= 10; i++) {
	// final File file = new File(String.format(base, j, i));
	//
	// FImage img = ImageUtilities.readF(file);
	// img = img.extractCenter(50, 50);
	// img = ResizeProcessor.resample(img, 19, 19);
	// positive.add(new SummedSqTiltAreaTable(img, tilted));
	// }
	// }
	// }
	//
	// void loadNegative(boolean tilted) throws IOException {
	// final File dir = new File(
	// "/Volumes/Raid/face_databases/haartraining/tutorial-haartraining.googlecode.com/svn/trunk/data/negatives/");
	//
	// for (final File f : dir.listFiles()) {
	// if (f.getName().endsWith(".jpg")) {
	// FImage img = ImageUtilities.readF(f);
	//
	// final int minwh = Math.min(img.width, img.height);
	//
	// img = img.extractCenter(minwh, minwh);
	// img = ResizeProcessor.resample(img, 19, 19);
	//
	// negative.add(new SummedSqTiltAreaTable(img, tilted));
	// }
	// }
	// }

	void loadImage(File image, List sats, boolean
			tilted) throws IOException
	{
		final FImage img = ImageUtilities.readF(image);

		sats.add(new SummedSqTiltAreaTable(img, false));
	}

	void loadPositive(boolean tilted) throws IOException {
		for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/face").listFiles()) {
			if (file.getName().endsWith(".pgm")) {
				loadImage(file, positive, tilted);
			}
		}
	}

	void loadNegative(boolean tilted) throws IOException {
		for (final File file : new File("/Users/jsh2/Data/cbcl-faces/train/non-face").listFiles()) {
			if (file.getName().endsWith(".pgm")) {
				loadImage(file, negative, tilted);
			}
		}
	}

	void perform() throws IOException {
		System.out.println("Creating feature set");
		createFeatures(19, 19);

		System.out.println("Loading positive images and computing SATs");
		loadPositive(false);

		System.out.println("Loading negative images and computing SATs");
		loadNegative(false);

		System.out.println("+ve: " + positive.size());
		System.out.println("-ve: " + negative.size());
		System.out.println("features: " + features.size());

		System.out.println("Computing cached feature sets");
		final CachedTrainingData data = new CachedTrainingData(positive, negative, features);

		System.out.println("Starting Training");
		final AdaBoost boost = new AdaBoost();
		final List> ensemble = boost.learn(data, 500);

		System.out.println("Training complete. Ensemble has " + ensemble.size() + " classifiers.");

		for (float threshold = 3; threshold >= -3; threshold -= 0.25f) {
			System.out.println("Threshold = " + threshold);
			boost.printClassificationQuality(data, ensemble, threshold);
		}

		final Stage root = createStage(ensemble);
		final StageTreeClassifier classifier = new StageTreeClassifier(19, 19, "test cascade", false, root);
		classifier.setScale(1);

		for (int i = 0; i < positive.size(); i++) {
			if ((classifier.classify(positive.get(i), 0, 0) == 1) != AdaBoost.classify(data.getInstanceFeature(i),
					ensemble))
				System.out.println("ERROR");
		}
		for (int i = 0; i < negative.size(); i++) {
			if ((classifier.classify(negative.get(i), 0, 0) == 1) != AdaBoost.classify(
					data.getInstanceFeature(i + positive.size()), ensemble))
			{
				System.out.println(classifier.classify(negative.get(i), 0, 0) + " " + AdaBoost.classify(
						data.getInstanceFeature(i + positive.size()), ensemble));
				System.out.println("ERROR2");
			}
		}

		final ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin")));
		IOUtils.write(classifier, oos);
		oos.close();
	}

	/**
	 * Create a {@link Stage} from a trained ensemble.
	 *
	 * @param ensemble
	 *            the ensemble
	 * @return the stage
	 */
	private Stage createStage(final List> ensemble) {
		final HaarFeatureClassifier[] trees = new HaarFeatureClassifier[ensemble.size()];

		for (int i = 0; i < trees.length; i++) {
			final ObjectFloatPair wc = ensemble.get(i);
			final StumpClassifier c = wc.first;
			final float alpha = wc.second;
			final float threshold = c.threshold;
			final float leftValue = c.sign > 0 ? -alpha : alpha;
			final HaarFeature feature = features.get(c.dimension);

			final ValueClassifier left = new ValueClassifier(leftValue);
			final ValueClassifier right = new ValueClassifier(-leftValue);

			trees[i] = new HaarFeatureClassifier(feature, threshold, left, right);
		}

		final Stage root = new Stage(0, trees, null, null);
		return root;
	}

	public static void main(String[] args) throws IOException {
		new Testing().perform();
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy