All Downloads are FREE. Search and download functionalities are using the official Maven repository.

boofcv.alg.bow.LearnSceneFromFiles Maven / Gradle / Ivy

Go to download

BoofCV is an open source Java library for real-time computer vision and robotics applications.

The newest version!
/*
 * Copyright (c) 2021, Peter Abeles. All Rights Reserved.
 *
 * This file is part of BoofCV (http://boofcv.org).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package boofcv.alg.bow;

import boofcv.struct.learning.ClassificationHistogram;
import boofcv.struct.learning.Confusion;
import org.jetbrains.annotations.Nullable;

import java.io.File;
import java.util.*;

/**
 * Abstract class which provides a frame work for learning a scene classifier from a set of images.
 *
 * TODO describe how it provides learning
 *
 * @author Peter Abeles
 */
@SuppressWarnings("NullAway.Init")
public abstract class LearnSceneFromFiles {

	protected Random rand;

	protected List scenes = new ArrayList<>();

	// The minimum number of images in each type of set
	int minimumTrain;
	int minimumCross;
	int minimumTest;

	// how to divide the input set up
	double fractionTrain;
	double fractionCross;

	// maps for each set of images
	protected Map> train;
	protected Map> cross;
	protected Map> test;

	public Confusion evaluateTest() {
		return evaluate(test);
	}

	/**
	 * Given a set of images with known classification, predict which scene each one belongs in and compute
	 * a confusion matrix for the results.
	 *
	 * @param set Set of classified images
	 * @return Confusion matrix
	 */
	protected Confusion evaluate( Map> set ) {
		ClassificationHistogram histogram = new ClassificationHistogram(scenes.size());

		int total = 0;
		for (int i = 0; i < scenes.size(); i++) {
			total += Objects.requireNonNull(set.get(scenes.get(i))).size();
		}
		System.out.println("total images " + total);

		for (int sceneIdx = 0; sceneIdx < scenes.size(); sceneIdx++) {
			String scene = scenes.get(sceneIdx);

			List images = Objects.requireNonNull(set.get(scene));
			System.out.println("  " + scene + " " + images.size());
			for (int imageIdx = 0; imageIdx < images.size(); imageIdx++) {
				String image = images.get(imageIdx);
				int predicted = classify(image);
				histogram.increment(sceneIdx, predicted);
			}
		}

		return histogram.createConfusion();
	}

	/**
	 * Given an image compute which scene it belongs to
	 *
	 * @param path Path to input image
	 * @return integer corresponding to the scene
	 */
	protected abstract int classify( String path );

	public void loadSets( File dirTraining, File dirCross, File dirTest ) {
		train = Objects.requireNonNull(findImages(dirTraining));
		if (dirCross != null)
			cross = Objects.requireNonNull(findImages(dirCross));
		test = Objects.requireNonNull(findImages(dirTest));

		extractKeys(train);
		extractKeys(test);
	}

	private void extractKeys( Map> images ) {
		Set keys = images.keySet();

		for (String key : keys) { // lint:forbidden ignore_line
			if (!scenes.contains(key)) {
				scenes.add(key);
			}
		}
	}

	public void loadThenSplit( File directory ) {
		Map> all = Objects.requireNonNull(findImages(directory));
		train = new HashMap<>();
		if (fractionCross != 0)
			cross = new HashMap<>();
		test = new HashMap<>();

		Set keys = all.keySet();

		for (String key : keys) { // lint:forbidden ignore_line
			List allImages = Objects.requireNonNull(all.get(key));

			// randomize the ordering to remove bias
			Collections.shuffle(allImages, rand);

			int numTrain = (int)(allImages.size()*fractionTrain);
			numTrain = Math.max(minimumTrain, numTrain);
			int numCross = (int)(allImages.size()*fractionCross);
			numCross = Math.max(minimumCross, numCross);
			int numTest = allImages.size() - numTrain - numCross;

			if (numTest < minimumTest)
				throw new RuntimeException("Not enough images to create test set. " + key + " total = " + allImages.size());

			createSubSet(key, allImages, train, 0, numTrain);
			if (cross != null) {
				createSubSet(key, allImages, cross, numTrain, numCross + numTrain);
			}
			createSubSet(key, allImages, test, numCross + numTrain, allImages.size());
		}

		scenes.addAll(keys);
	}

	private void createSubSet( String key, List allImages, Map> subset,
							   int start, int end ) {
		List trainImages = new ArrayList<>();
		for (int i = start; i < end; i++) {
			trainImages.add(allImages.get(i));
		}
		subset.put(key, trainImages);
	}

	/**
	 * Loads the paths to image files contained in subdirectories of the root directory. Each sub directory
	 * is assumed to be a different category of images.
	 */
	public static @Nullable Map> findImages( File rootDir ) {
		File[] files = rootDir.listFiles();
		if (files == null)
			return null;

		List imageDirectories = new ArrayList<>();
		for (File f : files) { // lint:forbidden ignore_line
			if (f.isDirectory()) {
				imageDirectories.add(f);
			}
		}
		Map> out = new HashMap<>();
		for (File d : imageDirectories) { // lint:forbidden ignore_line
			List images = new ArrayList<>();

			files = d.listFiles();
			if (files == null)
				throw new RuntimeException("Should be a directory!");

			for (File f : files) { // lint:forbidden ignore_line
				if (f.isHidden() || f.isDirectory() || f.getName().endsWith(".txt")) {
					continue;
				}

				images.add(f.getPath());
			}

			String key = d.getName().toLowerCase();

			out.put(key, images);
		}

		return out;
	}

	public List getScenes() {
		return scenes;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy